Transformers are RNNs

Mar 30, 2025
1 views
NLP

摘掉Softmax

制约Attention性能的关键因素,其实是定义里边的Softmax!事实上,简单地推导一下就可以得到这个结论。\(\boldsymbol{Q}\boldsymbol{K}^{\top}\) 这一步我们得到一个 \(n\times n\) 的矩阵,就是这一步决定了Attention的复杂度是 \(\mathcal{O}(n^2)\);如果没有Softmax,那么就是三个矩阵连乘\(\boldsymbol{Q}\boldsymbol{K}^{\top}\boldsymbol{V}\),而矩阵乘法是满足结合率的,所以我们可以先算 \(\boldsymbol{K}^{\top}\boldsymbol{V}\),得到一个 \(d\times d\) 的矩阵,然后再用 \(\boldsymbol{Q}\) 左乘它,由于\(d \ll n\),所以这样算大致的复杂度只是 \(\mathcal{O}(n)\)(就是 \(\boldsymbol{Q}\) 左乘那一步占主导)。

也就是说,去掉Softmax的Attention的复杂度可以降到最理想的线性级别\(\mathcal{O}(n)\)!这显然就是我们的终极追求:Linear Attention,复杂度为线性级别的Attention。所以,本文的主题就是探究摘掉Softmax后的线形Attention。

问题是,直接去掉Softmax还能算是Attention吗?它还能有标准的Attention的效果吗?

几个例子

根据 式3 如果直接去掉Softmax,那么就是 \(\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\),问题是内积无法保证非负性,所以这还不是一个合理的选择。下面我们简单介绍几种可取的方案。

值得指出的是,下面介绍的这几种Linear Attention,前两种来自CV领域,第三种是苏神自己构思的,所以都还没有在NLP任务上做过什么实验。

核函数形式

一个自然的想法是:如果 \(\boldsymbol{q}_i,\boldsymbol{k}_j\) 的每个元素都是非负的,那么内积自然也就是非负的。为了完成这点,我们可以给 \(\boldsymbol{q}_i,\boldsymbol{k}_j\) 各自加个激活函数 \(\phi,\varphi\),即

\[ \begin{equation}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\end{equation} \]

其中\(\phi(\cdot),\varphi(\cdot)\)是值域非负的激活函数。在论文Transformers are RNNs中选择的是\(\phi(x)=\varphi(x)=\text{elu}(x)+1\)

非要讲故事的话,式4 可以联想到“核方法(kernal method)”,尤其是 \(\phi=\varphi\)\(\phi\) 就相当于一个核函数,而 \(\langle \phi(\boldsymbol{q}_i), \phi(\boldsymbol{k}_j)\rangle\) 就是通过核函数所定义的内积。这方面的思考可以参考论文原文,此处不做过多延伸。

妙用Softmax

另一篇更早的文章《Efficient Attention: Attention with Linear Complexities》则给出了一个更有意思的选择。它留意到在 \(\boldsymbol{Q}\boldsymbol{K}^{\top}\) 中,\(\boldsymbol{Q}, \boldsymbol{K}, \in\mathbb{R}^{n\times d}\),如果“\(\boldsymbol{Q}\)\(d\) 那一维是归一化的、并且\(\boldsymbol{K}\)\(n\) 那一维是归一化的”,那么 \(\boldsymbol{Q}\boldsymbol{K}^{\top}\) 就是自动满足归一化了,所以它给出的选择是:

\[ \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax_2\left(\boldsymbol{Q}\right)softmax_1(\boldsymbol{K})^{\top}\boldsymbol{V}\end{equation} \]

其中\(softmax_1\)\(softmax_2\) 分别指在第一个(\(n\))、第二个维度(\(d\))进行Softmax运算。也就是说,这时候我们是各自给\(\boldsymbol{Q},\boldsymbol{K}\) 加Softmax,而不是\(\boldsymbol{Q}\boldsymbol{K}^{\top}\)算完之后才加Softmax。

如果直接取\(\phi(\boldsymbol{q}_i)=softmax(\boldsymbol{q}_i),\varphi(\boldsymbol{k}_j)=softmax(\boldsymbol{k}_j)\),那么很显然这个形式也是式4 的一个特例。另外这个设计在CV中出现过不止一次,比如A^2-Nets也包含了同样的做法。

自己的构思

在这里,苏神给出自己的一种构思。这个构思的出发点不再是式4,而是源于我们对原始 定义2 的近似。由泰勒展开我们有

\[ \begin{equation}e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} \approx 1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\end{equation} \]

如果\(\boldsymbol{q}_i^{\top}\boldsymbol{k}_j\geq -1\),那么就可以保证右端的非负性,从而可以让 \(\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)=1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\)。到这里读者可能已经想到了,想要保证\(\boldsymbol{q}_i^{\top}\boldsymbol{k}_j\geq -1\),只需要分别对 \(\boldsymbol{q}_i,\boldsymbol{k}_j\)\(l_2\) 归一化。所以,笔者最终提出的方案就是:

\[ \begin{equation}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = 1 + \left( \frac{\boldsymbol{q}_i}{\Vert \boldsymbol{q}_i\Vert}\right)^{\top}\left(\frac{\boldsymbol{k}_j}{\Vert \boldsymbol{k}_j\Vert}\right)\end{equation} \]

这不同于形式4,但理论上它更加接近原始的Scaled-Dot Attention。

Linformer

跟本文所介绍的Linear Attention很相似的一个工作是Facebook的Linformer,它依然保留原始的Scaled-Dot Attention形式,但在进行Attention之前,用两个 \(m\times n\) 的矩阵\(\boldsymbol{E},\boldsymbol{F}\)分别对\(\boldsymbol{K},\boldsymbol{V}\)进行投影,即变为

\[ \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}(\boldsymbol{E}\boldsymbol{K})^{\top}\right)\boldsymbol{F}\boldsymbol{V}\end{equation} \]

这样一来,\(\boldsymbol{Q}(\boldsymbol{E}\boldsymbol{K})^{\top}\) 就只是一个 \(n\times m\) 的矩阵,而作者声称对于哪怕对于很大的序列长度n,m也可以保持为一个适中的常数,从而这种Attention也是线性的。

但是,笔者认为“对于超长序列m可以保持不变”这个结论是值得质疑的,对于长序列原论文只做了MLM任务,而很明显MLM并不那么需要长程依赖,所以这个实验没什么说服力。因此,Linformer是不是真的Linear,还有待商榷。

自回归生成

Linformer的另一个缺点是 \(\boldsymbol{E}\boldsymbol{K},\boldsymbol{F}\boldsymbol{V}\) 这两个运算直接把整个序列的信息给“糅合”起来了,所以它没法简单地把将来信息给Mask掉(Causal Masking),从而无法做语言模型、Seq2Seq等自回归生成任务,这也是刚才说的原作者只做了MLM任务的原因。相比之下,前面介绍的几种Linear Attention都能做到这一点。以 式3式4 为例,如果要Mask掉未来信息,那么只需要把求和 \(\sum\limits_{j=1}^n\) 改为\(\sum\limits_{j=1}^i\)

\[ \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^i \left(\phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\right)\boldsymbol{v}_j}{\sum\limits_{j=1}^i \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)}=\frac{ \phi(\boldsymbol{q}_i)^{\top} \sum\limits_{j=1}^i\varphi(\boldsymbol{k}_j)\boldsymbol{v}_j^{\top}}{ \phi(\boldsymbol{q}_i)^{\top} \sum\limits_{j=1}^i\varphi(\boldsymbol{k}_j)}\end{equation} \]

实现上式有两种方式:第一方式是设\(\boldsymbol{S}_i=\sum\limits_{j=1}^i\varphi(\boldsymbol{k}_j)\boldsymbol{v}_j^{\top}\)以及\(\boldsymbol{z}_i=\sum\limits_{j=1}^i\varphi(\boldsymbol{k}_j)\),我们有

\[ \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i =\frac{ \phi(\boldsymbol{q}_i)^{\top} \boldsymbol{S}_i}{ \phi(\boldsymbol{q}_i)^{\top} \boldsymbol{z}_i},\quad \begin{aligned}&\boldsymbol{S}_i=\boldsymbol{S}_{i-1}+\varphi(\boldsymbol{k}_i)\boldsymbol{v}_i^{\top}\\ &\boldsymbol{z}_i=\boldsymbol{z}_{i-1}+\varphi(\boldsymbol{k}_i) \end{aligned}\end{equation} \]

这说明这种Attention可以作为一个RNN模型用递归的方式实现,它的空间复杂度最低,但是要串行计算,适合预测解码时使用;

第二种是直接将\(\varphi(\boldsymbol{K}),\boldsymbol{V}\in\mathbb{R}^{n\times d}\) 做外积,得到一个 \(n\times d\times d\) 的矩阵,然后对 n 那一维执行\(\text{cumsum}\) 运算,这样就一次性得到 \(\boldsymbol{S}_1,\boldsymbol{S}_2,\dots,\boldsymbol{S}_n\) 了,它的速度最快,但空间占用最大,适合训练时使用,不过很多时候都有\(d^2\gg n\),一般情况下训练时都很难承受这个空间复杂度,因此多数还是用RNN形式。

但是如果用这种最基本的循环来训练有问题。

如果使用autograd,那么,每一个时刻的hidden state \(\mathbf{S}_t\) 都会被保存下来给反向传播算梯度。保存中间hidden state在之前的RNN训练里面没啥问题,但是在这里却是一个大问题。为什么呢?因为传统RNN 每个时刻的hidden state的大小是 \(\mathcal{O}(d)\) ,而linear attention每个时刻的hidden state大小是 \(\mathcal{O}(d^2)\) ,如果全部存下来需要 \(\mathcal{O}(Ld^2)\) ,那是万万不可接受的(你可以想象,输入的QKV才\(3Ld\), 一般 \(d\gg3\), 那么你基本就显存爆炸寄了不用玩了)

Transformers are RNNs 里面提出了一个memory-efficient 训练的解决方案。

在这个方案里, \(\mathbf{S}_t\) 在前向传播的过程中是不会被保存的。 由于只有 \(\mathbf{q}_{t}\) 的梯度依赖于 \(\mathbf{S}_t\) 。那我们索性在backward的时候重新算一遍 \(\mathbf{S}_t\) ,但这时的output算的不是 \(\mathbf{o}_t\),而是 \(\mathbf{dq}_t = \mathbf{S}_t \mathbf{do}_t\)

而 $\mathbf{k}_t, \mathbf{v}_t $ 的梯度不依赖于 \(\mathbf{S}_t\) ,可以直接通过平常的BPTT来算出来。

image

小结

上面介绍了一些从结构上对Attention进行修改从而降低其计算复杂度的工作,其中最主要的idea是去掉标准Attention中的Softmax,就可以使得Attention的复杂度退化为理想的\(\mathcal{O}(n)\)级别(Linear Attention)。相比于其他类似的改进结构的工作,这种修改能在把复杂度降到 \(\mathcal{O}(n)\)的同时,依然保留所有的“token-token“的注意力,同时还能保留用于做自回归生成的可能性。

Reference

🔖 https://spaces.ac.cn/archives/7546