SSM奠基之作-HiPPO

Apr 23, 2025
1 views
NLP

概述

HiPPO(High-order Polynomial Projection Operators)是目前大热的structured state space model (S4)及其后续工作的backbone. State space mode主要是控制学科里的内容,最近被引入深度学习领域来解决长距离依赖问题。长距离依赖建模的核心问题是如何通过有限的memory来尽可能记住之前所有的历史信息。当前的主流序列建模模型(即Transformer和RNN) 存在着普遍的遗忘问题

  • fixed-size context windows: Transformer的window size通常是有限的,一般来说quadratic的attention最多建模到大约10k的token就到计算极限了
  • vanishing gradient: RNN通过hidden state来存储历史信息,理论上能记住之前所有内容,但实际上的effective memory大概是<1k个token的level,可能的原因是gradient vanishing
    HiPPO 通过数学方法分析来得到closed-form solution,并用数学证明HiPPO不会遗忘历史信息。

快速理解

  • HiPPO的核心思想是用有限维向量储存连续函数的信息。
  • 当我们试图用正交基去逼近一个动态更新的函数时,其结果就是线性ODE系统
  • HiPPO不仅告诉我们线性系统可以逼近复杂函数,还告诉我们如何逼近以及近似程度。
    HiPPO 框架利用了逼近理论中的经典工具,如正交多项式。最终,解决方案采取了一种简单的线性微分方程的形式,被称为 HiPPO 算子:

基本形式

对于事先已经对SSM有所了解的读者,想必知道SSM建模所用的是线性ODE系统:

\[ \begin{equation}\begin{aligned} x'(t) =&\, A x(t) + B u(t) \\ y(t) =&\, C x(t) + D u(t) \end{aligned}\end{equation} \]

其中\(u(t)\in\mathbb{R}^{d_i}, x(t)\in\mathbb{R}^{d}, y(t)\in\mathbb{R}^{d_o}, A\in\mathbb{R}^{d\times d}, B\in\mathbb{R}^{d\times d_i}, C\in\mathbb{R}^{d_o\times d}, D\in\mathbb{R}^{d_o\times d_i}\)。当然我们也可以将它离散化,那么就变成一个线性RNN模型,这部分我们在后面的文章再展开。不管离散化与否,其关键词都是“线性”,那么马上就有一个很自然的问题:为什么是线性系统?线性系统够了吗?

我们可以从两个角度回答这个问题:线性系统既足够简单,也足够复杂

  • 简单是指从理论上来说,线性化往往是复杂系统的一个最基本近似,所以线性系统通常都是无法绕开的一个基本点;
  • 复杂是指即便如此简单的系统,也可以拟合异常复杂的函数,为了理解这一点,我们只需要考虑一个 \(\mathbb{R}^4\) 的简单例子:
    $$

x'(t) =\begin{pmatrix} 1 & 0 & 0 & 0 \
0 & -1 & 0 & 0 \
0 & 0 & 0 & 1 \
0 & 0 & -1 & 0
\end{pmatrix}x(t)
$$

这个例子的基本解是 \(x(t) = (e^t, e^{-t}, \sin t, \cos t)\)。这意味着什么呢?意味着只要d足够大,该线性系统就可以通过指数函数和三角函数的组合来拟合足够复杂的函数,而我们知道拟合能力很强的傅里叶级数也只不过是三角函数的组合,如果在加上指数函数显然就更强了,因此可以想象线性系统也有足够复杂的拟合能力。

当然,这些解释某种意义上都是“马后炮”。HiPPO给出的结果更加本质:当我们试图用正交基去逼近一个动态更新的函数时,其结果就是如上的线性系统。这意味着,HiPPO不仅告诉我们线性系统可以逼近足够复杂的函数,还告诉我们怎么去逼近,甚至近似程度如何。

有限压缩

接下来,我们都只考虑 \(d_i=1\) 的特殊情形,\(d_i > 1\) 只不过是 \(d_i=1\) 时的平行推广。此时,\(u(t)\) 的输出是一个标量,进一步地,作为开头我们先假设 \(t\in[0, 1]\),HiPPO的目标是:用一个有限维的向量来储存这一段 \(u(t)\) 的信息。

看上去这是一个不大可能的需求,因为 \(t\in[0,1]\) 意味着 \(u(t)\) 可能相当于无限个点组成的向量,压缩到一个有限维的向量可能严重失真。不过,如果我们对 \(u(t)\) 做一些假设,并且允许一些损失,那么这个压缩是有可能做到的,并且大多数读者都已经尝试过。比如,当 \(u(t)\) 在某点 \(n+1\) 阶可导的,它对应的 \(n\) 阶泰勒展开式往往是 \(u(t)\) 的良好近似,于是我们可以只储存展开式的 \(n+1\) 个系数来作为 \(u(t)\) 的近似表征,这就成功将 \(u(t)\) 压缩为一个 \(n+1\) 维向量。

当然,对于实际遇到的数据来说,“\(n+1\) 阶可导”这种条件可谓极其苛刻,我们通常更愿意使用在平方可积条件下的正交函数基展开,比如傅里叶(Fourier)级数,它的系数计算公式为

\[ \begin{equation}c_n = \int_0^1 u(t) e^{-2i\pi n t}dt \end{equation} \]

这时候取一个足够大的整数 \(N\),只保留 \(|n|\leq N\) 的系数,那么就将 \(u(t)\) 压缩为一个 \(2N + 1\) 维的向量了。

接下来,问题难度就要升级了。刚才我们说 \(t\in[0,1]\),这是一个静态的区间,而实际中 \(u(t)\) 代表的是持续采集的信号,所以它是不断有新数据进入的,比如现在我们近似了 \([0,1]\) 区间的数据,马上就有 \([1,2]\) 的数据进来,你需要更新逼近结果来试图记忆整个\([0,2]\) 区间,接下来是 \([0,3]\)\([0,4]\) 等等,这我们称为“在线函数逼近”。而上面的傅里叶系数公式 2,只适用于区间 \([0,1]\),因此需要将它进行推广。

为此,我们设 \(t\in[0,T]\)\(s\mapsto t_{\leq T}(s)\)\([0,1]\)\([0,T]\) 的一个映射,那么 \(u(t_{\leq T}(s))\) 作为 \(s\) 的函数时,它的定义区间就是 \([0,1]\),于是就可以复用 式2

\[ \begin{equation}c_n(T) = \int_0^1 u(t_{\leq T}(s)) e^{-2i\pi n s}ds \end{equation} \]

这里我们已经给系数加了标记\((T)\),以表明此时的系数会随着T的变化而变化。

线性初现

能将\([0,1]\) 映射到\([0,T]\)的函数有无穷多,而最终结果也因 \(t_{\leq T}(s)\) 而异,一些比较直观且相对简单的选择如下:

  1. \(t_{\leq T}(s) = sT\),即将 \([0,1]\) 均匀地映射到 \([0,T]\)
  2. 注意 \(t_{\leq T}(s)\) 并不必须是满射,所以像 \(t_{\leq T}(s)=s + T - 1\) 也是允许的,这意味着只保留了最邻近窗口 \([T-1,T]\) 的信息,丢掉了更早的部分,更一般地有 \(t_{\leq T}(s)=sw + T - w\),其中\(w\) 是一个常数,这意味着 \(T-w\) 前的信息被丢掉了;
  3. 也可以选择非均匀映射,比如 \(t_{\leq T}(s) = T\sqrt{s}\),它同样是 \([0,1]\)\([0,T]\) 的满射,但 \(s=1/4\)时就映射到 \(T/2\) 了,这意味着我们虽然关注全局的历史,但同时更侧重于 \(T\) 时刻附近的信息。
    现在我们以\(t_{\leq T}(s)=sw + T - w\) 为例,代入 式3 得到
\[ \begin{equation}c_n(T) = \int_0^1 u(sw + T - w) e^{-2i\pi n s}ds\end{equation} \]

现在我们两边求关于T的导数:\(v(x) = \frac{e^{-2i\pi ns}}{-2i\pi n}\)

\[ \begin{equation}\begin{aligned} \frac{d}{dT}c_n(T) =&\, \int_0^1 u'(sw + T - w) e^{-2i\pi n s}ds \\ =&\, \left.\frac{1}{w} u(sw + T - w) e^{-2i\pi n s}\right|_{s=0}^{s=1} + \frac{2i\pi n}{w}\int_0^1 u(sw + T - w) e^{-2i\pi n s}ds \\ =&\, \frac{1}{w} u(T) - \frac{1}{w} u(T-w) + \frac{2i\pi n}{w} c_n(T) \\ \end{aligned}\ \ \ \ \ \end{equation} \]

其中第二个等号我们用了分部积分公式。由于我们只保留了 \(|n|\leq N\) 的系数,所以根据傅立叶级数的公式,可以认为如下是 \(u(sw + T - w)\) 的一个良好近似:

\[ \begin{equation}u(sw + T - w) \approx \sum_{k=-N}^{k=N} c_k(T) e^{2i\pi k s}\end{equation} \]

那么 \(u(T - w) = u(sw + T - w)|_{s=0}\approx \sum\limits_{k=-N}^{k=N} c_k(T)\),代入 式5 得:

\[ \begin{equation}\frac{d}{dT}c_n(T) \approx \frac{1}{w} u(T) - \frac{1}{w} \sum_{k=-N}^{k=N} c_k(T) + \frac{2i\pi n}{w} c_n(T)\end{equation} \]

\(T\) 换成 \(t\),然后所有的 \(c_n(t)\) 堆在一起记为 \(x(t) = (c_{-N},c_{-(N-1)},\cdots,c_0,\cdots,c_{N-1},c_N)\),并且不区分 \(\approx\)\(=\),那么就可以写出

\[ \begin{equation}x'(t) = Ax(t) + Bu(t),\quad A_{n,k} = \left\{\begin{array}{l}(2i\pi n - 1)/w, &k=n \\ -1/w,&k\neq n\end{array}\right.,\quad B_n = 1/w\end{equation} \]

这就出现了如 式1 所示的线性ODE系统。即当我们试图用傅里叶级数去记忆一个实时函数的最邻近窗口内的状态时,结果自然而言地导致了一个线性ODE系统。

HIPPO推导

一般框架

当然,目前只是选择了一个特殊的 \(t_{\leq T}(s)\),换一个 \(t_{\leq T}(s)\) 就不一定有这么简单的结果了。此外,傅里叶级数的结论是在复数范围内的,进一步实数化也可以,但形式会变得复杂起来。所以,我们要将上一节的过程推广成一个一般化的框架,从而得到更一般、更简单的纯实数结论。

\(t\in[a,b]\),并且有目标函数 \(u(t)\) 和函数基 \(\{g_n(t)\}_{n=0}^N\)我们希望有后者的线性组合来逼近前者,目标是最小化 \(L_2\) 距离:

\[ \begin{equation}\mathop{\text{argmin}}_{c_1,\cdots,c_N}\int_a^b \left[u(t) - \sum_{n=0}^N c_n g_n(t)\right]^2 dt\end{equation} \]

这里我们主要在实数范围内考虑,所以方括号直接平方就行,不用取模。更一般化的目标函数还可以再加个权重函数\(\rho(t)\),但我们这里就不考虑了,毕竟HiPPO的主要结论其实也没考虑这个权重函数。

对目标函数展开,得到

\[ \begin{equation}\int_a^b u^2(t) dt - 2\sum_{n=0}^N c_n \int_a^b u(t) g_n(t)dt + \sum_{m=0}^N\sum_{n=0}^N c_m c_n \int_a^b g_m(t) g_n(t) dt\end{equation} \]

这里我们只考虑标准正交函数基,其定义为 \(\int_a^b g_m(t) g_n(t) dt = \delta_{m,n}\)\(\delta_{m,n}\)克罗内克δ函数,此时上式可以简化成

\[ \begin{equation}\int_a^b u^2(t) dt - 2\sum_{n=0}^N c_n \int_a^b u(t) g_n(t)dt + \sum_{n=0}^N c_n^2 \end{equation} \]

这只是一个关于 \(c_n\) 的二次函数,它的最小值是有解析解的:

\[ \begin{equation}c^*_n = \int_a^b u(t) g_n(t)dt\end{equation} \]

这也被称为 \(u(t)\)\(g_n(t)\) 的内积,它是有限维向量空间的内积到函数空间的平行推广。简单起见,在不至于混淆的情况下,我们默认 \(c_n\) 就是 \(c^*_n\)

接下来的处理跟上一节是一样的,我们要对一般的 \(t\in[0, T]\) 考虑 \(u(t)\) 的近似,那么找一个 \([a,b]\)\([0,T]\) 的映射 \(s\mapsto t_{\leq T}(s)\),然后计算系数

\[ \begin{equation}c_n(T) = \int_a^b u(t_{\leq T}(s)) g_n(s) ds\end{equation} \]

同样是两边求T的导数,然后用分部积分法

\[ \begin{equation}\scriptsize\begin{aligned} \frac{d}{dT}c_n(T) =&\, \int_a^b u'(t_{\leq T}(s)) \frac{\partial t_{\leq T}(s)}{\partial T} g_n(s) ds = \int_a^b \left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s) d u(t_{\leq T}(s)) \\ =&\,\left.u(t_{\leq T}(s))\left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s)\right|_{s=a}^{s=b} - \int_a^b u(t_{\leq T}(s)) \,d\left[\left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s)\right] \end{aligned}\ \ \ \ \ \ \ \end{equation} \]

准备工作—勒让德多项式

接下来的计算,就依赖于 \(g_n(t)\)\(t_{\leq T}(s)\) 的具体形式了。HiPPO的全称是High-order Polynomial Projection Operators,第一个P正是多项式(Polynomial)的首字母,所以HiPPO的关键是选取多项式为基。现在我们请出继傅里叶之后又一位大牛——勒让德(Legendre),接下来我们要选取的函数基正是以他命名的“勒让德多项式”。

勒让德多项式 \(p_n(t)\) 是关于 \(t\)\(n\) 次函数,定义域为 \([-1,1]\),满足

\[ \begin{equation}\int_{-1}^1 p_m(t) p_n(t) dt = \frac{2}{2n+1}\delta_{m,n}\end{equation} \]

所以 \(p_n(t)\) 之间只是正交,还不是标准(平分积分为1),\(g_n(t)=\sqrt{\frac{2n+1}{2}} p_n(t)\)才是标准正交基。

当我们对函数基 \(\{1,t,t^2,\cdots, t^n\}\) 执行施密特正交化时,其结果正是勒让德多项式。相比傅里叶基,勒让德多项式的好处是它是纯粹定义在实数空间中的,并且多项式的形式能够有助于简化部分\(t_{\leq T}(s)\) 的推导过程,这一点我们后面就可以看到。勒让德多项式有很多不同的定义和性质,这里我们不一一展开,有兴趣的读者自行看链接中维基百科介绍即可。

接下来我们用到两个递归公式来推导一个恒等式,这两个递归公式是

\[ \begin{align} p_{n+1}'(t) - p_{n-1}'(t) = (2n+1)p_n(t)\\[5pt] p_{n+1}'(t) = (n + 1)p_n(t) + t p_n'(t)\\ \end{align} \]

由第一个公式16 迭代得到:

\[ \begin{equation}\begin{aligned} p_{n+1}'(t) =&\, (2n+1)p_n(t) + (2n-3)p_{n-2}(t) + (2n-7)p_{n-4}(t) + \cdots \\ =&\, \sum_{k=0}^n (2k+1) \chi_{n-k} p_k(t) \end{aligned}\end{equation} \]

其中当 \(k\) 是偶数时 \(\chi_k=1\) 否则 \(\chi_k=0\)。代入第二个 公式17 得到

\[ \begin{equation}t p_n'(t) = n p_n(t) + (2n-3)p_{n-2}(t) + (2n-7)p_{n-4}(t) + \cdots\end{equation} \]

继而有

\[ \begin{equation}\begin{aligned} (t+1) p_n'(t) =&\, n p_n(t) + (2n-1)p_{n-1}(t) + (2n-3)p_{n-2}(t) + \cdots\\ =&\,-(n+1) p_n(t) + \sum_{k=0}^n (2k + 1) p_k(t) \end{aligned}\end{equation} \]

这些就是等会要用到的恒等式。此外,勒让德多项式满足 \(p_n(1)=1,p_n(-1)=(-1)^n\),这个边界值后面也会用到。

正如n维空间中不止有一组正交基也一样,正交多项式也不止有勒让德多项式一种,比如还有切比雪夫(Chebyshev)多项式,如果算上加权的目标函数(即\(\rho(t)\not\equiv 1\)),还有拉盖尔多项式等,这些在原论文中都有提及,但HiPPO的主要结论还是基于勒让德多项式展开的,所以剩余部分这里也不展开讨论了。

邻近窗口(LegT)

完成准备工作后,我们就可以代入具体的 \(t_{\leq T}(s)\) 进行计算了,计算过程跟傅里叶级数的例子大同小异,只不过基函数换成了勒让德多项式构造的标准正交基 \(g_n(t)=\sqrt{\frac{2n+1}{2}} p_n(t)\)。作为第一个例子,我们同样先考虑只保留最邻近窗口的信息,此时 \(t_{\leq T}(s) = (s + 1)w / 2 + T - w\)\([-1,1]\) 映射到 \([T-w,T]\),原论文将这种情形称为“LegT(Translated Legendre)”。

直接代入 式14,马上得到

\[ \small\frac{d}{dT}c_n(T) = \frac{\sqrt{2(2n+1)}}{w}\left[u(T) - (-1)^n u(T-w)\right] - \frac{2}{w}\int_{-1}^1 u((s + 1)w / 2 + T - w) g_n'(s) ds \]

我们首先处理 \(u(T-w)\) 项,跟傅里叶级数那里同样的思路,我们截断 \(n\leq N\) 作为 \(u((s + 1)w / 2 + T - w)\) 的一个近似:

\[ \begin{equation}u((s + 1)w / 2 + T - w)\approx \sum_{k=0}^N c_k(T)g_k(s)\end{equation} \]

从而有 \(u(T-w)\approx \sum\limits_{k=0}^N c_k(T)g_k(-1) = \sum\limits_{k=0}^N (-1)^k c_k(T) \sqrt{\frac{2k+1}{2}}\)。接着,利用 式18 得到

\[ \begin{equation}\begin{aligned} &\,\int_{-1}^1 u((s + 1)w / 2 + T - w) g_n'(s) ds \\ =&\,\int_{-1}^1 u((s + 1)w / 2 + T - w) \sqrt{\frac{2n+1}{2}} p_n'(s) ds \\ =&\, \int_{-1}^1 u((s + 1)w / 2 + T - w)\sqrt{\frac{2n+1}{2}}\left[\sum_{k=0}^{n-1} (2k+1) \chi_{n-1-k} p_k(s)\right]ds \\ =&\, \int_{-1}^1 u((s + 1)w / 2 + T - w)\sqrt{\frac{2n+1}{2}}\left[\sum_{k=0}^{n-1} \sqrt{2(2k+1)} \chi_{n-1-k} g_k(s)\right]ds \\ =&\, \sqrt{2n+1}\sum_{k=0}^{n-1} \sqrt{2k+1} \chi_{n-1-k} c_k(T) \end{aligned}\end{equation} \]

将这些结果整合起来,就有

\[ \begin{equation}\begin{aligned} \frac{d}{dT}c_n(T) \approx &\, \frac{\sqrt{2(2n+1)}}{w}u(T) - \frac{\sqrt{2(2n+1)}}{w} (-1)^n \overbrace{\sum\limits_{k=0}^N (-1)^k c_k(T) \sqrt{\frac{2k+1}{2}}}^{u(T-w)} \\ &\quad- \frac{2}{w}\overbrace{\sqrt{2n+1}\sum_{k=0}^{n-1} \sqrt{2k+1} \chi_{n-1-k} c_k(T)}^{\int_{-1}^1 u((s + 1)w / 2 + T - w) g_n'(s) ds} \\[12pt] = &\, \frac{\sqrt{2(2n+1)}}{w}u(T) - \frac{\sqrt{2n+1}}{w} \sum\limits_{k=0}^N (-1)^{n-k} c_k(T) \sqrt{2k+1} \\ &\quad- \frac{2}{w}\sqrt{2n+1}\sum_{k=0}^{n-1} \sqrt{2k+1} \chi_{n-1-k} c_k(T) \\[12pt] = &\, \frac{\sqrt{2(2n+1)}}{w}u(T) - \frac{\sqrt{2n+1}}{w} \sum\limits_{k=n}^N (-1)^{n-k} c_k(T) \sqrt{2k+1} \\ &\quad- \frac{\sqrt{2n+1}}{w}\sum_{k=0}^{n-1} \sqrt{2k+1} \underbrace{\left(2\chi_{n-1-k} + (-1)^{n-k}\right)}_{\equiv 1}c_k(T) \\ \end{aligned}\end{equation} \]

再次地,将 \(T\) 换回 \(t\),并将所有的 \(c_n(t)\) 堆在一起记为 \(x(t) = (c_0,c_1,\cdots,c_N)\),那么根据上式可以写出

\[ \begin{equation}\begin{aligned} x'(t) =&\, Ax(t) + Bu(t)\\[8pt] \quad A_{n,k} =&\, -\frac{1}{w}\left\{\begin{array}{l}\sqrt{(2n+1)(2k+1)}, &k < n \\ (-1)^{n-k}\sqrt{(2n+1)(2k+1)}, &k \geq n\end{array}\right.\\[8pt] B_n =&\, \frac{1}{w}\sqrt{2(2n+1)} \end{aligned}\end{equation} \]

我们还可以给每个 \(c_n(T)\) 都引入一个缩放因子,来使得上述结果更一般化。比如我们设 \(c_n(T) = \lambda_n \tilde{c}_n(T)\),代入式23 整理得

\[ \begin{equation}\begin{aligned} \frac{d}{dt}\tilde{c}_n(T) \approx &\, \frac{\sqrt{2(2n+1)}}{w\lambda_n}u(T) - \frac{\sqrt{2n+1}}{w} \sum\limits_{k=n}^N (-1)^{n-k} \tilde{c}_k(T) \frac{\lambda_k\sqrt{2k+1}}{\lambda_n} \\ &\quad- \frac{\sqrt{2n+1}}{w}\sum_{k=0}^{n-1} \frac{\lambda_k\sqrt{2k+1}}{\lambda_n} \tilde{c}_k(T) \\ \end{aligned}\end{equation} \]

如果取 \(\lambda_n = \sqrt{2}\),那么 \(A\) 不变,\(B_n = \frac{1}{w}\sqrt{2n+1}\),这就对齐了原论文的结果,如果取 \(\lambda_n = \frac{2}{\sqrt{2n+1}}\),那么就得到了 Legendre Memory Units 中的结果

\[ \begin{equation}\begin{aligned} x'(t) =&\, Ax(t) + Bu(t)\\[8pt] \quad A_{n,k} =&\, -\frac{1}{w}\left\{\begin{array}{l}2n+1, &k < n \\ (-1)^{n-k}(2n+1), &k \geq n\end{array}\right.\\[8pt] B_n =&\, \frac{1}{w}(2n+1) \end{aligned}\end{equation} \]

这些形式在理论上都是等价的,但可能存在不同的数值稳定性。比如一般来说当 \(u(t)\) 的性态不是特别糟糕时,我们可以预期 \(n\) 越大,\(|c_n|\) 的值就相对越小,这样直接用 \(c_n\) 的话 \(x(t)\) 向量的每个分量的尺度就不大对等,这样的系统在实际计算时容易出现数值稳定问题,而取 \(\lambda_n = \frac{2}{\sqrt{2n+1}}\) 改用 \(\tilde{c}_n\) 的话意味着数值小的分量会被适当放大,可能有助于缓解多尺度问题从而使得数值计算更稳定。

整个区间(LegS)

现在我们继续计算另一个例子:\(t_{\leq T}(s) = (s + 1)T / 2\),它将 \([-1,1]\) 均匀映射到 \([0,T]\),这意味着我们没有舍弃任何历史信息,并且平等地对待所有历史,原论文将这种情形称为“LegS(Scaled Legendre)”。

同样地,通过代入 式14 得到