RNN复杂度 cnn复杂度_卷积核


RNN复杂度 cnn复杂度_复杂度_02


  • Maximum path lengths:序列中两个元素进行交互所需经过的最大路径长度
  • per-layer complexity:每层的时间复杂度
  • minimum number of sequential operations:最少需要的序列操作数

计算效率

的矩阵,与另一个形状为

的矩阵相乘,其运算复杂度来源于乘法操作的次数,时间复杂度为

Self-Attention



  • : 与 运算,得到 矩阵,复杂度为
  • ,则n行的复杂度为
  • 与 运算,得到 矩阵,复杂度为

故最后self-attention的时间复杂度为



对于受限的self-attention,每个元素仅能和周围


个元素进行交互,即和



维向量做内积运算,复杂度为


,则


个元素的总时间复杂度为



Multi-Head Attention



对于multi-head attention,假设有


个head,这里


是一个常数,对于每个head,首先需要把三个矩阵分别映射到


维度。这里考虑一种简化情况:


。(对于dot-attention计算方式,



可以不同)。


  • 与 运算,忽略常系数,复杂度为 。
  • 与 运算,复杂度为
  • 输出线性映射的复杂度:concat操作拼起来形成 的矩阵,然后经过输出线性映射,保证输入输出相同,所以是 与 计算,复杂度为

故最后的复杂度为:



注意:多头的计算并不是通过循环完成的,而是通过 transposes and reshapes,用矩阵乘法来完成的。假设有


个head,则新的representation dimension:


。因为,我们将


的矩阵拆为


的张量,再利用转置操作转为


的张量。故


的计算为:



做计算,得到


的张量,复杂度为


,即


。注意,此处


实际是一个常数,故


复杂度为



Recurrent



  • : 与 运算,复杂度为 , 为input size
  • : 与 运算,复杂度为

故一次操作的时间复杂度为



次序列操作后的总时间复杂度为



Convolution

  • 为了保证输入和输出在第一个维度都相同,故需要对输入进行padding操作,因为这里kernel size为 ,(实际kernel的形状为 )如果不padding的话,那么输出的第一个维度为 ,因为这里stride是为1的。为了保证输入输出相同,则需要对序列的前后分别padding长度为 。
  • 的卷积核一次运算的复杂度为: ,一共做了 次,故复杂度为
  • 个卷积核,所以卷积操作总的时间复杂度为

序列操作数

  • 表明三种模型的并行程度:从计算方式上看,只有RNN才需要串行地完成 次序列操作,而self-attention和convolution的n次序列操作均可以并行完成。因为RNN还需要依赖于上一个时间步的隐藏层输出,而其他模型仅仅依赖于输入。

最大路径长度


RNN复杂度 cnn复杂度_复杂度_03


  • 的两个结点传递信息所经历的路径长度,表征了存在长距离依赖的结点在传递信息时,信息丢失的程度,长度越长,两个节点之间越难交互,信息丢失越严重
  • 的序列中,节点之间的最大路径长度为 ,即 。第一个token的信息需要经过 次迭代才能传到最后一个时间步的状态中,信息丢失严重,很难建立节点间的长距离依赖。
  • , 层数为 的CNN中,能看到最大的local context的大小为 ,最大路径长度为 ,例如图(b)中是一个两层的卷积核大小为3的CNN,顶层节点能看到的最大local context为2*2+1=5个token的输入。粗略来看,上图可以看作一个 叉树,深度为 的树,叶子节点个数为 ,解得最大路径长度为 ,即
  • Self-attention:任意两个结点都可以直接相连,即任意两个结点之间的距离为1,故最大路径长度为1
  • 受限的self-attention:类似于卷积核大小为 的CNN,最大路径长度为

参考资料

从 Transformer 说起tobiaslee.top

RNN复杂度 cnn复杂度_RNN复杂度_04

Why Self-Attention? A Targeted Evaluation of Neural Machine Translation Architecturesarxiv.org https://www.reddit.com/r/LanguageTechnology/comments/9gulm9/complexity_of_transformer_attention_network/www.reddit.com