Reading

The Devil in Linear Transformer

Attention

当前最流行的Attention机制当属 Scaled-Dot Attention,形式为

\[\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}\boldsymbol{K}^{\top}\right)\boldsymbol{V}\tag{1}\end{equation}\]

这里的 \(\boldsymbol{Q}\in\mathbb{R}^{n\times d_k}, \boldsymbol{K}\in\mathbb{R}^{m\times d_k}, \boldsymbol{V}\in\mathbb{R}^{m\times d_v}\),简单起见我们就没显式地写出Attention的缩放因子了。本文我们主要关心Self Attention场景,所以为了介绍上的方便统一设 \(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}\in\mathbb{R}^{n\times d}\),一般场景下都有\(n > d\) 甚至 \(n\gg d\)(BERT base里边\(d=64\))。

我们可以将Scaled-Dot Attention的(1)等价地改写为(本文的向量都是列向量)

\[\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})i = \frac{\sum\limits{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\boldsymbol{v}j}{\sum\limits{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}}\tag{2}\end{equation}\]

所以,Scaled-Dot Attention其实就是以 \(e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\) 为权重对 \(\boldsymbol{v}_j\) 做加权平均。所以我们可以提出一个Attention的一般化定义:

\[\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})i = \frac{\sum\limits{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}j}{\sum\limits{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\tag{3}\end{equation}\]

也就是把 \(e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\) 换成 \(\boldsymbol{q}_i, \boldsymbol{k}_j\) 的一般函数 \(\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\),为了保留Attention相似的分布特性,我们要求 \(\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0\) 恒成立。也就是说,我们如果要定义新式的Attention,那么要保留 式3 的形式,并且满足 \(\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0\)

简介

承接 Transformers are RNNs 这篇论文

目的: 为了分析之前linear transformer的效果为什么不好。发现主要是两个原因造成的:

  1. 无界梯度(unbounded gradient),会导致模型在训练时不稳定,收敛不好;
  2. 注意力稀释(attention dilution),transformer在lower level时应该更关注局部特征,而higher level更关注全局特征,但线性transformer中的attention往往weight 更均匀化,不能聚焦在local区域上,因此称为attention稀释。

解决方案: 

  1. 对linear attention算出来的output接着做个normalization,形成NormFormer的结构,增加训练的稳定性。
  2. 在底层的layer用diagonal的local attention。

自注意力机制的统一表示

在自注意力模块中,无论是传统的vanilla注意力还是线性注意力,其注意力矩阵 \(\mathbf{P} \in \mathbb{R}^{n \times n}\) 可以用以下统一形式表示:

\[\begin{equation}p_{ij} = \frac{f(s_{ij})}{\sum_{k=1}^{n}f(s_{ik})}, \quad f: \mathbb{R} \rightarrow \mathbb{R}\tag{4}\end{equation}\]

其中 \(s_{ij}\) 表示token之间的相似度,注意这里19式其实就是 公式3, 只是在这篇论文写成了这个形式。

  • Vanilla注意力

在vanilla注意力中,\(s_{ij}\) 的计算方式为:

\[ s_{ij} = \frac{\mathbf{q}_i^T\mathbf{k}_j}{\sqrt{d}}, \quad f(x) = \exp(x)\]
  • 线性注意力

在线性注意力中,\(s_{ij}\) 可以使用核函数 \(\phi\) 分解为:

\[s_{ij} = \phi(\mathbf{q}_i)^T\phi(\mathbf{k}_j), \quad f(x) = x\]

无界梯度(unbounded gradient)

梯度的一般形式

注意力矩阵 \(\mathbf{P}\) 关于相似度 \(s_{ik}\) 的梯度可以推导为:

\[\begin{equation}\frac{\partial p_{ij}}{\partial s_{ik}} = \frac{f'(s_{ik})}{f(s_{ik})}(1_{j=k}p_{ij} - p_{ij}p_{ik})\tag{5}\end{equation}\]

推导过程

推导注意力矩阵 \(\mathbf{P}\) 关于相似度 \(s_{ik}\) 的梯度的一般形式。

我们需要计算 \(\frac{\partial p_{ij}}{\partial s_{ik}}\),即注意力权重 \(p_{ij}\) 关于相似度 \(s_{ik}\) 的偏导数。

首先,定义分子为 \(N_{ij} = f(s_{ij})\),分母为 \(D_i = \sum_{l=1}^{n}f(s_{il})\),则 \(p_{ij} = \frac{N_{ij}}{D_i}\)

使用商的求导法则:

\[\frac{\partial p_{ij}}{\partial s_{ik}} = \frac{\partial}{\partial s_{ik}}\left(\frac{N_{ij}}{D_i}\right) = \frac{D_i \cdot \frac{\partial N_{ij}}{\partial s_{ik}} - N_{ij} \cdot \frac{\partial D_i}{\partial s_{ik}}}{D_i^2}\]

接下来,分别计算 \(\frac{\partial N_{ij}}{\partial s_{ik}}\)\(\frac{\partial D_i}{\partial s_{ik}}\)

  1. 对于 \(\frac{\partial N_{ij}}{\partial s_{ik}}\)
    • \(j = k\) 时,\(\frac{\partial N_{ij}}{\partial s_{ik}} = \frac{\partial f(s_{ij})}{\partial s_{ik}} = \frac{\partial f(s_{ik})}{\partial s_{ik}} = f'(s_{ik})\)
    • \(j \neq k\)时,\(\frac{\partial N_{ij}}{\partial s_{ik}} = 0\) (因为 \(N_{ij}\) 不依赖于 \(s_{ik}\))
  1. 对于 \(\frac{\partial D_i}{\partial s_{ik}}\)
\[\frac{\partial D_i}{\partial s_{ik}} = \frac{\partial}{\partial s_{ik}}\sum_{l=1}^{n}f(s_{il}) = \frac{\partial f(s_{ik})}{\partial s_{ik}} = f'(s_{ik})\]

将这些结果代入原式:

\(j = k\)时:

\[\frac{\partial p_{ij}}{\partial s_{ik}} = \frac{D_i \cdot f'(s_{ik}) - N_{ik} \cdot f'(s_{ik})}{D_i^2} = \frac{f'(s_{ik})}{D_i} - \frac{N_{ik} \cdot f'(s_{ik})}{D_i^2} = \frac{f'(s_{ik})}{D_i}\left(1 - \frac{N_{ik}}{D_i}\right) = \frac{f'(s_{ik})}{D_i}(1 - p_{ik})\]

\(j \neq k\)时:

\[\frac{\partial p_{ij}}{\partial s_{ik}} = \frac{0 - N_{ij} \cdot f'(s_{ik})}{D_i^2} = -\frac{N_{ij} \cdot f'(s_{ik})}{D_i^2} = -\frac{f'(s_{ik})}{D_i} \cdot \frac{N_{ij}}{D_i} = -\frac{f'(s_{ik})}{D_i} \cdot p_{ij}\]

将两种情况合并,并引入指示函数 \(1_{j=k}\)(当 \(j=k\) 时为1,否则为0):

\[\frac{\partial p_{ij}}{\partial s_{ik}} = \frac{f'(s_{ik})}{D_i} \cdot [1_{j=k} - p_{ij}]\]

进一步,由于 \(p_{ij} = \frac{f(s_{ij})}{D_i}\),可以将 \(\frac{1}{D_i}\) 替换为 \(\frac{p_{ij}}{f(s_{ij})}\)

\[\frac{\partial p_{ij}}{\partial s_{ik}} = \frac{f'(s_{ik}) \cdot p_{ik}}{f(s_{ik})} \cdot [1_{j=k} - p_{ij}]\]

注意到 \(1_{j=k}p_{ij} = 1_{j=k}p_{ik}\)(因为当 \(j=k\) 时,\(p_{ij}=p_{ik}\)),所以:

\[\frac{\partial p_{ij}}{\partial s_{ik}} = \frac{f'(s_{ik}) \cdot p_{ik}}{f(s_{ik})} \cdot [1_{j=k} - p_{ij}]= \frac{f'(s_{ik})}{f(s_{ik})}(1_{j=k}p_{ij} - p_{ij}p_{ik})\]

最终,我们得到了论文中的表达式:

\[\frac{\partial p_{ij}}{\partial s_{ik}} = \frac{f'(s_{ik})}{f(s_{ik})}(1_{j=k}p_{ij} - p_{ij}p_{ik})\]

这个公式是注意力矩阵梯度的一般形式,适用于任何激活函数 \(f\)。在不同类型的注意力机制中,只需代入相应的激活函数及其导数即可得到具体的梯度表达式。

  • Vanilla注意力的梯度

对于vanilla注意力,有 \(f'(x) = \exp(x) = f(x)\),因此:

\[\frac{\partial p_{ij}}{\partial s_{ik}} = 1_{j=k}p_{ij} - p_{ij}p_{ik} = \begin{cases} p_{ik} - p_{ik}p_{ik} \in [0, 1/4] & j = k \\ -p_{ij}p_{ik} \in [-1/4, 0] & j \neq k \end{cases}\]

这表明vanilla注意力的梯度是有界的:

\[\left|\frac{\partial p_{ij}}{\partial s_{ik}}\right| < \frac{1}{4}\]
  • 线性注意力的梯度

对于线性注意力,有 \(f'(x) = 1\),因此:

\[\frac{\partial p_{ij}}{\partial s_{ik}} = \frac{1}{s_{ik}}(1_{j=k}p_{ij} - p_{ij}p_{ik}) = \begin{cases} \frac{1}{s_{ik}}(p_{ik} - p_{ik}p_{ik}) & j = k \\ \frac{1}{s_{ik}}(-p_{ij}p_{ik}) & j \neq k \end{cases}\]

这导致线性注意力的梯度上界为:

\[ \left|\frac{\partial p_{ij}}{\partial s_{ik}}\right| < \frac{1}{4|s_{ik}|}\]

由于 \(|s_{ik}|^{-1} = |\phi(\mathbf{q}_i)\phi(\mathbf{k}_j)^T|^{-1}\)可以任意大,线性注意力的梯度没有上界

更进一步,可以证明线性注意力的梯度也没有下界。

无界梯度会导致优化过程不稳定,在初步研究中表现为更差的收敛结果。这是线性注意力相比vanilla注意力性能下降的重要原因之一。

注意力稀释问题(Attention Dilution)

image

作者通过评估不同层级上query在邻域内其他query上的attention权重占比发现问题

  • vanilla的attention主要集中在对角线附近
  • linear attention 由于是low-rank来逼近的,所以必然是dense的(PS:有很多low-rank+sparse的方法来解决这个问题,可以从图中看到,注意力过多的给了分布较远的token。这种注意力"稀释"现象导致模型无法有效学习层次化特征表示
  • 该文提出的方法(ie., TransNormer)的注意力则主要集中在对角线附近。

解决方案架构设计

image

NormAttention模块

  • 设计思路
    • \(O = Q(K^TV)\)
    • \(O_{norm} = \text{XNorm}(Q(K^TV))\)
  • 实现细节
    • XNorm可以是LayerNorm或RMSNorm
    • RMSNorm定义为:\(\text{RMSNorm}(x) = \frac{x}{\sqrt{\sigma^2 + \epsilon}}\),其中\(\sigma^2 = \sum_{i=1}^d x_i^2 /d\)
    • 作者证明此设计使梯度有上界,提高训练稳定性

DiagAttention模块

  • 基于pattern的attention设计,将query按距离划分为不重叠的window
  • 每个window内进行attention计算,使用vanilla attention
  • 这种设计使早期层能更好地关注局部特征

作者首先分析了vanilla attention和线性attention(1+elu)的组合,发现在靠底层的layer用vanilla(更sparse)的attention效果要更好(感觉也合理,模型最开始的时候应该更关注临近的token,到后期开始关注比较远的token)。

image

为了保持整体架构线性复杂度,作者把靠底层的layer换成了local attention,然后后面的layer用去掉了分母的,normalize过后的线性attention。

Reference

[EMNLP‘22 简读] The Devil in Linear Transformer