概述
众所周知,尽管基于Attention机制的Transformer类模型有着良好的并行性能,但它的空间和时间复杂度都是 \(\mathcal{O}(n^2)\) 级别的,\(n\) 是序列长度,所以当 \(n\) 比较大时Transformer模型的计算量难以承受。近来,也有不少工作致力于降低Transformer模型的计算量,比如模型剪枝、量化、蒸馏等精简技术,又或者修改Attention结构,使得其复杂度能降低到 \(\mathcal{O}(n\log n)\) 甚至 \(\mathcal{O}(n)\)。
改变这一复杂度的思路主要有两种:
- 一是走稀疏化的思路,比如OpenAI的Sparse Attention,通过“只保留小区域内的数值、强制让大部分注意力为零”的方式,来减少Attention的计算量。经过特殊设计之后,Attention矩阵的大部分元素都是0,因此理论上它也能节省显存占用量和计算量。后续类似工作还有《Explicit Sparse Transformer: Concentrated Attention Through Explicit Selection》、《Longformer: The Long-Document Transformer》等。
- 二是走线性化的思路,也就是本文要介绍的一系列工作
快速预览
其实linear attention的思想很简单,就是把
的softmax去掉,变成了
然后借助矩阵乘法结合律得到
在双向注意力里,比方说古代的bert时期,以及计算机视觉领域中,这样就已经足够了,大家开开心心地在线性时间内算两个很大的矩阵乘法,甚至都不需要写kernel就能很高效
但是在autoregressive modeling中,我们需要有causal mask。 训练和推理的形式分别是:
同样地,把去掉softmax之后,我们可以得到
由于这个 \(\mathbf{M}\) 的存在,我们不能直接利用矩阵乘法的结合律得到上面先算KV 矩阵乘法的线性形式(因为矩阵乘法跟矩阵点乘是不可以交换的)
Attention
当前最流行的Attention机制当属 Scaled-Dot Attention,形式为
这里的 \(\boldsymbol{Q}\in\mathbb{R}^{n\times d_k}, \boldsymbol{K}\in\mathbb{R}^{m\times d_k}, \boldsymbol{V}\in\mathbb{R}^{m\times d_v}\),简单起见我们就没显式地写出Attention的缩放因子了。本文我们主要关心Self Attention场景,所以为了介绍上的方便统一设 \(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}\in\mathbb{R}^{n\times d}\),一般场景下都有\(n > d\) 甚至 \(n\gg d\)(BERT base里边\(d=64\))。
我们可以将Scaled-Dot Attention的(1)等价地改写为(本文的向量都是列向量)
所以,Scaled-Dot Attention其实就是以 \(e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\) 为权重对 \(\boldsymbol{v}_j\) 做加权平均。所以我们可以提出一个Attention的一般化定义:
也就是把 \(e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\) 换成 \(\boldsymbol{q}_i, \boldsymbol{k}_j\) 的一般函数 \(\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\),为了保留Attention相似的分布特性,我们要求 \(\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0\) 恒成立。也就是说,我们如果要定义新式的Attention,那么要保留 式3 的形式,并且满足 \(\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0\)。
Transformers are RNNs
自回归linear attention的开山之作
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
主要的idea是去掉标准Attention中的Softmax,就可以使得Attention的复杂度退化为理想的\(\mathcal{O}(n)\)级别(Linear Attention)。相比于其他类似的改进结构的工作,这种修改能在把复杂度降到 \(\mathcal{O}(n)\)的同时,依然保留所有的“token-token“的注意力,同时还能保留用于做自回归生成的可能性。
其对应的attention计算可以写为去掉softmax并带有kernel的形式:
利用矩阵乘法的结合律,可以将上式简化为:
详细可以参考:Transformers are RNNs
Performer
从Performer出发思考了线性Attention的一些问题,包括关于线性Attention的激活函数选择,以及线性Attention的瓶颈所在(低秩性、稀疏性),总的结论是,线性Attention的最佳激活函数应当是指数函数,而有效的Attention机制应当具备更高的秩和更大的稀疏性。
详情可以参考:Preformer
The Devil in Linear Transformer
这个工作主要的贡献是可以去掉整体的normalization的分母项。并证明了分母带来数值问题,所以在最近的linear attention中几乎全部去掉了,取而代之的是加上output normalization
详情可以参考:The Devil in Linear Transformer
另外,Fine-Tuning Pre-trained Transformers into Decaying Fast Weights发现QK的activation啥也不设就good enough,后续的RetNet/GLA也不用激活函数,所以这两个term都省掉了。
FLASH
在 Transformers are Rnns 的实现中 linear attention 还存在一个问题:循环训练并行度太差了。此外,linear attention的recurrent update全部都是element-wise的操作(外积,点乘,...),根本没有一丁点矩阵乘法的影子,而矩阵乘法在GPU上非常高效(相同数量的FLOPs,在A100上用tensor cores算半精度矩阵乘法的效率是其他操作的16倍,所以现代的算法都是怎么矩阵乘法怎么来,这也是为什么注意力机制最先被提出来,然后直接席卷deep learning,因为它训的快呀。)
在parallel形式中,我们不需要算任何hidden state,只通过Q K V来得到output,但是需要 \(\mathcal{O}(L^2) \) 复杂度。在Recurrent形式中,我们需要算每个time step的hidden state,但是只需要 \(\mathcal{O}(L) \) 的复杂度。
那么存不存在一个介于两者之间算法,能够减少recurrent state的数量从而减少循环的次数,同时复杂度依然是线性的呢?
答案就是linear attention的第三种形式:chunkwise parallel form
最早应该是在Transformer Quality in Linear Time 提出的,现在所有的线性注意力训练都是基于chunkwise parallel form。当chunk size \(C=1\) 的时候,它和recurrent form等价,当 \(C=L\) 的时候,它跟parallel form等价。这是一个exact的算法,而不是一个approximate的算法,所以chunk size不影响output的大小
Lightning Attention
在全局线性注意力中,每个位置的 token 都可以看到整个序列,所以我们可以先计算所有位置的\(\psi(K)^T V\) ,然后再用 \(\phi(Q)\) 与之做点积。
我们可以直接使用矩阵乘法:
- 计算 \(\psi(K)^T V\):\(\begin{bmatrix} k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4 \end{bmatrix}\) 这是一个简单的向量乘法。
- 计算\(\phi(Q) (\psi(K)^T V)\) : \([q_1, q_2, q_3, q_4] \begin{bmatrix} k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4 \end{bmatrix}\) 这是一个向量乘以一个标量的运算。
在这个过程中,我们可以使用矩阵乘法高效并行地完成计算,复杂度为\(O(n)\)。
但是,在LLM推理中,我们通常需要「因果性」,每个位置的 token 只能看到它之前的 tokens,所以我们需要为每个位置单独计算注意力,并且要考虑到每个位置可见的 tokens 的数量是不同的。比如:
- 位置 1:\(x_1\)只能看到自己, \(text{Output}_1 = q_1 (k_1^T v_1)\)
- 位置 2:\(x_2\) 可以看到 \(x_1\) 和 \(x_2\) ,\(text{Output}_2 = q_2 (k_1^T v_1 + k_2^T v_2)\)
- 位置 3:\(x_3\) 可以看到 \(x_1\)、\(x_2\) 和 \(x_3\),\(text{Output}_3 = q_3 (k_1^T v_1 + k_2^T v_2 + k_3^T v_3)\)
- 位置 4:\(x_4\) 可以看到 \(x_1\)、\(x_2\)、\(x_3\) 和 \(x_4\),\(text{Output}_4 = q_4 (k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4)\)
这样的cumsum操作无法被高效地表达为矩阵乘法,因此虽然计算复杂度下来了,但实际运算的效率并不高。
Lightning Attention如何克服传统线性注意力的问题
传统线性注意力虽然降低了复杂度,但在实际实现中面临一个关键问题:cumsum操作。这种操作会导致严重的内存瓶颈和计算效率下降,特别是在处理长序列时。
Lightning Attention 利用了分块技术,有效地规避了cumsum操作带来的问题。从算法1可以看出其实现细节:

- IO感知的分块策略
Lightning Attention的核心创新在于"IO-aware"的设计思路:- 将输入矩阵分块处理:将长度为n的序列分成\(T=n/B\) 个块,每个块大小为 \(B×d\)
- 高效内存管理:每次只将当前需要的块(\(Q_t, K_t, V_t\))从慢速HBM加载到高速片上SRAM
- 避免大规模cumsum:通过分块计算避免了对整个序列进行一次性操作
- 注意力计算的双重分解
算法将注意力计算巧妙地分解为两个组件:- 块内计算(intra-block):\(O_{intra} = [(Q_tK_t^T) ⊙ M]V_t\)
- 使用左乘积注意力计算
- 通过掩码 \(M\) 确保因果关系
- 仅在当前块内进行计算,复杂度为\(O(B²)\)
- 块间计算(inter-block):\(O_{inter} = Q_t(KV)\)
- 使用右乘积注意力计算
- 利用累积矩阵KV捕获当前块与之前所有块的关系
- 避免了对之前块的重复计算
- 块内计算(intra-block):\(O_{intra} = [(Q_tK_t^T) ⊙ M]V_t\)
- 累积矩阵的巧妙应用
- KV矩阵的递增更新:\(KV = KV + K_t^TV_t\)
- 这种设计使得每个块都能高效地获取之前所有块的信息,而无需重新计算
- 累积矩阵的维度为\(d×d\),与序列长度 \(n\) 无关,保持了内存效率
- 并行优化与硬件友好
- 减少内存访问:通过在SRAM中完成大部分计算,减少了HBM与SRAM之间的数据传输
- 块级并行:不同块之间的操作可以结合硬件流水线进行并行处理
- 避免串行依赖:规避了传统cumsum操作带来的严格序列依赖
这种优化很像FlashAttention的思路,它们的基本机制相近,都是「在块/小批量维度上将专用计算搬到高速缓存中进行,从而避免大矩阵在显存之间频繁交互」。所以,从工程实现角度,也可以把 Lightning Attention 看作「针对线性注意力的 FlashAttention 思路移植与优化」。