最近要开始使用Transformer去做一些事情了,特地把与此相关的知识点记录下来,构建相关的、完整的知识结构体系,
以下是要写的文章,本文是这个系列的第十二篇:
跟Longformer一样,Linformer也是为了解决Transformer中的Attention部分随着序列长度而有N^2复杂度的问题。
论文标题很exciting,但是实际做法却很简洁直接,就是在Attention计算的时候K和V部分加了一个线性映射映射到低维空间。当低维空间的大小是固定的时候,就达到了线性复杂度。
与简单直接的做法不同,论文中花了很大的篇幅去对映射到低维空间的做法做了证明。
在Wiki103和IMDB两个数据集上,在Roberta-large预训练好的模型上计算出Attention矩阵。然后做奇异值分解,然后从下图左两图中可以看到,通过奇异值的累积,可以看到,前128维的奇异值累计值已经占了到了0.9左右。
而在右图中可以看到,越高层,128个奇异值累积值就越高。在第11层,128个奇异值累积起来达到了0.96。
因而说明了,虽然Attention的计算结果是一个N x N的矩阵,但其实一个低秩矩阵比如N x 128可能就已经足够存储attention的所有信息。
首先回顾一下Attention的计算,如下图所示。Transformer中的Attention都是多头的,对于第i个头来说,计算如下。
注意,上面的公式表达跟我们在之前文章中写的略有不同,这里Q,K,V成了原始的embedding,WQ, WK和WV是转换矩阵。
因此,论文提出了一个定理,如下图所示。数字符号比较繁杂,我用汉语再翻译一下,就是对于任意的Q,K,V和WQ, WK和WV,存在一个低秩矩阵P,使得对于VWV中的任何一个列向量w,满足下面这个式子。更具体一点,就是用低秩矩阵对w做转换,其损失相对于用原始矩阵,被控制在一个可以接受的范围内,此时低秩矩阵的秩是log(n)。
证明我就不解释了,我们主要关注的是这个idea以及idea所产生的效果。对数学感兴趣可以直接去翻论文。
其实这里我有一个疑问,如果低秩矩阵的秩是logn,那么这个算法的复杂度应该是nlog(n)而不是线性?
有了这个方法之后,其实一个直接的手段就是使用SVD对矩阵做近似,这样复杂度就可以变成O(nk),k为采用的低秩矩阵的秩。
但是runtime这样做,还需要每次先对大矩阵做SVD,不划算。
根据上面所说,在inference的时候去做SVD更费事,所以需要在训练时做好。而做的方式就是在key和value上再各自加入一个线性变换。如下图所示:
上图中的右上部分还画出了不同的k,inference时间和序列长度的关系。可以看到,不管k是多少,Linformer的曲线都是平的。
公式如下,E是K上的转换,F是V上的转换。
针对上面的做法,论文又提出了定理二,对k的下界进行了理论上的限定。论证部分大家感兴趣可以去看原始论文。
上面模型部分添加了两个线性转换层。在这两个层上,其实还有很多技巧:
对MLM任务的训练结果如下,从a和b图可以看到,k越大效果越好,但它们和标准的transformer其实差别不大。
在下游任务上结果如下,也是和标准transformer类似的效果。
而在内存和速度上的提升,则在下图,左图是速度提升,右图是内存提升。可以看到,序列长度越长,k越小,提升越大。
这篇论文是一个观察法做优化的绝好案例,从对attention的SVD分解到映射层的添加水到渠成。但标题原因还是导致论文有些言过其实。主要是因为:
方案在长度比较长的时候才能显现为例,而在原始的bert上,长度512,此时如果k=128, 那么相当于内存占用量由512 * 512 变成128 * 512。
另一方面,Linformer在长度比较长的时候会更加有效,但论文却只做了性能和内存的比较,没有做在较长序列的情况下,Linformer在下游任务上的优势实验,虽然Roberta做不了baseline,但起码可以和Reformer,longformer比较。
证明部分有些奇怪,没有见到明确的线性的证明。
或许是我数学水平有限,k=5log(nd) / (ε^2 - ε^3) 我理解不是线性。(有理解不同的可以私信我)
对于序列较短的加速需求而言,还是MobileBert更靠谱一些。