Reading

LinearAttention 概述

概述

众所周知,尽管基于Attention机制的Transformer类模型有着良好的并行性能,但它的空间和时间复杂度都是 级别的, 是序列长度,所以当 比较大时Transformer模型的计算量难以承受。近来,也有不少工作致力于降低Transformer模型的计算量,比如模型剪枝、量化、蒸馏等精简技术,又或者修改Attention结构,使得其复杂度能降低到 甚至

改变这一复杂度的思路主要有两种:

快速预览

其实linear attention的思想很简单,就是把

的softmax去掉,变成了

然后借助矩阵乘法结合律得到

在双向注意力里,比方说古代的bert时期,以及计算机视觉领域中,这样就已经足够了,大家开开心心地在线性时间内算两个很大的矩阵乘法,甚至都不需要写kernel就能很高效

但是在autoregressive modeling中,我们需要有causal mask。 训练和推理的形式分别是:

同样地,把去掉softmax之后,我们可以得到

由于这个 的存在,我们不能直接利用矩阵乘法的结合律得到上面先算KV 矩阵乘法的线性形式(因为矩阵乘法跟矩阵点乘是不可以交换的)

Transformers are RNNs

自回归linear attention的开山之作

Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

主要的idea是去掉标准Attention中的Softmax,就可以使得Attention的复杂度退化为理想的级别(Linear Attention)。相比于其他类似的改进结构的工作,这种修改能在把复杂度降到 的同时,依然保留所有的“token-token“的注意力,同时还能保留用于做自回归生成的可能性。

其对应的attention计算可以写为去掉softmax并带有kernel的形式:

利用矩阵乘法的结合律,可以将上式简化为:

详细可以参考:

Performer

从Performer出发思考了线性Attention的一些问题,包括关于线性Attention的激活函数选择,以及线性Attention的瓶颈所在(低秩性、稀疏性),总的结论是,线性Attention的最佳激活函数应当是指数函数,而有效的Attention机制应当具备更高的秩和更大的稀疏性。

详情可以参考:

The Devil in Linear Transformer

这个工作主要的贡献是可以去掉整体的normalization的分母项。并证明了分母带来数值问题,所以在最近的linear attention中几乎全部去掉了,取而代之的是加上output normalization

详情可以参考:

另外,Fine-Tuning Pre-trained Transformers into Decaying Fast Weights发现QK的activation啥也不设就good enough,后续的RetNet/GLA也不用激活函数,所以这两个term都省掉了。

FLASH

在 Transformers are Rnns 的实现中 linear attention 还存在一个问题:循环训练并行度太差了。此外,linear attention的recurrent update全部都是element-wise的操作(外积,点乘,...),根本没有一丁点矩阵乘法的影子,而矩阵乘法在GPU上非常高效(相同数量的FLOPs,在A100上用tensor cores算半精度矩阵乘法的效率是其他操作的16倍,所以现代的算法都是怎么矩阵乘法怎么来,这也是为什么注意力机制最先被提出来,然后直接席卷deep learning,因为它训的快呀。)

在parallel形式中,我们不需要算任何hidden state,只通过Q K V来得到output,但是需要 复杂度。在Recurrent形式中,我们需要算每个time step的hidden state,但是只需要 的复杂度。

那么存不存在一个介于两者之间算法,能够减少recurrent state的数量从而减少循环的次数,同时复杂度依然是线性的呢?

答案就是linear attention的第三种形式:chunkwise parallel form

最早应该是在Transformer Quality in Linear Time 提出的,现在所有的线性注意力训练都是基于chunkwise parallel form。当chunk size 的时候,它和recurrent form等价,当 的时候,它跟parallel form等价。这是一个exact的算法,而不是一个approximate的算法,所以chunk size不影响output的大小

详见:

Lightning Attention

在全局线性注意力中,每个位置的 token 都可以看到整个序列,所以我们可以先计算所有位置的 ,然后再用  与之做点积。

我们可以直接使用矩阵乘法:

  1. 计算 这是一个简单的向量乘法。
  2. 计算 : 这是一个向量乘以一个标量的运算。
    在这个过程中,我们可以使用矩阵乘法高效并行地完成计算,复杂度为

但是,在LLM推理中,我们通常需要「因果性」,每个位置的 token 只能看到它之前的 tokens,所以我们需要为每个位置单独计算注意力,并且要考虑到每个位置可见的 tokens 的数量是不同的。比如:

  • 位置 1:** **只能看到自己,
  • 位置 2: **可以看到 ,**
  • 位置 3: **可以看到 ,**
  • 位置 4: **可以看到 ,**
    这样的cumsum操作无法被高效地表达为矩阵乘法,因此虽然计算复杂度下来了,但实际运算的效率并不高。

Lightning Attention如何克服传统线性注意力的问题

传统线性注意力虽然降低了复杂度,但在实际实现中面临一个关键问题:cumsum操作。这种操作会导致严重的内存瓶颈和计算效率下降,特别是在处理长序列时。

Lightning Attention 利用了分块技术,有效地规避了cumsum操作带来的问题。从算法1可以看出其实现细节:

image

  1. IO感知的分块策略
  2. 注意力计算的双重分解
  3. 累积矩阵的巧妙应用
  4. 并行优化与硬件友好
    这种优化很像FlashAttention的思路,它们的基本机制相近,都是「在块/小批量维度上将专用计算搬到高速缓存中进行,从而避免大矩阵在显存之间频繁交互」。所以,从工程实现角度,也可以把 Lightning Attention 看作「针对线性注意力的 FlashAttention 思路移植与优化」。

Reference

知乎:Lightning Attention 是如何克服传统线性注意力机制需要累加求和的缺陷的?