Reading

Transformers are RNNs

摘掉Softmax

制约Attention性能的关键因素,其实是定义里边的Softmax!事实上,简单地推导一下就可以得到这个结论。 这一步我们得到一个 的矩阵,就是这一步决定了Attention的复杂度是 ;如果没有Softmax,那么就是三个矩阵连乘,而矩阵乘法是满足结合率的,所以我们可以先算 ,得到一个 的矩阵,然后再用 左乘它,由于,所以这样算大致的复杂度只是 (就是 左乘那一步占主导)。

也就是说,去掉Softmax的Attention的复杂度可以降到最理想的线性级别!这显然就是我们的终极追求:Linear Attention,复杂度为线性级别的Attention。所以,本文的主题就是探究摘掉Softmax后的线形Attention。

问题是,直接去掉Softmax还能算是Attention吗?它还能有标准的Attention的效果吗?

几个例子

根据 式3 如果直接去掉Softmax,那么就是 ,问题是内积无法保证非负性,所以这还不是一个合理的选择。下面我们简单介绍几种可取的方案。

值得指出的是,下面介绍的这几种Linear Attention,前两种来自CV领域,第三种是苏神自己构思的,所以都还没有在NLP任务上做过什么实验。

核函数形式

一个自然的想法是:如果 的每个元素都是非负的,那么内积自然也就是非负的。为了完成这点,我们可以给 各自加个激活函数 ,即

其中是值域非负的激活函数。在论文Transformers are RNNs中选择的是

非要讲故事的话,式4 可以联想到“核方法(kernal method)”,尤其是 就相当于一个核函数,而 就是通过核函数所定义的内积。这方面的思考可以参考论文原文,此处不做过多延伸。

妙用Softmax

另一篇更早的文章《Efficient Attention: Attention with Linear Complexities》则给出了一个更有意思的选择。它留意到在 中,,如果“ 那一维是归一化的、并且 那一维是归一化的”,那么 就是自动满足归一化了,所以它给出的选择是:

其中 分别指在第一个()、第二个维度()进行Softmax运算。也就是说,这时候我们是各自给 加Softmax,而不是算完之后才加Softmax。

如果直接取,那么很显然这个形式也是式4 的一个特例。另外这个设计在CV中出现过不止一次,比如A^2-Nets也包含了同样的做法。

自己的构思

在这里,苏神给出自己的一种构思。这个构思的出发点不再是式4,而是源于我们对原始 定义2 的近似。由泰勒展开我们有

如果,那么就可以保证右端的非负性,从而可以让 。到这里读者可能已经想到了,想要保证,只需要分别对 归一化。所以,笔者最终提出的方案就是:

这不同于形式4,但理论上它更加接近原始的Scaled-Dot Attention。

Linformer

跟本文所介绍的Linear Attention很相似的一个工作是Facebook的Linformer,它依然保留原始的Scaled-Dot Attention形式,但在进行Attention之前,用两个 的矩阵分别对进行投影,即变为

这样一来, 就只是一个 的矩阵,而作者声称对于哪怕对于很大的序列长度n,m也可以保持为一个适中的常数,从而这种Attention也是线性的。

但是,笔者认为“对于超长序列m可以保持不变”这个结论是值得质疑的,对于长序列原论文只做了MLM任务,而很明显MLM并不那么需要长程依赖,所以这个实验没什么说服力。因此,Linformer是不是真的Linear,还有待商榷。

自回归生成

Linformer的另一个缺点是 这两个运算直接把整个序列的信息给“糅合”起来了,所以它没法简单地把将来信息给Mask掉(Causal Masking),从而无法做语言模型、Seq2Seq等自回归生成任务,这也是刚才说的原作者只做了MLM任务的原因。相比之下,前面介绍的几种Linear Attention都能做到这一点。以 式3式4 为例,如果要Mask掉未来信息,那么只需要把求和 改为

实现上式有两种方式:第一方式是设以及,我们有

这说明这种Attention可以作为一个RNN模型用递归的方式实现,它的空间复杂度最低,但是要串行计算,适合预测解码时使用;

第二种是直接将 做外积,得到一个 的矩阵,然后对 n 那一维执行 运算,这样就一次性得到 了,它的速度最快,但空间占用最大,适合训练时使用,不过很多时候都有,一般情况下训练时都很难承受这个空间复杂度,因此多数还是用RNN形式。

但是如果用这种最基本的循环来训练有问题。

如果使用autograd,那么,每一个时刻的hidden state 都会被保存下来给反向传播算梯度。保存中间hidden state在之前的RNN训练里面没啥问题,但是在这里却是一个大问题。为什么呢?因为传统RNN 每个时刻的hidden state的大小是 ,而linear attention每个时刻的hidden state大小是 ,如果全部存下来需要 ,那是万万不可接受的(你可以想象,输入的QKV才, 一般 , 那么你基本就显存爆炸寄了不用玩了)

Transformers are RNNs 里面提出了一个memory-efficient 训练的解决方案。

在这个方案里, 在前向传播的过程中是不会被保存的。 由于只有 的梯度依赖于 。那我们索性在backward的时候重新算一遍 ,但这时的output算的不是 ,而是

的梯度不依赖于 ,可以直接通过平常的BPTT来算出来。

image

小结

上面介绍了一些从结构上对Attention进行修改从而降低其计算复杂度的工作,其中最主要的idea是去掉标准Attention中的Softmax,就可以使得Attention的复杂度退化为理想的级别(Linear Attention)。相比于其他类似的改进结构的工作,这种修改能在把复杂度降到 的同时,依然保留所有的“token-token“的注意力,同时还能保留用于做自回归生成的可能性。

Reference

🔖 https://spaces.ac.cn/archives/7546