概述
众所周知,尽管基于Attention机制的Transformer类模型有着良好的并行性能,但它的空间和时间复杂度都是 \(\mathcal{O}(n^2)\) 级别的,\(n\) 是序列长度,所以当 \(n\) 比较大时Transformer模型的计算量难以承受。近来,也有不少工作致力于降低Transformer模型的计算量,比如模型剪枝、量化、蒸馏等精简技术,又或者修改Attention结构,使得其复杂度能降低到 \(\mathcal{O}(n\log n)\) 甚至 \(\mathcal{O}(n)\)。
改变这一复杂度的思路主要有两种:
- 一是走稀疏化的思路,比如OpenAI的Sparse Attention,通过“只保留小区域内的数值、强制让大部分注意力为零”的方式,来减少Attention的计算量。经过特殊设计之后,Attention矩阵的大部分元素都是0,因此理论上它也能节省显存占用量和计算量。后续类似工作还有《Explicit Sparse Transformer: Concentrated Attention Through Explicit Selection》、《Longformer: The Long-Document Transformer》等。
- 二是走线性化的思路,也就是本文要介绍的一系列工作
快速预览
其实linear attention的思想很简单,就是把
的softmax去掉,变成了
然后借助矩阵乘法结合律得到
在双向注意力里,比方说古代的bert时期,以及计算机视觉领域中,这样就已经足够了,大家开开心心地在线性时间内算两个很大的矩阵乘法,甚至都不需要写kernel就能很高效
但是在autoregressive modeling中,我们需要有causal mask。 训练和推理的形式分别是:
同样地,把去掉softmax之后,我们可以得到
由于这个 \(\mathbf{M}\) 的存在,我们不能直接利用矩阵乘法的结合律得到上面先算KV 矩阵乘法的线性形式(因为矩阵乘法跟矩阵点乘是不可以交换的)
Transformers are RNNs
自回归linear attention的开山之作
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
主要的idea是去掉标准Attention中的Softmax,就可以使得Attention的复杂度退化为理想的\(\mathcal{O}(n)\)级别(Linear Attention)。相比于其他类似的改进结构的工作,这种修改能在把复杂度降到 \(\mathcal{O}(n)\)的同时,依然保留所有的“token-token“的注意力,同时还能保留用于做自回归生成的可能性。
其对应的attention计算可以写为去掉softmax并带有kernel的形式:
利用矩阵乘法的结合律,可以将上式简化为:
详细可以参考:
Performer
从Performer出发思考了线性Attention的一些问题,包括关于线性Attention的激活函数选择,以及线性Attention的瓶颈所在(低秩性、稀疏性),总的结论是,线性Attention的最佳激活函数应当是指数函数,而有效的Attention机制应当具备更高的秩和更大的稀疏性。
详情可以参考:
The Devil in Linear Transformer
这个工作主要的贡献是可以去掉整体的normalization的分母项。并证明了分母带来数值问题,所以在最近的linear attention中几乎全部去掉了,取而代之的是加上output normalization
详情可以参考:
另外,Fine-Tuning Pre-trained Transformers into Decaying Fast Weights发现QK的activation啥也不设就good enough,后续的RetNet/GLA也不用激活函数,所以这两个term都省掉了。
FLASH
在 Transformers are Rnns 的实现中 linear attention 还存在一个问题:循环训练并行度太差了。此外,linear attention的recurrent update全部都是element-wise的操作(外积,点乘,...),根本没有一丁点矩阵乘法的影子,而矩阵乘法在GPU上非常高效(相同数量的FLOPs,在A100上用tensor cores算半精度矩阵乘法的效率是其他操作的16倍,所以现代的算法都是怎么矩阵乘法怎么来,这也是为什么注意力机制最先被提出来,然后直接席卷deep learning,因为它训的快呀。)
在parallel形式中,我们不需要算任何hidden state,只通过Q K V来得到output,但是需要 $\mathcal{O}(L^2) $ 复杂度。在Recurrent形式中,我们需要算每个time step的hidden state,但是只需要 $\mathcal{O}(L) $ 的复杂度。
那么存不存在一个介于两者之间算法,能够减少recurrent state的数量从而减少循环的次数,同时复杂度依然是线性的呢?
答案就是linear attention的第三种形式:chunkwise parallel form
最早应该是在Transformer Quality in Linear Time 提出的,现在所有的线性注意力训练都是基于chunkwise parallel form。当chunk size \(C=1\) 的时候,它和recurrent form等价,当 \(C=L\) 的时候,它跟parallel form等价。这是一个exact的算法,而不是一个approximate的算法,所以chunk size不影响output的大小
详见:
Lightning Attention
在全局线性注意力中,每个位置的 token 都可以看到整个序列,所以我们可以先计算所有位置的\(\psi(K)^T V\) ,然后再用 \(\phi(Q)\) 与之做点积。
我们可以直接使用矩阵乘法:
- 计算 \(\psi(K)^T V\):\(\begin{bmatrix} k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4 \end{bmatrix}\) 这是一个简单的向量乘法。
- 计算\(\phi(Q) (\psi(K)^T V)\) : \([q_1, q_2, q_3, q_4] \begin{bmatrix} k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4 \end{bmatrix}\) 这是一个向量乘以一个标量的运算。
在这个过程中,我们可以使用矩阵乘法高效并行地完成计算,复杂度为\(O(n)\)。
但是,在LLM推理中,我们通常需要「因果性」,每个位置的 token 只能看到它之前的 tokens,所以我们需要为每个位置单独计算注意力,并且要考虑到每个位置可见的 tokens 的数量是不同的。比如:
- 位置 1:\(x_1\) 只能看到自己, \(text{Output}_1 = q_1 (k_1^T v_1)\)
- 位置 2:\(x_2\) 可以看到 \(x_1\) ** 和 \(x_2\) ,**\(text{Output}_2 = q_2 (k_1^T v_1 + k_2^T v_2)\)
- 位置 3:\(x_3\) 可以看到 \(x_1\)、\(x_2\) 和 \(x_3\),\(text{Output}_3 = q_3 (k_1^T v_1 + k_2^T v_2 + k_3^T v_3)\)
- 位置 4:\(x_4\) 可以看到 \(x_1\)、\(x_2\)、\(x_3\) 和 \(x_4\),\(text{Output}_4 = q_4 (k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4)\)
这样的cumsum操作无法被高效地表达为矩阵乘法,因此虽然计算复杂度下来了,但实际运算的效率并不高。
Lightning Attention如何克服传统线性注意力的问题
传统线性注意力虽然降低了复杂度,但在实际实现中面临一个关键问题:cumsum操作。这种操作会导致严重的内存瓶颈和计算效率下降,特别是在处理长序列时。
Lightning Attention 利用了分块技术,有效地规避了cumsum操作带来的问题。从算法1可以看出其实现细节:

- IO感知的分块策略
- 注意力计算的双重分解
- 累积矩阵的巧妙应用
- 并行优化与硬件友好
这种优化很像FlashAttention的思路,它们的基本机制相近,都是「在块/小批量维度上将专用计算搬到高速缓存中进行,从而避免大矩阵在显存之间频繁交互」。所以,从工程实现角度,也可以把 Lightning Attention 看作「针对线性注意力的 FlashAttention 思路移植与优化」。