Reading

LinearAttention 概述

概述

众所周知,尽管基于Attention机制的Transformer类模型有着良好的并行性能,但它的空间和时间复杂度都是 \(\mathcal{O}(n^2)\) 级别的,\(n\) 是序列长度,所以当 \(n\) 比较大时Transformer模型的计算量难以承受。近来,也有不少工作致力于降低Transformer模型的计算量,比如模型剪枝、量化、蒸馏等精简技术,又或者修改Attention结构,使得其复杂度能降低到 \(\mathcal{O}(n\log n)\) 甚至 \(\mathcal{O}(n)\)

改变这一复杂度的思路主要有两种:

快速预览

其实linear attention的思想很简单,就是把

\[\mathbf{O} = \operatorname{softmax}(\mathbf{Q}\mathbf{K}^\top) \mathbf{V}\]

的softmax去掉,变成了

\[\mathbf{O} = (\mathbf{Q}\mathbf{K}^\top) \mathbf{V}\]

然后借助矩阵乘法结合律得到

\[\mathbf{O} = \mathbf{Q}(\mathbf{K}^\top \mathbf{V})\]

在双向注意力里,比方说古代的bert时期,以及计算机视觉领域中,这样就已经足够了,大家开开心心地在线性时间内算两个很大的矩阵乘法,甚至都不需要写kernel就能很高效

但是在autoregressive modeling中,我们需要有causal mask。 训练和推理的形式分别是:

\[\begin{align*} \mathbf{O} &= \operatorname{softmax}(\mathbf{Q}\mathbf{K}^\top \odot \mathbf{M}) \mathbf{V} &&\in \mathbb{R}^{L\times d} \\ \mathbf{o_t} &= \sum_{j=1}^t \frac{\exp(\mathbf{q}_t^\top \mathbf{k}j)}{\sum_{l=1}^t\exp(\mathbf{q}^\top_t \mathbf{k}_l)}\mathbf{v}_j && \in \mathbb{R}^d \end{align*}\]

同样地,把去掉softmax之后,我们可以得到

\[\begin{aligned}\mathbf{O} &= (\mathbf{Q}\mathbf{K}^\top \odot \mathbf{M}) \mathbf{V} && \in \mathbb{R}^{L \times d} \\ \mathbf{o}_t &= \sum_{j=1}^t (\mathbf{q}_t^\top \mathbf{k}_j) \mathbf{v}_j && \in \mathbb{R}^d \end{aligned}\]

由于这个 \(\mathbf{M}\) 的存在,我们不能直接利用矩阵乘法的结合律得到上面先算KV 矩阵乘法的线性形式(因为矩阵乘法跟矩阵点乘是不可以交换的)

Attention

当前最流行的Attention机制当属 Scaled-Dot Attention,形式为

\[\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}\boldsymbol{K}^{\top}\right)\boldsymbol{V}\tag{1}\end{equation}\]

这里的 \(\boldsymbol{Q}\in\mathbb{R}^{n\times d_k}, \boldsymbol{K}\in\mathbb{R}^{m\times d_k}, \boldsymbol{V}\in\mathbb{R}^{m\times d_v}\),简单起见我们就没显式地写出Attention的缩放因子了。本文我们主要关心Self Attention场景,所以为了介绍上的方便统一设 \(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}\in\mathbb{R}^{n\times d}\),一般场景下都有\(n > d\) 甚至 \(n\gg d\)(BERT base里边\(d=64\))。

我们可以将Scaled-Dot Attention的(1)等价地改写为(本文的向量都是列向量)

\[\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})i = \frac{\sum\limits{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\boldsymbol{v}j}{\sum\limits{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}}\tag{2}\end{equation}\]

所以,Scaled-Dot Attention其实就是以 \(e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\) 为权重对 \(\boldsymbol{v}_j\) 做加权平均。所以我们可以提出一个Attention的一般化定义:

\[\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})i = \frac{\sum\limits{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}j}{\sum\limits{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\tag{3}\end{equation}\]

也就是把 \(e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\) 换成 \(\boldsymbol{q}_i, \boldsymbol{k}_j\) 的一般函数 \(\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\),为了保留Attention相似的分布特性,我们要求 \(\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0\) 恒成立。也就是说,我们如果要定义新式的Attention,那么要保留 式3 的形式,并且满足 \(\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0\)

Transformers are RNNs

自回归linear attention的开山之作

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

主要的idea是去掉标准Attention中的Softmax,就可以使得Attention的复杂度退化为理想的\(\mathcal{O}(n)\)级别(Linear Attention)。相比于其他类似的改进结构的工作,这种修改能在把复杂度降到 \(\mathcal{O}(n)\)的同时,依然保留所有的“token-token“的注意力,同时还能保留用于做自回归生成的可能性。

其对应的attention计算可以写为去掉softmax并带有kernel的形式:

\[\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)}\tag{4}\end{equation}\]

利用矩阵乘法的结合律,可以将上式简化为:

\[\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\phi(q_i)^T \sum_{j=1}^{N} \phi(k_j) v_j^T}{\phi(q_i)^T \sum_{j=1}^{N} \phi(k_j)}\tag{5}\end{equation}\]

详细可以参考:Transformers are RNNs

Performer

从Performer出发思考了线性Attention的一些问题,包括关于线性Attention的激活函数选择,以及线性Attention的瓶颈所在(低秩性、稀疏性),总的结论是,线性Attention的最佳激活函数应当是指数函数,而有效的Attention机制应当具备更高的秩和更大的稀疏性。

详情可以参考:Preformer

The Devil in Linear Transformer

这个工作主要的贡献是可以去掉整体的normalization的分母项。并证明了分母带来数值问题,所以在最近的linear attention中几乎全部去掉了,取而代之的是加上output normalization

\[\begin{equation}O_{norm} = \text{XNorm}(Q(K^TV))\tag{6}\end{equation}\]

详情可以参考:The Devil in Linear Transformer

另外,Fine-Tuning Pre-trained Transformers into Decaying Fast Weights发现QK的activation啥也不设就good enough,后续的RetNet/GLA也不用激活函数,所以这两个term都省掉了。

FLASH

在 Transformers are Rnns 的实现中 linear attention 还存在一个问题:循环训练并行度太差了。此外,linear attention的recurrent update全部都是element-wise的操作(外积,点乘,...),根本没有一丁点矩阵乘法的影子,而矩阵乘法在GPU上非常高效(相同数量的FLOPs,在A100上用tensor cores算半精度矩阵乘法的效率是其他操作的16倍,所以现代的算法都是怎么矩阵乘法怎么来,这也是为什么注意力机制最先被提出来,然后直接席卷deep learning,因为它训的快呀。)

在parallel形式中,我们不需要算任何hidden state,只通过Q K V来得到output,但是需要 \(\mathcal{O}(L^2) \) 复杂度。在Recurrent形式中,我们需要算每个time step的hidden state,但是只需要 \(\mathcal{O}(L) \) 的复杂度。

那么存不存在一个介于两者之间算法,能够减少recurrent state的数量从而减少循环的次数,同时复杂度依然是线性的呢?

答案就是linear attention的第三种形式:chunkwise parallel form

最早应该是在Transformer Quality in Linear Time 提出的,现在所有的线性注意力训练都是基于chunkwise parallel form。当chunk size \(C=1\) 的时候,它和recurrent form等价,当 \(C=L\) 的时候,它跟parallel form等价。这是一个exact的算法,而不是一个approximate的算法,所以chunk size不影响output的大小

详见:FLASH:高效Transformer设计

Lightning Attention

在全局线性注意力中,每个位置的 token 都可以看到整个序列,所以我们可以先计算所有位置的\(\psi(K)^T V\) ,然后再用 \(\phi(Q)\) 与之做点积。

\[\text{Global Attention}(Q, K, V) = \phi(Q) (\psi(K)^T V) = [q_1, q_2, q_3, q_4] \begin{bmatrix} k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4 \end{bmatrix}\]

我们可以直接使用矩阵乘法:

  1. 计算 \(\psi(K)^T V\)\(\begin{bmatrix} k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4 \end{bmatrix}\) 这是一个简单的向量乘法。
  2. 计算\(\phi(Q) (\psi(K)^T V)\) : \([q_1, q_2, q_3, q_4] \begin{bmatrix} k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4 \end{bmatrix}\) 这是一个向量乘以一个标量的运算。

在这个过程中,我们可以使用矩阵乘法高效并行地完成计算,复杂度为\(O(n)\)

但是,在LLM推理中,我们通常需要「因果性」,每个位置的 token 只能看到它之前的 tokens,所以我们需要为每个位置单独计算注意力,并且要考虑到每个位置可见的 tokens 的数量是不同的。比如:

  • 位置 1:\(x_1\)只能看到自己, \(text{Output}_1 = q_1 (k_1^T v_1)\)
  • 位置 2:\(x_2\) 可以看到 \(x_1\) \(x_2\)\(text{Output}_2 = q_2 (k_1^T v_1 + k_2^T v_2)\)
  • 位置 3:\(x_3\) 可以看到 \(x_1\)\(x_2\) \(x_3\)\(text{Output}_3 = q_3 (k_1^T v_1 + k_2^T v_2 + k_3^T v_3)\)
  • 位置 4:\(x_4\) 可以看到 \(x_1\)\(x_2\)\(x_3\) \(x_4\)\(text{Output}_4 = q_4 (k_1^T v_1 + k_2^T v_2 + k_3^T v_3 + k_4^T v_4)\)
    这样的cumsum操作无法被高效地表达为矩阵乘法,因此虽然计算复杂度下来了,但实际运算的效率并不高。

Lightning Attention如何克服传统线性注意力的问题

传统线性注意力虽然降低了复杂度,但在实际实现中面临一个关键问题:cumsum操作。这种操作会导致严重的内存瓶颈和计算效率下降,特别是在处理长序列时。

Lightning Attention 利用了分块技术,有效地规避了cumsum操作带来的问题。从算法1可以看出其实现细节:

image
  1. IO感知的分块策略
    Lightning Attention的核心创新在于"IO-aware"的设计思路:
    • 将输入矩阵分块处理:将长度为n的序列分成\(T=n/B\) 个块,每个块大小为 \(B×d\)
    • 高效内存管理:每次只将当前需要的块(\(Q_t, K_t, V_t\))从慢速HBM加载到高速片上SRAM
    • 避免大规模cumsum:通过分块计算避免了对整个序列进行一次性操作
  2. 注意力计算的双重分解
    算法将注意力计算巧妙地分解为两个组件:
    • 块内计算(intra-block)\(O_{intra} = [(Q_tK_t^T) ⊙ M]V_t\)
      • 使用左乘积注意力计算
      • 通过掩码 \(M\) 确保因果关系
      • 仅在当前块内进行计算,复杂度为\(O(B²)\)
    • 块间计算(inter-block)\(O_{inter} = Q_t(KV)\)
      • 使用右乘积注意力计算
      • 利用累积矩阵KV捕获当前块与之前所有块的关系
      • 避免了对之前块的重复计算
  3. 累积矩阵的巧妙应用
    • KV矩阵的递增更新\(KV = KV + K_t^TV_t\)
    • 这种设计使得每个块都能高效地获取之前所有块的信息,而无需重新计算
    • 累积矩阵的维度为\(d×d\),与序列长度 \(n\) 无关,保持了内存效率
  4. 并行优化与硬件友好
    • 减少内存访问:通过在SRAM中完成大部分计算,减少了HBM与SRAM之间的数据传输
    • 块级并行:不同块之间的操作可以结合硬件流水线进行并行处理
    • 避免串行依赖:规避了传统cumsum操作带来的严格序列依赖

这种优化很像FlashAttention的思路,它们的基本机制相近,都是「在块/小批量维度上将专用计算搬到高速缓存中进行,从而避免大矩阵在显存之间频繁交互」。所以,从工程实现角度,也可以把 Lightning Attention 看作「针对线性注意力的 FlashAttention 思路移植与优化」。

Reference

知乎:Lightning Attention 是如何克服传统线性注意力机制需要累加求和的缺陷的?