Reading

离散扩散模型用于文本生成

引言

Diffusion模型近年来在图像生成这一连续域任务中取得了显著成果,展现出强大的生成能力。然而,在文本生成这一离散域任务中整体效果仍不尽如人意,未能在该领域引起广泛关注。

去年,一篇研究离散扩散模型在文本生成的文章《Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution》获得ICML 2024的Best Paper,引发了学术界的广泛兴趣,也激发了新一轮的研究热潮。随后在2025年,越来越多高校和企业也开始积极探索基于Diffusion的文本生成方法。其中,近期备受关注的Block Diffusion也成功入选ICLR oral,进一步推动了该方向的发展。

对于文本生成,此时的空间不是连续的,而是离散的,所以Diffusion连续加噪的方式在处理离散数据变得不可用。有一种直观的方法是把离散数据编码成连续的,然后套用DDPM的连续加噪去噪模型,去噪后得到的连续域的编码再映射回离散域中。这种方式感觉有点暴力,而且对离散数据的处理也不够优雅,虽然有些也取得了一定成效,但不在本文的讨论范围内。另外一种就是直接在离散域上进行扩散,即 \(p(\mathbb x_t|\mathbb x_{t-1})\) 表示离散域上的概率,每一步都是在离散域中取值,更加直接,本文重点讨论这种方式的工作。

本文简单记录一下几个比较经典的研究:D3PM,Concrete Score Matching,A Continuous Time Framework for Discrete Denoising Models,Score Entropy Discrete Diffusion models(SEDD)和Block Diffusion。重点聚焦于这些方法的定义,损失函数的推导等理论框架。

D3PM

D3PM(Discrete Denoising Diffusion Probabilistic Models)发表于NeurIPS 2021,是一个非常经典的离散扩散模型。具体来说,它利用离散时间马尔可夫链,直接在离散空间上进行扩散

以下先考虑一维(单词)的情况,多维(句子)的情况可以看成是独立同分布的。对于具有 \(K\) 个类别的标量离散随机变量 \(x_t, x_{t-1} \in {1, ..., K}\) ,其前向转移概率可以通过矩阵表示: \([Q_t]_{ij} = q(x_t = j \mid x_{t-1} = i)\) 。这里转移矩阵的定义与 \(t\) 有关,所以D3PM的马尔可夫链是非齐次的,这样更加一般化,不过在这里非齐次的讨论和齐次的区别不大。如果将 \(x_{t-1}\) 编码成 \(K\) 维的独热向量(1的索引就代表所属的类别),那么有:

\[q(x_t\mid x_{t−1})=x_{t−1}Q_t\tag{1.1}\]

注意这里的 \(q(x_t\mid x_{t−1})\) 也是 \(K\) 维的向量,每个维度表示对应类别的概率;如果是 \(q(x_t=j\mid x_{t−1})\) ,那就是标量,表示 \(x_{t-1}\) 这个状态到 \(j\) 状态的概率。在实际操作中,通常选择 \(q(x_t\mid x_{t−1})\) 中概率最大的值的索引作为 \(x_t\) 的类别,因此这个采样通过argmax函数。或对于概率分布而言有:

\[q(x_t)=q(x_{t−1})Q_t\]

对于初始状态 \(x_0\) ,如果想求得 \(q(x_t\mid x_0)\) ,只需要通过:

\[q(x_t\mid x_{0})=x_{0}\overline Q_t\]

其中 \(\overline Q_t=Q_1Q_2...Q_t\) 。这是因为:

\[x_{0}\overline Q_t=x_0Q_1Q_2...Q_t=q(x_1\mid x_0)Q_2...Q_t=q(x_2\mid x_0)...Q_t=q(x_t\mid x_{0})\\\]

和DDPM类似,如果优化KL散度的ELBO,那么最终的损失函数为:

\[ \mathcal{L} =\sum_{t}KL(q(x_{t-1}\mid x_t,x_0)\|q_\theta(x_{t-1}\mid x_t))\]

左边 \(q(x_{t-1}\mid x_t,x_0)\) 可以用公式:

\[\begin{aligned} q(x_{t-1} | x_t, x_0) &= \frac{q(x_t \mid x_{t-1}, x_0) q(x_{t-1} \mid x_0)}{q(x_t \mid x_0)} \\ &= \frac{q(x_t \mid x_{t-1}) q(x_{t-1} \mid x_0)}{q(x_t \mid x_0)} \\ &= \frac{x_t Q_t^\top \odot x_0 \overline{Q}_{t-1}}{x_0 \overline Q_{t} x_t} \end{aligned}\tag{1.2}\]

关键看第三个等号。首先要明确的是 \(x_t, x_0\) 为给定的独热向量, \(q(x_{t-1} | x_t, x_0)\) 是 \(K\) 维的 \(x_{t-1}\) 概率分布,所以等式的右边应该是用 \(x_t, x_0\) 来表示 \(x_{t-1}\) 。所以分子的左边 \(q(x_t \mid x_{t-1})\) 是不能用(1.1)的,因为 \(x_{t-1}\) 是不确定的,而 \(x_t\) 才是确定给定的。事实上,这个概率展开应该是 \(q(x_t \mid x_{t-1})=[q(x_t \mid x_{t-1}=1),q(x_t \mid x_{t-1}=2),...,q(x_t \mid x_{t-1}=K)]\) ,这恰好是转移矩阵 \(Q_t\) 的第 \(x_t\) 列,表示成行向量形式则为 \(Q^\top_t\) 的第 \(x_t\) 行,因此有 \(q(x_t \mid x_{t-1})=x_t Q_t^\top\) 。分子的右边显然。而对于分母 \(q(x_t \mid x_0)\) 来说, \(x _t, x_0\) 是已经确定的,所以这个概率表示的就是转移矩阵 \(\overline Q_t\) 的第 \(x_0\) 行,第 \(x_t\) 列的元素,所以有 \(q(x_t \mid x_0)=x_0 \overline Q_{t} x_t\) 。

右边 \(q_\theta(x_{t-1}\mid x_t)\) 跟DDPM类似,我们不直接输入输入 \(x_t\) 预测 \(x_{t-1}\) 的 \(K\) 维概率分布,而是先预测 \(x_0\) ,然后再用前向得到 \(x_{t-1}\) 。具体的有:

\[q_\theta(x_{t-1}\mid x_t)=\sum_{x_0\in K}q(x_{t-1}\mid x_t,x_0)q_\theta(x_0\mid x_t)\tag{1.3}\]

也就是神经网络输入 \(x_t\) ,预测得到 \(x_0\) 每个类别的概率值,每个类别的预测值乘以对应类别的 \(q(x_{t-1}\mid x_t,x_0)\) 再累加。由于直接从 \(x_t\) 预测 \(x_0\) ,所以还加入了 \(x_t\) 到 \(x_0\) 再到 \(x_t\) 的重构损失,实验发现有利于提升图像质量:

\[ \mathcal{L} =\sum_{t}KL(q(x_{t-1}\mid x_t,x_0)\|q_\theta(x_{t-1}\mid x_t))+\lambda \mathbb{E}_{q(x_0)} \mathbb{E}_{q(x_t | x_0)} \left[ - \log q_\theta (x_0 | x_t) \right]\tag{1.4}\]

对于向量来说,(1.4)实际上就是计算交叉熵,因此最后(1.2)(1.3)的交叉熵再加上 \(\lambda\) 倍的 \(x_0\) 和预测 \(x_0\) 的概率分布的交叉熵作为损失函数训练即可。推理过程同样也是使用(1.3),得到逆向分布后,使用argmax函数找到概率最大的索引作为 \(x_{t-1}\) 。

接下来还有一个非常重要的点,就是关于转移矩阵 \(Q_t\) 的选取。我们的目的是想把 \(x_0\) 扩散到我们熟悉的先验分布中,然后再从先验分布中采样逆向生成,对于离散分布来说,均匀分布是一种非常常见的分布,因此如果令:

\[[Q_t]_{ij} = \begin{cases} 1 - \frac{K-1}{K} \beta_t & \text{if } i = j \\ \frac{1}{K} \beta_t & \text{if } i \neq j \end{cases}\tag{1.5} \]

其中 \(\beta_t\rightarrow1\) 。当 \(t\) 很大时, \(Q_t\) 的每一个元素,所有的转移概率,几乎都是 \(1/K\) ,这说明不管当前状态是啥,再经过一次扩散后下一次状态出现的概率都是均匀分布的,也就是能把 \(p(x_0)\) 扩散到均匀分布,那么采样的时候就从 \(K\) 个状态里面均匀选取,然后逆向生成即可。更简单的,我们还可以把(1.5)写作 \(Q_t = (1 - \beta_t) I + \frac{\beta_t}{K} \mathbf{1} \mathbf{1}^T \) 。其中 \(\mathbf 1\) 表示元素全为1的向量。

此外,我们还可以令:

\[[Q_t]_{ij} = \begin{cases} 1 & \text{if } i = j = m \\ 1 - \beta_t & \text{if } i = j \neq m \\ \beta_t & \text{if } j = m, i \neq m \\ 0, &\text{otherwise} \end{cases}\tag{1.6} \]

其中 \(m\in[0,K]\) 。为了更直观理解 \(Q_t\) 的性质,我们把它展开:

\[Q_t=\left[ \begin{matrix} 1-\beta _t& 0& \cdots& \beta _t& \cdots& 0\\ 0& 1-\beta _t& \cdots& \beta _t& \cdots& 0\\ 0& 0& & \vdots& & 0\\ 0& 0& \cdots& 1& \cdots& 0\\ 0& 0& & \vdots& & 0\\ 0& 0& \cdots& \beta _t& \cdots& 1-\beta _t\\ \end{matrix} \right] \tag{1.7}\]

可以看到,对于某一个当前不属于 \(m\) 的状态,它在转移的时候只有两种选择,要么不转移,要么转移到 \(m\) 状态;若当前已经处于 \(m\) 状态了,则会一直保留在 \(m\) 状态,就好像被“吸收”住了,所以这种情况下把 \(m\) 称作吸收态(Absorbing State)。当 \(t\) 比较大时, \(\beta_t\rightarrow1\) ,所以状态最终都会处于吸收态,先验分布就是吸收态,采样的时候就从吸收态开始,再应用逆向过程即可。实际操作中,一般会令吸收态为[Mask],因此一个单词最终会被扩散到一个[Mask],在这种场景下跟自回归中的BERT模型思想其实也是非常接近的,BERT就可以看成是一步吸收态的D3PM模型。更简单的,我们还可以把(1.6)写作 \((1 - \beta_t)I + \beta_t 1 e_m^T \) ,其中 \(e_m^T\) 表示第 \(m\) 个元素为1,其余元素为0的单位行向量。

介绍完了一维(单词)的情况,那么多维(句子)实际上就是输入神经网络的时候把 \([1*K]\) 换成 \([L*K]\) 即可, \(L\) 是句子的长度,输出也是同样的维度。前向扩散过程则为独立同分布的。另外,在定义好了 \(Q_t\) 后,可以事先把所有的 \(\overline Q_{t}\) 算出来,节省训练和推理的时间复杂度。

PixPin_2026-01-20_10-56-38.png
a)为吸收态扩散 b)为均匀扩散

可以看看具体的实验结果:

PixPin_2026-01-20_10-57-48.png
在LM1B上训练,128步,吸收态token为[Mask]
PixPin_2026-01-20_10-59-56.png
在LM1B上训练,1000步,均匀分布

可以看到,实际效果有点不太行,感觉句子有点莫名奇妙的,毕竟也是早期工作了。不过可以明显发现,这种扩散模型做文本生成,是去噪后直接生成一整段,而不像自回归是一个一个生成。哪怕句子非常长,有2w个字符,自回归可能就需要2w步,但是扩散模型仍然只需要128步或者1000步的去噪,因此还是很有研究潜力的。

Concrete Score Matching

在连续空间中有Score(概率分布的梯度)的概念,那么在离散空间有没有类似的“Score”?因为有了Score就可以应用某些方式比如连续空间中的Langevin采样,直接从概率分布直接采样,实现生成。而斯坦福NeurIPS 2022的一篇文章就介绍了离散空间中的Concrete Score。这篇文章的方法并不算扩散模型,没有承接D3PM,只是传统的离散生成模型。虽然现在几乎也没人使用这样的生成方法,但概念还是比较重要,并在后面工作也有所提及,具有启发式意义,因此就简单介绍一下。

令 \(p_{\text{data}}(\mathbf x)\) 为数据分布,这里的 \(\mathbf x\) 是多维的,可以表示一个句子,并假设 \(\mathbf x\in \mathcal{X}\) 。我们用 \(\mathcal{N}: \mathcal{X} \to \mathcal{X}^K\) 来表示将每个数据 \(\mathbf x \in \mathcal{X}\) 映射到一组邻居的函数,使得 \(\mathcal{N}(\mathbf x) = \{ \mathbf x_{n1}, \dots, \mathbf x_{nk} \}\) 。这个邻居关系会诱导出一种特定的图结构,我们称之为“邻居诱导图”,它将在构建替代梯度时发挥关键作用,以下是正式定义:

定义 2.1(邻居诱导图):令 \(p_{\text{data}}(\mathbf x)\) 为数据分布, \(\mathcal{N}\) 为映射每个节点 \(\mathbf x\in \mathcal{X}\) 到其邻居集的函数。邻居诱导图 \(G\) 是通过为每个 \(\mathbf x\) 添加一条从 \(\mathbf x\) 到其邻居集 \(\mathcal{N}(\mathbf x)\) 中每个节点的有向边而产生的有向图。

一个重要的点是,邻居结构可以是非对称的,因为邻居诱导图是一个有向图。这意味着可能存在这种情况,即 \(\mathcal{N}(\mathbf x_1) = \{ \mathbf x \} \) 并不一定意味着 \(\mathcal{N}(\mathbf x) = \{ \mathbf x_1 \}\) 。比如下图中的星星图,显然边点的邻居是中心点,但是中心点的邻居不是边点。

PixPin_2026-01-20_11-32-17.png
定义 2.2(Concrete Score):令 \(\mathcal{N}\) 是一个将每个点 \(\mathbf x\in \mathcal{X}\) 映射到其邻居集合的函数 \(\mathcal{N}(\mathbf x) = \{ \mathbf x_{n1}, \dots, \mathbf x_{nk} \}\) 。Concrete Score \(c_{p_{\text{data}}}(\mathbf x; \mathcal{N}):\mathcal X\rightarrow\mathbb{R}^{|\mathcal{N}(\mathbf x)|}\) 定义为:
\[c_{p_{\text{data}}}(\mathbf x; \mathcal{N}) \triangleq \left[ \frac{p_{\text{data}}(\mathbf x_{n_1}) - p_{\text{data}}(\mathbf x)}{p_{\text{data}}(\mathbf x)}, \dots, \frac{p_{\text{data}}(\mathbf x_{n_k}) - p_{\text{data}}(\mathbf x)}{p_{\text{data}}(\mathbf x)} \right]^T\tag{2.1}\]

容易发现:

\[c_{p_{\text{data}}}(\mathbf x; \mathcal{N}) +1= \left[ \frac{p_{\text{data}}(\mathbf x_{n_1})}{p_{\text{data}}(\mathbf x)}, \dots, \frac{p_{\text{data}}(\mathbf x_{n_k})}{p_{\text{data}}(\mathbf x)} \right]^T\tag{2.2}\]

每个元素表示的是相邻数据的概率分布的比值。为啥要定义这种概率的比值呢?因为假如已知 \(p(\mathbf x_T)\) ,假设有条路径 \(\mathbf x_T\rightarrow \mathbf x_{T-1}\rightarrow...\rightarrow \mathbf x_0\) ,那么我们就可以通过 \(p(\mathbf x_0)=p(\mathbf x_T)\frac{p(\mathbf x_{T-1})}{p(\mathbf x_T)}...\frac{p(\mathbf x_0)}{p(\mathbf x_1)}=p(\mathbf x_T)(c_{p_\theta}(\mathbf x_T; \mathcal{N})_{\mathbf x_{T-1}} +1)...(c_{p_\theta}(\mathbf x_1; \mathcal{N})_{\mathbf x_0} +1)\) 来得到 \(p(\mathbf x_0)\) ,或者任意连通的点的概率分布。那为什么不直接定义成 \( \left[ p_{\text{data}}(\mathbf x_{n_1}), \dots, p_{\text{data}}(\mathbf x_{n_k}) \right]^T\) 呢?因为对离散概率来说,这样定义的话各个维度中数据的概率值实际上就是生成模型,我们的最终目标就是这个。换句话说,如果能知道每个句子的概率值,那么直接采样就好了,所以定义成这样有点执果索因的意味了。事实上,这种比率的定义方式除了能够被学习,间接的进行采样生成,还能够与连续空间下的Score产生联系,后面后提到。此外,这样定义的Concrete Score是满足完整性(Completeness)的,即,若通过数据学到的 \(c_{\theta}(\mathbf x; \mathcal{N}) = c_{p_{\text{data}}}(\mathbf x; \mathcal{N}) \) ,那么 \(p_\theta(\mathbf x)=p_{\text{data}}(\mathbf x)\) 。简单理解,是因为Concrete Score能够唯一确定相邻节点概率的比率,那么任意两个连通节点之间的比率也能被唯一确定。因此在此之前,我们需要人为的对各个句子建立邻居关系,这样才能计算出Concrete Score。如果建立成完全图,即每个句子都跟其他所有句子相连,显然是不现实的,因为句子空间很大,Concrete Score的维度会非常大,不利于学习。一种比较现实的方式将两个仅仅只相差一个单词的句子建立邻居关系,比如“I love cat”和“I love dog”建立关系,那么不但固定住了Concrete Score的维度(规定句子长度为固定值 \(l\) ),而且也是连通的,因为两个完全不同的句子可以通过一个一个单词的替换得到另一个。

接下来还有两个问题,第一要怎么去训练得到 \(c_{\theta}(\mathbf x; \mathcal{N}) \) ,第二要怎么从 \(c_{\theta}(\mathbf x; \mathcal{N}) \) 采样得到具体的样本 \(\mathbf x\) 。

首先看采样。假设已经学好了 \(c_{\theta}(\mathbf x; \mathcal{N}) \approx c_{p_{\text{data}}}(\mathbf x; \mathcal{N}) \) ,由于我们得到的是各个状态的分布的比率,那么就可以用Metropolis-Hastings 算法。该算法是一种基于马尔可夫链蒙特卡洛(MCMC)方法的采样技术,基本思想是通过构造一个马尔可夫链,使其最终的稳态分布是我们想要的目标分布。操作也是非常简单,首先从状态空间 \( \mathcal{X}\) 中随意选取一个初始点 \(\mathbf x_0\) 。接着定义一个提议分布 \(q(x'\mid x_t)\) ,即给定当前状态,问下一个可能状态 \(x'\) 的概率,这个可以随意定义,为了对称性,我们一般定义为所有邻居中的均匀分布,这说明给定当前状态,到下一刻状态的转移是均匀的。最后还要以 \(A(\mathbf x' \mid \mathbf x_t) = \min \left( 1, \frac{p(\mathbf x') q(\mathbf x_t \mid \mathbf x')}{p(\mathbf x_t) q(\mathbf x' \mid \mathbf x_t)} \right) \) 的概率接受并更新状态,否则不改变状态。由于对称性,可以简化为 \(A(\mathbf x' \mid \mathbf x_t) = \min \left( 1, \frac{p(\mathbf x') }{p(\mathbf x_t) } \right)\) ,也就是说当 \(\frac{p(\mathbf x') }{p(\mathbf x_t) } \) 比较大时,状态会更倾向于转到 \(\mathbf x'\) ,不断重复这个过程。随着迭代次数的增加,Metropolis-Hastings 算法的马尔可夫链将逐渐收敛到目标分布。尽管算法的每一步可能会偏离目标分布,但通过多次迭代,采样结果会趋近于目标分布。这种采样方式跟连续空间下的Langevin采样有类似的味道,也是定义了一个马尔可夫采样链,无限迭代,最终趋于稳态分布。

接下来考虑训练。非常朴实无华,直接考虑MSE作为损失函数即可:

\[\mathcal{L}(\theta) = \sum_{\mathbf{x}} p_{\text{data}}(\mathbf{x}) \parallel c_{\theta}(\mathbf{x}; \mathcal{N}) - c_{p_{\text{data}}}(\mathbf{x}; \mathcal{N}) \parallel_2^2\tag{2.3}\]

显然 \(c_{p_{\text{data}}}(\mathbf{x}; \mathcal{N}) \) 的计算是比较复杂的,我们可以学习得分匹配中的思想,跟《概率视角下的生成模型》中(3.27)式类似,把上式改写成与 \(c_{p_{\text{data}}}(\mathbf{x}; \mathcal{N}) \) 无关的损失函数:

\[\begin{aligned} \mathcal{L}(\theta) &= \sum_{\mathbf{x}} p_{\text{data}}(\mathbf{x}) \parallel c_{\theta}(\mathbf{x}; \mathcal{N}) - c_{p_{\text{data}}}(\mathbf{x}; \mathcal{N}) \parallel_2^2 \\ &= \sum_{\mathbf{x}} p_{\text{data}}(\mathbf{x}) \left[ \parallel c_{p_{\text{data}}}(\mathbf{x}; \mathcal{N}) \parallel_2^2 - 2 c_{\theta}(\mathbf{x}; \mathcal{N})^T c_{p_{\text{data}}}(\mathbf{x}; \mathcal{N}) + \parallel c_{\theta}(\mathbf{x}; \mathcal{N}) \parallel_2^2 \right] \\ &= \sum_{\mathbf{x}} p_{\text{data}}(\mathbf{x}) \left[ \parallel c_{\theta}(\mathbf{x}; \mathcal{N}) \parallel_2^2 - 2 c_{\theta}(\mathbf{x}; \mathcal{N})^T c_{p_{\text{data}}}(\mathbf{x}; \mathcal{N}) \right] +C\\ &= \ \sum_{\mathbf{x}} \sum_{i=1}^{|\mathcal{N}(\mathbf{x})|} p_{\text{data}}(\mathbf{x}) \left( c_{\theta}(\mathbf{x}; \mathcal{N})^2_i + 2 c_{\theta}(\mathbf{x}; \mathcal{N})_i \right) - \sum_{\mathbf{x}} \sum_{i=1}^{|\mathcal{N}(\mathbf{x})|} 2 p_{\text{data}}(\mathbf{x}_{n_i}) c_{\theta}(\mathbf{x}; \mathcal{N})_i \\ &= \mathcal{J}_{CSM}(\theta) \end{aligned}\tag{2.4}\]

其中第四个等号就是把 \(c_{p_{\text{data}}}(\mathbf x; \mathcal{N})\) 用(2.1)展开。(2.4)式就避免了计算复杂的 \(c_{p_{\text{data}}}(\mathbf{x}; \mathcal{N}) \) 。在实际操作中,左边可以用以下步骤进行无偏估计:随机取 \(\mathbf x\) ,然后从它的邻居中均匀选择一个 \(\mathbf x_{n_i}\) ,计算 \(|\mathcal N(\mathbf x)|\cdot(c_{\theta}(\mathbf{x}; \mathcal{N})^2_i + 2 c_{\theta}(\mathbf{x}; \mathcal{N})_i)\) ,这样就不用计算每一个邻居的值了。对于右边,因为要先从邻居 \(p_{\text{data}}(\mathbf{x}_{n_i})\) 中采样,所以得逆着来:我们先要随机取 \(\mathbf x'\) ,然后从哪些点的邻居是 \(\mathbf x'\) 的集合中再均匀选择一个 \(\mathbf x\) ,这样的集合我们记作 \(\mathcal N^{-1}(\mathbf x')=\{(\mathbf x,i),\mathcal N(\mathbf x)_{i}=\mathbf x'\}\) ,再计算 \(|\mathcal N^{-1}(\mathbf x')|\cdot 2c_{\theta}(\mathbf{x}; \mathcal{N})_i\) 。

最后,再来考虑一下Concrete Score和Score之间的关系。

定理 2.1:给定一个 \(D\) -维连续数据分布 \(p(\mathbf{x})\) 和 \(\delta > 0\) ,我们定义一个特定的邻居结构: \(\mathcal{N}(\mathbf{x}) = \{ \mathbf{x}_{n_i} \}_{i=1}^D\) ,其中 \(\mathbf{x}_{n_i} = \mathbf{x} + \delta e_i\) 且 \(e_i \) 是标准的(one-hot)基向量。则有:
\[c_{\theta}(\mathbf{x}, \mathcal{N}) / \delta = \frac{1}{p(\mathbf{x})} \left[ \frac{p(\mathbf{x} + \delta e_1) - p(\mathbf{x})}{\delta}, \frac{p(\mathbf{x} + \delta e_2) - p(\mathbf{x})}{\delta}, \dots, \frac{p(\mathbf{x} + \delta e_D) - p(\mathbf{x})}{\delta} \right]^T\tag{2.5}\]

从公式中,我们可以看到 \(\left[ \frac{p(\mathbf{x} + \delta e_1) - p(\mathbf{x})}{\delta}, \dots, \frac{p(\mathbf{x} + \delta e_D) - p(\mathbf{x})}{\delta} \right]^T\) 近似的表示了对 \(p(\mathbf{x})\) 的方向导数。因此,缩放的Concrete得分函数 \(c_{\theta}(\mathbf{x}, \mathcal{N}) / \delta\) 在 \(\delta \to 0\) 时:

\[\lim_{\delta \to 0} \frac{c_{\theta}(\mathbf{x}, \mathcal{N})}{\delta} = \frac{1}{p(\mathbf{x})} \lim_{\delta \to 0} \left[ \frac{p(\mathbf{x} + \delta e_1) - p(\mathbf{x})}{\delta}, \dots, \frac{p(\mathbf{x} + \delta e_D) - p(\mathbf{x})}{\delta} \right]^T = \frac{1}{p(\mathbf{x})}\nabla_{\mathbf{x}}p(\mathbf{x}) = \nabla_{\mathbf{x}} \log p(\mathbf{x})\tag{2.6}\]

这正是Score。这是因为,这种特殊的邻居结构 \(\mathcal{N}(\mathbf{x}) = \{ \mathbf{x}_{n_i} \}_{i=1}^D\) 其实就是 \(D\) 维的网格,由于 \(\delta\) 的引入,每个维度的取值变得连续了起来,因此现在的数据就可以看做是分布在 \(D\) 维直角坐标中,也就和 \(D\) 维连续空间中的得分联系了起来。

Continuous Time Framework

在扩散模型图像生成的发展中,DDPM最先被提出,它基于离散时间步扩散框架。随后的VP-SDE则推广到了连续时间中,连续时间框架是兼容离散时间的,而且还允许我们使用连续的微积分工具进行更深度的处理。因此类似的,我们也可以把离散时间的D3PM框架推广到连续时间的框架上,而D3PM实际上就是基于离散时间马尔可夫链的,那么自然的就可以在连续时间马尔可夫链上构造生成模型。NeurIPS 2022上的一篇文章就介绍了连续时间马尔可夫链的生成框架,用于离散数据的生成。

一维情况

和D3PM一样,同样从一维数据开始考虑。假设 \(\tilde x\) 表示 \(t\) 时刻的状态,状态空间为 \(S\) ,那么 \(\tilde x\) 到下一个状态 \(x\) 的转移概率为 \(p_t(x\mid \tilde x)\) 。而这个转移概率可以用转移速率矩阵 \(R_t(\tilde x,x)\in \mathbb{R}^{S\times S}\) 来近似表示,我们考虑一段极小的时间 \(\Delta t\) ,根据转移速率矩阵的定义,很容易可以写出前向过程:

\[q_{t+\Delta t|t}(x\mid \tilde{x}) = \delta_{\tilde{x}, {x}} + R_t(\tilde{x}, {x})\Delta t + o(\Delta t)\tag{3.1}\]

其中 \(o(\Delta t)\) 代表那些比 \(\Delta t\) 下降得更快的项。与离散时间的情况相比,我们看到 \(R_t\) 在定义前向过程时类似于离散时间前向核 \(q_{t+1|t}\) 。因此,就像在离散时间中一样,我们会设计合适的 \(R_t\) 使得前向过程快速地到一个容易采样(平稳)分布 \(p_{\text{ref}}\) ,例如均匀分布或者全是Mask的吸收态分布;

马尔可夫链的转移概率矩阵:
在离散齐次的情况下,我们只需要定义单位转移矩阵 \(P\) ,就能通过幂运算求得 \(P^{(n)}\) 。由于连续时间没有最小单位时间的概念,所以也没有步数的概念,那么有没有什么概念能取代它?自然我们会联想到,对于一段时间,速率是描述单位的最好概念。因此转移速率矩阵定义为:
\[Q=\lim_{h\rightarrow0}\frac{P(h)-P(0)}{h}=\lim_{h\rightarrow0}\frac{P(h)-E}{h}=P'(0)\]
观察 \(Q\) 中的每个元素,对于对角线上的元素,有 \(q_{ii}=\lim\limits_{h\rightarrow0}\frac{p_{ii}(h)-1}{h}\) ,非对角线上的元素有 \(q_{ij}=\lim\limits_{h\rightarrow0}\frac{p_{ij}(h)}{h}\) 。可以发现,对角线上的元素都是负数,非对角线上的为正数。另外这两个极限是存在的,从直观上感受,在 \(h\) 接近于0时,转移概率 \(p_{ii}(h)\) 是趋向于1的,而 \(q_{ij}(h)\) 是趋近于0的,这是因为对于连续情况来说,0时刻就是不转移,然后随着时间的流逝,会慢慢向其它其它状态转移,那么自己到自己的概率就会逐渐减小,到其它的概率会逐渐增大,并且这也解释了为什么对角线上的元素都是负数而其它元素都是正数。严格证明这里不介绍了。另外,由于目前只考虑齐次的,与当前时间无关,所以转移速率矩阵 \(Q\) 只需要考虑0时刻的导数即可。由于 \(\sum_{j\in I}p_{ij}(h)=1\) ,所以有:
\[q_{ii}=-\sum_{i\ne j}q_{ij}\]
如果把 \(q_{ii}\) 看做状态流出的速率,那么总流出的速率等于流入各个状态速率的和,这个性质也在后面的公式化简中经常出现。

有了转移速率矩阵 \(Q\) 后,自然的就会想要求得任意的 \(P(t)\) 。

而且我们可以通过一些方式(后续会介绍)直接获得 \(q_{t|0}(x_t|x_0)\) 分布,以便高效训练。对于逆向过程,通过逆向转移速率矩阵 \(\hat R_t \in \mathbb{R}^{S\times S}\) ,给出:

\[q_{t|t+\Delta t}( \tilde x\mid{x}) = \delta_{{x}, \tilde{x}} + \hat R_t({x}, \tilde{x})\Delta t + o(\Delta t)\tag{3.2}\]

所以问题的关键就是估计逆向转移速率矩阵 \(\hat R_t\) ,直接对这个矩阵参数化再用神经网络估计是比较困难的,因为接受的输入和输出空间都是 \(S\times S\) ,范围太大了。事实上,由于贝叶斯公式有 \(p_t(\tilde x\mid x)=p_t(x\mid \tilde x)\frac{p_t(\tilde x)}{p_t(x)}\) ,所以正向和逆向转移速率矩阵之间有如下关系:

\[\begin{align} \hat R_t(x, \tilde x)&=R_t(\tilde x, x)\frac{p_t(\tilde x)}{p_t(x)}\\ &=R_t(\tilde x, x)\sum_{x_0\in S}\frac{p_t(\tilde x, x_0)}{p_t(x)}\\ &=R_t(\tilde x, x)\sum_{x_0\in S}\frac{p_0(x_0)}{p_t(x)}p_{t|0}(\tilde x\mid x_0)\\ &=R_t(\tilde x, x)\sum_{x_0\in S}\frac{p_0(x_0)p_{t|0}(x\mid x_0)}{p_t(x)p_{t|0}(x\mid x_0)}p_{t|0}(\tilde x\mid x_0)\\ &=R_t(\tilde x, x)\sum_{x_0\in S}\frac{p_{0|t}(x_0\mid x)}{p_{t|0}(x\mid x_0)}p_{t|0}(\tilde x\mid x_0)\\ \end{align}\tag{3.3}\]

注意上式推导的时间关系。并且注意逆向速率矩阵的逆向代表的是时间上的逆向,而不是状态上的。其中前向的 \(R_t(\tilde x, x)\) 和 \(p_{t|0}\) 能直接得到,所以参数化逆向转移矩阵只需要参数化逆向转移概率 \(p_{0|t}(x_0\mid x)\) 为 \(p_{0|t}^{\theta}(x_0\mid x)\) 即可,这回到了D3PM,甚至DDPM的思想,从 \(x_t\) 预测 \(x_0\) 。

接下来,则是最关键的损失函数。首先记由前向转移速率矩阵 \(R_t\) 定义的路径测度为 \(\mathbb Q\) ,这里路径表示所有时刻的联合状态,所以路径的样本空间就是所有时刻状态的无穷维联合分布。比如在离散时间中, \((x_1,x_1,x_3)\) 表示3个时刻的某个路径,在连续时间中,就变成无穷维了,但注意由于状态空间是离散随机变量,所以无穷维联合分布还是离散的,因此在上面的测度 \(\mathbb Q\) 就是“概率值”,不存在概率密度。同样,我们定义 \(\mathbb {\hat Q}\) 为 \(\hat R_t\) 逆向路径测度, \(\mathbb {\hat P}^{\theta}\) 为 \(\hat R_t^{\theta}\) 的逆向测度,他们两个的区别是一个是真实的,一个是学习到的。另外 \(\mathbb Q\) 和 \(\mathbb {\hat Q}\) 的区别在于 \(\mathbb Q\) 的 \(p_T(x_T)\) 和 \(\mathbb {\hat Q}\) 的 \(p_{\text{ref}}(x_T)\) 的不同

有了这些记号后,首先考虑负对数似然:

\[- \log p_0^\theta(x_0) = - \log \int p_{\text{ref}}(\mathrm{d}x_T) \int_{\{\hat{W}_T = x_0\}} \mathbb{P}^{\theta, x_T}(\mathrm{d}w)\tag{3.4}\]

其中 \(p_{\text{ref}}(\mathrm{d}x_T)\) 就表示 \(\mathrm{d}p_{\text{ref}}\) ,只不过前者能看出是在 \(x_T\) 空间上的测度积分; \(\mathbb{P}^{\theta, x_T}\) 表示以 \(x_T\) 为起点, \(W\) 表示路径, \(\hat{W}_T = x_0\) 表示终点为 \(x_0\) 的逆向路径,所以 \(\int_{\{\hat{W}_T = x_0\}} \mathbb{P}^{\theta, x_T}(\mathrm{d}w)\) 中的路径 \(w\) 的表示的是 \((0,T)\) 间的无穷维联合分布,不包括 \(0,T\) 。如果换成离散的初等形式,其实就是熟悉的全概率公式和条件概率公式: \(- \log \int p_{\text{ref}}(x_T)\mathrm{d}x_T \int \mathbb{P}^{\theta}(x_{0:T-1}\mid x_T)\mathrm{d}x_{1:T-1}\) ,只不过对于无穷维的路径积分,写成(3.4)更加严谨。当然要严谨证明无穷维乘积空间的(正则)条件概率的性质还需要用到一些额外的工具,不过理解上还是非常容易的。

由于 \(p_{\text{ref}}\) 对 \(q_{T|0}\) 是绝对连续的, \(\mathbb{P}^{\theta}\) 对 \(\mathbb {\hat Q}\) 也是绝对连续的,这是因为参考测度 \(q_{T|0}\) 和 \(\mathbb {\hat Q}\) 定义的马尔可夫链是比较“宽松的”,包含了所有可能出现的情况,所以他们的零测集在 \(p_{\text{ref}}\) 和 \(\mathbb{P}^{\theta}\) 上肯定也是零测的,因此两个对应的R-N导数 \(\frac{\mathrm{d}p_{\text{ref}}}{\mathrm{d}q_{T|0}}\) 和 \(\frac{\mathrm{d}\mathbb{P}^{\theta, x_T}}{\mathrm{d}\mathbb{{\hat Q}}^{x_T}}\) 存在。根据测度变换公式,进而有:

\[\begin{align*} - \log p_0^\theta(x_0) &= - \log \int q_{T|0}(\mathrm{d}x_T) \int_{\{\hat{W}_T = x_0\}} \frac{\mathrm{d}p_{\text{ref}}}{\mathrm{d}q_{T|0}} \frac{\mathrm{d}\mathbb{P}^{\theta, x_T}}{\mathrm{d}\hat{\mathbb{Q}}^{x_T}} \hat{\mathbb{Q}}^{x_T}(\mathrm{d}w) \\ &= - \log \int q_{T|0}(\mathrm{d}x_T) \int \frac{\mathrm{d}p_{\text{ref}}}{\mathrm{d}q_{T|0}} \frac{\mathrm{d}\mathbb{P}^{\theta, x_T}}{\mathrm{d}\hat {\mathbb{Q}}^{x_T}} \hat {\mathbb{Q}}^{x_T}\{\hat{W}_T = x_0\} \hat{\mathbb{Q}}(\mathrm{d}w \mid \hat{W}_0 = x_T, \hat{W}_T = x_0) \\ &\le \int q_{T|0}(\mathrm{d}x_T) \int \left\{ - \log \frac{\mathrm{d}\mathbb{P}^{\theta, x_T}}{\mathrm{d}\hat{\mathbb{Q}}^{x_T}}\right\}\hat{\mathbb{Q}}(\mathrm{d}w \mid \hat{W}_0 = x_T, \hat{W}_T = x_0) + C, \end{align*}\tag{3.5}\]

其中第三行用到了Jensen不等式。接着重点考虑 \(-\log \frac{\mathrm{d}\mathbb{P}^{\theta, x_T}}{\mathrm{d}\hat{\mathbb{Q}}^{x_T}}\) ,由R-N导数的定义和测度变换公式可知,R-N导数就是其测度(概率)密度的比值,即 \(-\log \frac{\mathrm{d}\mathbb{P}^{\theta, x_T}}{\mathrm{d}\hat{\mathbb{Q}}^{x_T}}(w)=-\log \frac{p^{\theta, x_T}(w)}{\hat{q}^{x_T}(w)}\) 。因为分母与 \(\theta\) 无关,因此我们只需要考虑 \(-\log p^{\theta, x_T}(w)\) 即可。 \(\mathbb{P}^{\theta, x_T}\) 是我们学习到的马尔可夫链,其转移速率矩阵为 \(\hat R_t^{\theta}\) ,所以在某个时刻一个状态可能转移,也有可能不转移。假设它的一条逆向路径为 \(\hat W=((x_T,T),(x_{T-s},T-S),...,(x_0,0))\) ,从 \(x_T\) 出发,经过若干次跳跃。那么它 \([T-s,T)\) 路径的概率密度应该由两部分组成:1. \((T-s,T)\) 之间不转移; \(2\) . \(T-s\) 时刻从 \(x_T\) 转移到 \(x_{T-s}\) 。

关于第一点,它不转移的概率就是在 \((T-s,T)\) 状态自己到自己的概率,由于是非齐次的,我们在 \(t\) 时刻考虑一小段 \(dt\) ,那么根据转移速率矩阵的定义,不转移的概率是 \(1+\hat R_{T-t}^{\theta}(x_{T-t})dt\approx1+\int_{t}^{t+dt} \hat R_{T-z}^{\theta}(x_{T-z})dz\) ,其中 \(\hat R_{T-t}^{\theta}(x_{T-t})=-\hat R_{T-t}^{\theta}(x_{T-t},x_{T-t})\) 。根据指数函数的性质又可以转化成 \(e^{-\int_{t}^{t+dt}\hat R_{T-z}^{\theta}(x_{T-z})dz}\) ,由于每一小段都是独立的,因此 \((T-s,T)\) 上不转移的概率就为: 

\[\prod_{t\in(T-s,T)}e^{-\int_{t}^{t+dt}\hat R_{T-z}^{\theta}(x_{T-z})dz}=e^{-\int_{0}^{s}\hat R_{T-z}^{\theta}(x_{T-z})dz}=e^{-\int_{0}^{s}\hat R_{T-z}^{\theta}(\hat W_z)dz}\tag{3.6}\]

(3.6)也能看出,不转移的概率服从参数为 \(\hat R_{T-z}^{\theta}(\hat W_z)\) 的泊松分布,所以转移(事件发生)服从参数为 \(\hat R_{T-z}^{\theta}(\hat W_z)\) 泊松过程。这也是很好理解的,因为定义的连续时间马尔可夫链本身是无记忆性的,所以必然服从指数分布,且发现泊松速率恰好就是转出速率。

关于第二点,又可以根据转移速率矩阵的定义可知 \(s\) 时刻转移速率就是概率转移密度,所以概率转移密度就为: \(\hat R_{T-s}^{\theta}(x_T,x_{T-s})\) ,也可以写为 \(\hat R_{T-s}^{\theta}(\hat W_{s^{-}},\hat W_s)\) ,其中 \(\hat W_{s^{-}}\) 表示 \(\hat W_s\) 之前的状态。

所以 \((T,T-s]\) 的路径密度就为: \(e^{-\int_{0}^{s}\hat R_{T-z}^{\theta}(\hat W_z)dz}\cdot \hat R_{T-s}^{\theta}(\hat W_{s^{-}},\hat W_s)\) 。有了 \((T,T-s]\) 的路径概率后,由于独立性,只需要将路径的每段的概率密度乘起来即可,所以:

\[\begin{align} -\log \mathbb{P}^{\theta, x_T}(\hat W)&=-\log \prod_{s:\hat W_{s^-}\ne \hat W_{s}}e^{-\int_{0}^{s}\hat R_{T-z}^{\theta}(\hat W_z)dz}\hat R_{T-s}^{\theta}(\hat W_{s^{-}},\hat W_s)\\ &=\int_{0}^{T}\hat R_{T-z}^{\theta}(\hat W_z)dz-\sum_{s:\hat W_{s^-}\ne \hat W_{s}}\log \hat R_{T-s}^{\theta}(\hat W_{s^{-}},\hat W_s)\\ \end{align}\tag{3.7}\]

这里面的累乘号和累加号下标表示的是对所有跳转的时刻进行累乘和累加。

再回到(3.5),由条件概率公式记 \(q_{T|0}(\mathrm{d}x_T)\hat{\mathbb{Q}}(\mathrm{d}w \mid \hat{W}_0 = x_T, \hat{W}_T = x_0)=\hat{\mathbb{Q}}^{\hat{W}_T = x_0}(\mathrm{d}w)\) 。然后把测度方向和路径方向换成正向,即 \(\hat{\mathbb{Q}}^{\hat{W}_T = x_0}(\mathbb dw)={\mathbb{Q}}^{x_0}(\mathbb dw)\) , \(\hat R_{T-z}^{\theta}(\hat W_z)=\hat R_{T-z}^{\theta}(W_{T-z})\) , \(\hat R_{T-s}^{\theta}(\hat W_{s^{-}},\hat W_s)=\hat R_{T-s}^{\theta}(W_{T-s},W_{(T-s)^{-}})\) 。那么损失函数就为负对数似然的期望:

\[\begin{align*} \int p_{\text{data}}(\mathrm{d}x_0) \left[ -\log p_0^\theta(x_0) \right] &\leq \int p_{\text{data}}(\mathrm{d}x_0) \mathbb{Q}^{x_0}(\mathrm{d}w) \Bigg\{ \int_0^T \hat{R}^\theta_{T-s}(W_{T-s}) \mathrm{d}s \\ &\quad - \sum_{s : W_{(T-s)^-} \neq W_{T-s}} \log \hat R_{T-s}^{\theta}(\hat W_{s^{-}},\hat W_s) \Bigg\} + C \end{align*}\tag{3.8} \]

其中 \(p_{\text{data}}(\mathrm{d}x_0) \mathbb{Q}^{x_0}(\mathrm{d}w) =\mathbb{Q}(\mathrm{d}w)\) ,所以(3.8)表达的意思是对整个真实路径的期望,而且(3.8)还存在对跳跃时刻的累加,这在实际情况中没法计算。而Dynkin's lemma告诉我们,可以把路径的期望沿着时间和状态拆解成对某时刻状态的期望再对所有时刻的积分。所以对于(3.8)第一项,可以理解成从某时刻 \(T-s\) 的真实边缘分布 \(p_{T-s}(x)\) 取一点 \(x\) ,然后计算 \(\hat{R}^\theta_{T-s}(x) \) 的期望,最后再对时间求积分,因此:

\[\begin{align} \int p_{\text{data}}(\mathrm{d}x_0) \mathbb{Q}^{x_0}(\mathrm{d}w) \int_0^T \hat{R}^\theta_{T-s}(W_{T-s}) \mathrm{d}s &=\int_{0}^{T}\mathrm{d}s\int p_{T-s}(\mathrm{d}x) \hat{R}^\theta_{T-s}(x)\\ &=\int_{0}^{T}\mathrm{d}s\int p_{s}(\mathrm{d}x) \hat{R}^\theta_{s}(x)\\ \end{align}\tag{3.9}\]

而第二项要表示所有跳跃步对函数 \(\log \hat R_{T-s}^{\theta}(\hat W_{s^{-}},\hat W_s)\) 的期望,那么可以理解为某时刻跳跃概率对函数 \(\log \hat R_{T-s}^{\theta}(\hat W_{s^{-}},\hat W_s)\) 的期望,最后再对所有时刻求积分。某时刻 \(T-s\) 跳跃的期望应该是给定状态 \(x\) 后,首先得跳出,而真实跳出的密度是 \(R_{T-s}(x)\) ,然后假设在跳出的条件下,跳到 \(y\) 状态的真实概率是 \(r_{T-s}(y\mid x)\) ,满足 \(r_{T-s}(y\mid x)=(1-\delta_{xy})R_{T-s}(x,y)/\sum_{y\ne x}R_{T-s}(x,y)\) 。这是因为 \(r_{T-s}(y\mid x)\) 表示的是在跳跃已经发生情况下的概率,所以自己到自己的概率是0,其次 \(x\) 到 \(y\) 的概率就是 \(x\) 到所有不包括自己的可能的状态的速率占比。所以跳跃的期望就是上述项乘起来且对 \(x,y\) 进行积分:

\[\begin{align} &\int p_{\text{data}}(\mathrm{d}x_0) \mathbb{Q}^{x_0}(\mathrm{d}w) \sum_{s : W_{(T-s)^-} \neq W_{T-s}} \log \hat R_{T-s}^{\theta}(\hat W_{s^{-}},\hat W_s) \\=&\int_{0}^{T}\mathrm{d}s\int p_{T-s}(\mathrm{d}x) R_{T-s}(x)r_{T-s}(\mathrm dy\mid x) \log \hat R_{T-s}^{\theta}(\hat W_{s^{-}},\hat W_s)\\ =&\int_{0}^{T}\mathrm{d}s\int p_{s}(\mathrm{d}x) r_{s}(\mathrm dy\mid x) R_{s}(x)\log \hat R_{s}^{\theta}(y,x)\\ \end{align}\tag{3.10}\]

其中第二行是因为求期望的状态 \(x\) 在 \(y\) 的左边,而对于逆向的 \(\hat R_{T-s}^{\theta}(\hat W_{s^{-}},\hat W_s)\) 表示的是从右向左的,所以实际上应该是 \(\hat R_{T-s}^{\theta}(y,x)\) 。最终(3.8)化简为:

\[\begin{align*} \int p_{\text{data}}(\mathrm{d}x_0) \left[ -\log p_0^\theta(x_0) \right] &\leq \int_0^T \mathrm d s\int p_{s}(\mathrm{d}x) r_{s}(\mathrm dy\mid x) \Bigg\{ \hat{R}^\theta_{s}(x) - R_{s}(x)\log \hat R_{s}^{\theta}(y,x) \Bigg\} + C \end{align*}\tag{3.11}\]

这就导出了原文中的Proposition 2,另外根据转移速率矩阵的特点有 \(R_{s}(x)=\sum_{x'\ne x}R_{s}(x,x')\) 。(3.11)实际上是边缘分布形式的损失函数,所以为了计算它,我们可以采样一批 \(x_0\) ,随机一个时间 \(t\) ,然后由前向分布 \(p_{t|0}(x_t\mid x_0)\) 得到一批 \(x_t\) ,这样的一批 \(x_t\) 可以近似的看成从 \(p_t(x_t)\) 中进行采样。接下来用跳转概率 \(r_t(y_t\mid x_t)\) 对每个 \(x_t\) 得到对应的 \(y_t\) ,而参数化的逆向速率矩阵用(3.3)式计算即可,即: \(\hat R_t^{\theta}(x, \tilde x)=R_t(\tilde x, x)\sum_{x_0\in S}\frac{p^{\theta}_{0|t}(x_0\mid x)}{p_{t|0}(x\mid x_0)}p_{t|0}(\tilde x\mid x_0)\) 。

多维情况

假如数据(句子) \(\mathbf x^{1:D}=(x^1,x^2,...,x^D)\in \mathbb R^{S^D}\) ,如果当成一个整体,那么整个状态空间会非常大,假如一个句子20个单词,词典里面有3w个单词,那么一个句子的状态空间就是 \(30000^{20}\) ,这对于计算机的实现显然不太现实,而且这样将一个句子当成整体也没有考虑各个单词(元素)之间的关系。所以我们需要对多维情况下的逆向速率矩阵和损失函数等进行元素上的拆解与化简。

对于前向过程,我们定义各个元素的扩散满足条件独立性:

\[q_{t \mid s}(\mathbf{x}_t^{1:D} \mid \mathbf{x}_s^{1:D}) = \prod_{d=1}^{D} q_{t \mid s}(x_t^d \mid x_s^d), \quad t > s\tag{3.12}\]

首先考虑多维逆向速率矩阵 \(\hat{R}_t^{1:D}(\mathbf{x}^{1:D}, \tilde{\mathbf{x}}^{1:D}) \) ,这里为了简化把数据的下标省略了。根据Kolmogorov前向方程(把矩阵展开)有:

\[\partial_t q_{t \mid s}(\mathbf{x}^{1:D} \mid \tilde{\mathbf{x}}^{1:D}) = \sum_{\mathbf{y}^{1:D}} q_{t \mid s}(\mathbf{y}^{1:D} \mid \tilde{\mathbf{x}}^{1:D}) R_t^{1:D}(\mathbf{y}^{1:D}, \mathbf{x}^{1:D}) \tag{3.13}\]

对于左边有:

\[\begin{align} \partial_t q_{t \mid s}(\mathbf{x}^{1:D} \mid \tilde{\mathbf{x}}^{1:D}) &= \partial_t \left\{ \prod_{d=1}^D q_{t \mid s}(x^d \mid \tilde{x}^d) \right\} \\ &= \sum_{d=1}^D q_{t \mid s}(\mathbf{x}^{1:D \setminus d} \mid \tilde{\mathbf{x}}^{1:D \setminus d}) \partial_t q_{t \mid s}(x^d \mid \tilde{x}^d) \\ &= \sum_{d=1}^D q_{t \mid s}(\mathbf{x}^{1:D \setminus d} \mid \tilde{\mathbf{x}}^{1:D \setminus d}) \sum_{y^d} q_{t \mid s}(y^d \mid \tilde{x}^d) R_t^d(y^d, x^d) \\ &= \sum_{d=1}^D \sum_{\mathbf{y}^{1:D}} q_{t \mid s}(\mathbf{x}^{1:D \setminus d} \mid \tilde{\mathbf{x}}^{1:D \setminus d}) q_{t \mid s}(y^d \mid \tilde{x}^d) R_t^d(y^d, x^d) \delta_{\mathbf{x}^{1:D \setminus d}, \mathbf{y}^{1:D \setminus d}} \\ &= \sum_{d=1}^D \sum_{\mathbf{y}^{1:D}} q_{t \mid s}(\mathbf{y}^{1:D \setminus d} \mid \tilde{\mathbf{x}}^{1:D \setminus d}) q_{t \mid s}(y^d \mid \tilde{x}^d) R_t^d(y^d, x^d) \delta_{\mathbf{x}^{1:D \setminus d}, \mathbf{y}^{1:D \setminus d}} \\ &= \sum_{\mathbf{y}^{1:D}} q_{t \mid s}(\mathbf{y}^{1:D} \mid \tilde{\mathbf{x}}^{1:D}) \sum_{d=1}^D R_t^d(y^d, x^d) \delta_{\mathbf{x}^{1:D \setminus d}, \mathbf{y}^{1:D \setminus d}} \end{align}\tag{3.14}\]

其中,第四个等号由于Kronecker \(\delta\) 函数 \(\delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}} \) 的筛选性,能将除 \(d\) 个元素外其它相等的项筛出来,所以总共也只在 \(y^d\) 上有遍历。结合(3.13),(3.14)有:

\[\sum_{\mathbf{y}^{1:D}} q_{t \mid s}(\mathbf{y}^{1:D} \mid \tilde{\mathbf{x}}^{1:D}) R_t^{1:D}(\mathbf{y}^{1:D}, \mathbf{x}^{1:D}) = \sum_{\mathbf{y}^{1:D}} q_{t \mid s}(\mathbf{y}^{1:D} \mid \tilde{\mathbf{x}}^{1:D}) \sum_{d=1}^D R_t^d(y^d, x^d) \delta_{\mathbf{x}^{1:D \setminus d}, \mathbf{y}^{1:D \setminus d}}\tag{3.15}\]

由于 \(q_{t \mid s}(\mathbf{y}^{1:D} \mid \tilde{\mathbf{x}}^{1:D})\) 是任意的,所以有以下关系:

\[R_t^{1:D}(\tilde{\mathbf{x}}^{1:D}, \mathbf{x}^{1:D}) = \sum_{d=1}^D R_t^d(\tilde{x}^d, x^d) \delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}} \tag{3.16}\]

从上式可以看出,只有两个句子之间只差一个元素时才发生转移,而差两个以上元素的情况转移速率均为0,也就是说转移只发生在单个元素上。因此从微观上,前向过程同一时刻只有一个元素在转移,而不是直觉理解的所有元素在同一时刻互不干扰发生转移,只是这种转移(扩散)的独立性在宏观上是成立的,如定义的(3.12)所示。根据(3.3),多维逆向速率矩阵有:

\[\begin{align} \hat{R}_t^{1:D}(\mathbf{x}^{1:D}, \tilde{\mathbf{x}}^{1:D}) &=R_t^{1:D}(\tilde{\mathbf{x}}^{1:D}, {\mathbf{x}}^{1:D})\sum_{{\mathbf{x}_0}^{1:D}}\frac{q_{0|t}({\mathbf{x}_0}^{1:D}\mid {\mathbf{x}}^{1:D})}{q_{t|0}({\mathbf{x}}^{1:D}\mid {\mathbf{x}_0}^{1:D})}q_{t|0}(\tilde{\mathbf{x}}^{1:D}\mid {\mathbf{x}_0}^{1:D})\\ &= \sum_{\mathbf{x}_0^{1:D}} \sum_{d=1}^D R_t^d(\tilde{x}^d, x^d) \frac{q_{t|0}(\tilde{\mathbf{x}}^{1:D} \mid \mathbf{x}_0^{1:D})}{q_{t|0}(\mathbf{x}^{1:D} \mid \mathbf{x}_0^{1:D})} q_{0|t}(\mathbf{x}_0^{1:D} \mid \mathbf{x}^{1:D}) \delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}} \\ &= \sum_{\mathbf{x}_0^{1:D}} \sum_{d=1}^D R_t^d(\tilde{x}^d, x^d) \frac{q_{t|0}(\tilde{x}^d \mid x_0^d)}{q_{t|0}(x^d \mid x_0^d)} q_{0|t}(\mathbf{x}_0^{1:D} \mid \mathbf{x}^{1:D}) \delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}} \\ &= \sum_{d=1}^D R_t^d(\tilde{x}^d, x^d) \delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}} \sum_{x_0^d}q_{0|t}({x}_0^{d} \mid \mathbf{x}^{1:D}) \frac{q_{t|0}(\tilde{x}^d \mid x_0^d)}{q_{t|0}(x^d \mid x_0^d)} \sum_{\mathbf{x}_0^{1:D \setminus d}} q_{0|t}(\mathbf{x}_0^{1:D \setminus d} \mid x_0^d, \mathbf{x}^{1:D}) \\ &= \sum_{d=1}^D R_t^d(\tilde{x}^d, x^d) \delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}} \sum_{x_0^d} q_{0|t}(x_0^d \mid \mathbf{x}^{1:D}) \frac{q_{t|0}(\tilde{x}^d \mid x_0^d)}{q_{t|0}(x^d \mid x_0^d)} \end{align} \tag{3.17}\]

其中,第三个等号是由于 \(\delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}} \) 的筛选性,和(3.12)的独立扩散。从上式同样能看出,逆向过程在同一时刻也是只发生在一个元素上,所以需要学习的空间就只有 \(D\times S\) 了。与前向不同的是,逆向速率与整个 \(\mathbf x_t^{1:D}\) 有关,所以我们实际上可以参数化 \(q_{0|t}(x_0^d \mid \mathbf{x}^{1:D})\) 这种预测元素的形式,输入为 \(D\times S\) ,每一行表示一个token的独热编码向量,输出同样为 \(D\times S\) ,第 \(d\) 行表示预测的 \(q_{0|t}(x_0^d \mid \mathbf{x}^{1:D})\) 。为什么不定义成微观上的独立扩散?因为此时转移速率的空间为全空间 \(S^D\) ,学习成本太高,不好建模。

接下来还需要计算多维损失函数,并化简成含有 \(q_{0|t}(x_0^d \mid \mathbf{x}^{1:D})\) 的形式。根据(5.11),很容易写出多维损失函数为:

\[\begin{align} \mathcal L_{\text{CT}}(\theta)= \int_0^T \mathrm d t\int p_{t}(\mathrm{d}\mathbf x^{1:D}) r_{t}(\mathrm d \tilde{\mathbf{x}}^{1:D}\mid \mathbf x^{1:D}) \Bigg\{ \hat{R}^\theta_{t}(\mathbf x^{1:D}) - R_{t}(\mathbf x^{1:D})\log \hat R_{t}^{\theta}(\tilde{\mathbf{x}}^{1:D},\mathbf x^{1:D}) \Bigg\} + C \end{align}\tag{3.18}\]

其中:

\[r_{t}(\tilde{\mathbf{x}}^{1:D}\mid \mathbf x^{1:D})=(1-\delta_{\mathbf x^{1:D},\tilde{\mathbf{x}}^{1:D}})R_{t}(\mathbf x^{1:D},\tilde{\mathbf{x}}^{1:D})/\sum_{\tilde{\mathbf{x}}^{1:D}\ne \mathbf x^{1:D}}R_{t}(\mathbf x^{1:D},\tilde{\mathbf{x}}^{1:D})\tag{3.19}\]

(3.17)的右边,我们记:

\[\hat{R}_t^{\theta, d}(\mathbf{x}^{1:D}, \tilde{x}^d) = R_t^d(\tilde{x}^d, x^d) \sum_{x_0^d} p_{0|t}^\theta(x_0^d \mid \mathbf{x}^{1:D}) \frac{q_{t|0}(\tilde{x}^d \mid x_0^d)}{q_{t|0}(x^d \mid x_0^d)} \tag{3.20}\]

所以(3.17)可以简写为:

\[\hat{R}_t^{\theta, 1:D}(\mathbf{x}^{1:D}, \tilde{\mathbf{x}}^{1:D}) = \sum_{d=1}^D \hat{R}_t^{\theta, d}(\mathbf{x}^{1:D}, \tilde{x}^d) \delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}}\tag{3.21}\]

接下来先处理损失函数(3.18)第一项:

\[\begin{align} \hat{R}^{\theta_{t},1:D}(\mathbf x^{1:D})&= \sum_{\mathbf{x}'^{1:D} \ne \mathbf{x}^{1:D}} \hat{R}^{\theta_{t},1:D}(\mathbf x^{1:D}, \mathbf{x}'^{1:D})\\ &=\sum_{\mathbf{x}'^{1:D} \ne \mathbf{x}^{1:D}} \sum_{d=1}^D \hat{R}_t^{\theta, d}(\mathbf{x}^{1:D}, x'^d) \delta_{\mathbf{x}^{1:D \setminus d}, \mathbf{x}'^{1:D \setminus d}} \\ &= \sum_{d=1}^D \sum_{x'^d \ne x^d} \hat{R}_t^{\theta, d}(\mathbf{x}^{1:D}, x'^d) \end{align}\tag{3.22} \]

这样,就化简成了预测元素的形式了。

对于损失函数(3.18)第二项,我们先考虑期望概率 \(p_{t}(\mathrm{d}\mathbf x^{1:D}) r_{t}(\mathrm d \tilde{\mathbf{x}}^{1:D}\mid \mathbf x^{1:D})\) 。由于第二项与 \(\mathbf x^{1:D}_0\) 无关,所以可以写成加噪形式的期望概率 \(p_0(\mathrm{d}\mathbf x^{1:D}_0)p_{t}(\mathrm{d}\mathbf x^{1:D}\mid \mathbf x^{1:D}_0) r_{t}(\mathrm d \tilde{\mathbf{x}}^{1:D}\mid \mathbf x^{1:D})\) ,表示采样 \(\mathbf x^{1:D}_0\) 后,再采样 \(t\) 时刻的 \(\mathbf x^{1:D}\) ,再同时采样与当前时刻不同的转移状态 \(\tilde{\mathbf{x}}^{1:D}\) 。但现在考虑直接从 \(\mathbf x^{1:D}_0\) 到 \(\tilde{\mathbf{x}}^{1:D}\) 的概率,所以根据联合分布可以写出:

\[p_{0}(\mathbf{x}_0^{1:D}) q_{t|0}(\mathbf{x}^{1:D} \mid \mathbf{x}_0^{1:D}) r_t(\tilde{\mathbf{x}}^{1:D} \mid \mathbf{x}_0^{1:D}) = p_{0}(\mathbf{x}_0^{1:D}) \psi_t(\tilde{\mathbf{x}}^{1:D} \mid \mathbf{x}_0^{1:D}) \phi_t(\mathbf{x}^{1:D} \mid \tilde{\mathbf{x}}^{1:D}, \mathbf{x}_0^{1:D})\tag{3.23}\]

其中 \(\psi_t(\tilde{\mathbf{x}}^{1:D} \mid \mathbf{x}_0^{1:D})\) 就是前向过程,剩下的关键是:

\[\begin{align} \phi_t(\mathbf{x}^{1:D} \mid \tilde{\mathbf{x}}^{1:D}, \mathbf{x}_0^{1:D}) &= \frac{q_{t|0}(\mathbf{x}^{1:D} \mid \mathbf{x}_0^{1:D}) q(\tilde{\mathbf{x}}^{1:D} \mid \mathbf{x}^{1:D},\mathbf{x}_0^{1:D})}{q_{t|0}(\tilde{\mathbf{x}}^{1:D}\mid \mathbf{x}_0^{1:D})}\\ &\propto q_{t|0}(\mathbf{x}^{1:D} \mid \mathbf{x}_0^{1:D}) r_t(\tilde{\mathbf{x}}^{1:D} \mid \mathbf{x}^{1:D}) \\ &= q_{t|0}(\mathbf{x}^{1:D} \mid \mathbf{x}_0^{1:D}) \left(1 - \delta_{\tilde{\mathbf{x}}^{1:D}, \mathbf{x}^{1:D}} \right) \sum_{d=1}^D \frac{R_t^d(x^d, \tilde{x}^d)\delta_{ \mathbf{x}^{1:D\setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}}}{R_t({\mathbf{x}}^{1:D} )} \\ &= \sum_{d=1}^D \frac{R_t^d(x^d, \tilde{x}^d)}{R_t(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d)} q_{t|0}(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d \mid \mathbf{x}_0^{1:D}) \delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}} \left(1 - \delta_{\tilde{\mathbf{x}}^{1:D}, \mathbf{x}^{1:D}} \right) \end{align}\tag{3.24} \]

其中,第一个等号跟之前求解 \(q(x_{t-1}\mid x_t,x_0)\) 如出一辙;而第二个正比号也是根据马尔可夫性质得到当前时刻的转移速率;第三个等号是用(3.19)展开 \(r_t\) 后,再用(3.16)展开前向转移速率矩阵;而第四个等号是由于 \(\delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}} \left(1 - \delta_{\tilde{\mathbf{x}}^{1:D}, \mathbf{x}^{1:D}} \right)\) 的筛选性,只有刚好 \(\mathbf{x}^{1:D}, \tilde{\mathbf{x}}^{1:D }\) 在 \(d\) 维不同时才不为0,所以有关系 \({\mathbf{x}}^{1:D } =\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d\) , \(\circ\) 表示concat。(3.24)表示的是分子,那么分母可以对分子的 \(\mathbf{x}^{1:D}\) 求和得到:

\[\begin{align} q_{t|0}(\tilde{\mathbf{x}}^{1:D}\mid \mathbf{x}_0^{1:D})&= \sum_{\mathbf{x}^{1:D}} \sum_{d=1}^D \frac{R_t^d(x^d, \tilde{x}^d)}{R_t(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d)} q_{t|0}(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d \mid \mathbf{x}_0^{1:D}) \delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}} \left(1 - \delta_{\tilde{\mathbf{x}}^{1:D}, \mathbf{x}^{1:D}} \right) \\ &= \sum_{d=1}^D \sum_{x^d \ne \tilde{x}^d} \frac{R_t^d(x^d, \tilde{x}^d)}{R_t(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d)} q_{t|0}(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d \mid \mathbf{x}_0^{1:D}) \end{align}\tag{3.25}\]

因此,与(3.21)的简记思想一样,如果记:

\[\begin{align} \phi_t(x^d \mid \tilde{\mathbf{x}}^{1:D}, \mathbf{x}_0^{1:D}) &=\frac{\frac{R_t^d(x^d, \tilde{x}^d)}{R_t(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d)} q_{t|0}(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d \mid \mathbf{x}_0^{1:D}) }{q_{t|0}(\tilde{\mathbf{x}}^{1:D}\mid \mathbf{x}_0^{1:D})} \\ &= \frac{R_t^d(x^d, \tilde{x}^d) \, q_{t|0}(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d \mid \mathbf{x}_0^{1:D})} {R_t(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d) \sum_{d'=1}^D \sum_{x^{d'} \ne \tilde{x}^{d'}} \frac{R_t^{d'}(x^{d'}, \tilde{x}^{d'})}{R_t(\tilde{\mathbf{x}}^{1:D \setminus d'} \circ x^{d'})} q_{t|0}(\tilde{\mathbf{x}}^{1:D \setminus d'} \circ x^{d'} \mid \mathbf{x}_0^{1:D})} \end{align}\tag{3.26}\]

那么:

\[\phi_t(\mathbf{x}^{1:D} \mid \tilde{\mathbf{x}}^{1:D}, \mathbf{x}_0^{1:D}) = \sum_{d=1}^D \phi_t(x^d \mid \tilde{\mathbf{x}}^{1:D}, \mathbf{x}_0^{1:D}) \delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}}\left(1 - \delta_{\tilde{\mathbf{x}}^{1:D}, \mathbf{x}^{1:D}} \right) \tag{3.27}\]

最终损失函数第二项为:

\[\begin{align} &-\int_0^T \mathrm dt \int p_{0}(\mathrm d \mathbf{x}_0^{1:D}) \psi_t(\mathrm d \tilde{\mathbf{x}}^{1:D} \mid \mathbf{x}_0^{1:D}) \phi_t(\mathrm d\mathbf{x}^{1:D} \mid \tilde{\mathbf{x}}^{1:D}, \mathbf{x}_0^{1:D}) R_{t}(\mathbf x^{1:D})\log \hat R_{t}^{\theta}(\tilde{\mathbf{x}}^{1:D},\mathbf x^{1:D})\\ =&-\int_0^T \mathrm dt \int p_{0}(\mathrm d \mathbf{x}_0^{1:D}) \psi_t(\mathrm d \tilde{\mathbf{x}}^{1:D} \mid \mathbf{x}_0^{1:D}) \left[\sum_{\mathbf{x}^{1:D}} \phi_t(\mathbf{x}^{1:D} \mid \tilde{\mathbf{x}}^{1:D}, \mathbf{x}_0^{1:D}) R_{t}(\mathbf x^{1:D})\log \hat R_{t}^{\theta}(\tilde{\mathbf{x}}^{1:D},\mathbf x^{1:D})\right]\\ =&-\int_0^T \mathrm dt \int p_{0}(\mathrm d \mathbf{x}_0^{1:D}) \psi_t(\mathrm d \tilde{\mathbf{x}}^{1:D} \mid \mathbf{x}_0^{1:D}) \left[\sum_{d=1}^D \sum_{x^d \ne \tilde{x}^d} \phi_t(x^d \mid \tilde{\mathbf{x}}^{1:D}, \mathbf{x}_0^{1:D}) R_t(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d) \log\left( \hat{R}_t^{\theta, d}(\tilde{\mathbf{x}}^{1:D}, x^d) \right) \right] \end{align} \tag{3.28}\]

其中,第二个等号是由于 \(\delta_{\mathbf{x}^{1:D \setminus d}, \tilde{\mathbf{x}}^{1:D \setminus d}} \left(1 - \delta_{\tilde{\mathbf{x}}^{1:D}, \mathbf{x}^{1:D}} \right)\) 的筛选性与(3.21)。于是(3.28)的参数化项也变成了预测单个元素的形式了。

最终的损失函数为:

\[\begin{align} \mathcal L_{\text{CT}}(\theta)=& \int_0^T \mathrm dt \int p_{0}(\mathrm d \mathbf{x}_0^{1:D}) p_{t|0}(\mathrm{d}\mathbf x^{1:D} \mid \mathbf x^{1:D}_0 )\left[\sum_{d=1}^D \sum_{x'^d \ne x^d} \hat{R}_t^{\theta, d}(\mathbf{x}^{1:D}, x'^d) \right]\\ &-\int_0^T \mathrm dt \int p_{0}(\mathrm d \mathbf{x}_0^{1:D}) \psi_t(\mathrm d \tilde{\mathbf{x}}^{1:D} \mid \mathbf{x}_0^{1:D}) \left[\sum_{d=1}^D \sum_{x^d \ne \tilde{x}^d} \phi_t(x^d \mid \tilde{\mathbf{x}}^{1:D}, \mathbf{x}_0^{1:D}) R_t(\tilde{\mathbf{x}}^{1:D \setminus d} \circ x^d) \log\left( \hat{R}_t^{\theta, d}(\tilde{\mathbf{x}}^{1:D}, x^d) \right) \right]\\ &+ C \end{align}\tag{3.29}\]

可以看到,多维情况下的损失函数还是挺复杂的,而且推理的时候逆向速率矩阵还得用(3.21)进行计算再采样,因为里面有非常多的求和运算,所以即使是纯数学上的运算,有可能也会带来很大的计算开销。

对于多维情况的推理,我们当然会将连续的时间 \([0,T]\) 离散成一些步数,假设是 \([0,...,t_{n},t_{n+1},..,T]\) ,那么推理的时候从直接从 \(\mathbf x_{t_{n+1}}\) 预测到 \(\mathbf x_{t_{n}}\) 即可。但是这里要注意的是,对于连续马尔可夫链,它的逆向转移速率与当前状态是有关的,所以真实的情况下假设 \(t_{n+1}\) 时刻状态为 \(\mathbf x_{t_{n+1}}\) ,用当前的转移速率来预测 \(\mathbf x_{t_{n+1}-\Delta t}\) ,假设某些维度状态发生了改变,那么预测下一时段的 \(\mathbf x_{t_n}\) 用的应该是改变后状态 \(\mathbf x_{t_{n+1}-\Delta t}\) 产生的转移速率而不是之前 \(t_{n+1}\) 时刻 \(\mathbf x_{t_{n+1}}\) 所定义的。简单来说,预测所需要的转移速率会随着状态一直变化,所以理论上应该要一个状态一个状态的更新才能最最接近真实的效果,但显然这样是不现实的。而tau-leaping的思想就是直接从从 \(\mathbf x_{t_{n+1}}\) 预测到 \(\mathbf x_{t_{n}}\) ,不考虑一个状态一个状态的更新了。当然这样做会产生误差,原论文还给了误差的界限。另外,作者还训了得分,每次预测后做predictor-corrector,这本质就是Langevin采样,用来让新变量尽量服从 \(p_t(\mathbf x_t)\) ,提高采样质量。注意到D3PM是没有这个缺陷的,因为离散的时间已经把 \(t\) 定死了,所以 \(t_{n}\) 的前一时刻就是 \(t_{n+1}\) 。

最后,关于转移速率矩阵 \(R_t\) 的选取。首先,如果任意给定一个转移速率矩阵形式为 \(R_t=\beta(t)R\) , \(R\) 是参考的常数矩阵满足对角化 \(R=Q\Lambda Q^{-1}\) ,那么它 \(0\) 到 \(t\) 的转移概率矩阵可以用:

\[P_t=Qe^{\Lambda \int_{0}^{t}\beta(s)\mathrm ds }Q^{-1}\tag{3.30}\]

表示,证明只需要把(3.12)带入Kolmogorov方程即可。所以选择合适的 \(\beta(t)\) 和 \(R\) 就可以一步算出转移概率,并且可以得到想要的分布 \(p_{\text{ref}}\) ,比如均匀分布或者全为吸收态的Mask。比如:

\[\begin{align*} R_{\text{uniform}} &= \begin{bmatrix} 1 - S & 1 & \cdots & 1 \\ 1 & 1 - S & \cdots & 1 \\ \vdots & \vdots & \ddots & \vdots \\ 1 & 1 & \cdots & 1 - S \end{bmatrix} \\[1em] R^{\text{absorb}} &= \begin{bmatrix} -1 & 0 & \cdots & 0 & 0 \\ 0 & -1 & \cdots & 0 & 0 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & \cdots & -1 & 0 \\ 1 & 1 & \cdots & 1 & 0 \end{bmatrix} \end{align*}\tag{3.31}\]

对于 \(R_{\text{uniform}}\) 来说,一个状态转移到其它状态的速率是相等的;对于 \(R^{\text{absorb}}\) 来说,只能转移到吸收态或者不转移,转入转出率相等。

至此基本说明完了连续时间马尔可夫链生成框架,它可以用于处理状态空间离散情况的生成,而SDE适用于状态空间连续情况的生成。

SEDD

有了前面D3PM,Concrete Score Matching和Continuous Time Framework的铺垫,ICML 2024年的Best Paper《Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution》介绍的SEDD(Score Entropy Discrete Diffusion models)理解起来就非常容易了。

一维情况

对于一维情况,SEDD同样是定义在连续时间马尔可夫链上,假设状态空间 \(\mathcal X=\{1,...,N\}\) ,转移速率矩阵为 \(Q_t\) ,那么演化过程服从:

\[\frac{\mathrm d p_t}{\mathrm d t}=Q_tp_t\tag{4.1}\]

这就是Fokker-Planck方程,要注意的是, 这里的 \([Q_t]_{i,j}=q_t(x_{t+\Delta t}=i\mid x_{t}=j)\) ,跟之前的速率矩阵互为转置,所以概率要乘在矩阵的右边。

前向过程的离散形式和(3.1)类似:

\[p(x_{t+\Delta t} = y \mid x_t = x) = \delta_{xy} + Q_t(y, x)\Delta t + O(\Delta t) \tag{4.2}\]

逆向过程同理:

\[p(x_{t} = x \mid x_{t+\Delta t} = y) = \delta_{yx} + \overline Q_t(x, y)\Delta t + O(\Delta t) \tag{4.3}\]

所以问题的关键同样是如何学习逆矩阵。重新看(3.3)关于逆矩阵的推导,注意到第一行的 \(\frac{p_t(\tilde x)}{p_t(x)}\) 其实就是Concrete Score,所以SEDD关于逆向矩阵的推导到第一步就停止了:

\[\overline{Q}_t(y, x) = \frac{p_t(y)}{p_t(x)} Q_t(x, y)\tag{4.4}\]

因此SEDD模型想要参数化Concrete Score \(\frac{p_t(y)}{p_t(x)}\) ,而不是逆向概率 \(p_{0|t}\) ,这是与之前Continuous Time Framework的最大的不同。关于作者为什么要参数化Concrete Score,首先一个重要原因就是学好Concrete Score后,乘以前向转移速率矩阵可以直接得到逆向转移速率矩阵了,避免了繁琐的求和运算。另外离散扩散模型学习Concrete Score,就跟连续扩散模型学习Score一样(两者的联系在定理2.1),在这种视角下也能完成离散连续的统一。

具体的,定义 \(s_{\theta}(x,t)\approx \left[\frac{p_t(y)}{p_t(x)}\right]_{y\ne x}\) ,但是作者并不用Concrete Score Matching那种MSE直接学习Concrete Score,主要是会比较难学,效果不好,原文附录D.1有消融实验说明。

所以最自然的当然是考虑负对数似然的期望作为损失函数,更准确的应该是它的ELBO,在前面(3.11)已经推导出来了,这也是模型名字SEDD中的Score Entropy的由来。把(3.11)用 \(Q_t\) 重新表示,并更改下积分变量,则:

\[\begin{align*} \int p_{\text{data}}(\mathrm{d}x_0) \left[ -\log p_0^\theta(x_0) \right] &\leq \int_0^T \mathrm d t\int p_{t}(x)\mathrm{d}x\int r_{t}(z\mid x)\mathrm dz \Bigg\{ \overline{Q}^\theta_{t}(x) - Q_t(x)\log \overline{Q}_t^{\theta}(x,z) \Bigg\} + C \end{align*}\tag{4.5}\]

事实上,由于转移变量 \(z\) 是离散的,所以我们可以用期望的定义把它拆开,再结合:

\[\overline{Q}_t^{\theta}(x,z)=Q_t(z,x)\left[\frac{p_t(x)}{p_t(z)}\right]_{\theta}=Q_t(z,x)[s_{\theta}(z,t)]_{x}\tag{4.6}\]

和:

\[\overline{Q}^\theta_{t}(x) =\sum_{y\ne x}\overline{Q}^\theta_{t}(y,x)=\sum_{y\ne x}Q_t(x,y)[s_{\theta}(x,t)]_{y}\tag{4.7}\]

有:

\[\begin{align} &\int p_{\text{data}}(\mathrm{d}x_0) \left[ -\log p_0^\theta(x_0) \right] \\ \leq& \int_0^T \mathrm d t\int p_{t}(x)\mathrm{d}x\int r_{t}(z\mid x)\mathrm dz \Bigg\{ \overline{Q}^\theta_{t}(x) - Q_t(x)\log \overline{Q}_t^{\theta}(x,z) \Bigg\} + C \\ =&\int_0^T \mathrm d t\int p_{t}(x)\mathrm{d}x\Bigg\{\sum_{y\ne x}Q_t(x,y)[s_{\theta}(x,t)]_{y}-\sum_{z\ne x}\frac{Q_t(z,x)}{Q_t(x)}Q_t(x)\log \left\{Q_t(z,x)[s_{\theta}(z,t)]_{x} \right\} \Bigg\} +C\\ =&\int_0^T \mathrm d t\int p_{t}(x)\mathrm{d}x\Bigg\{\sum_{y\ne x}Q_t(x,y)[s_{\theta}(x,t)]_{y}-\sum_{z\ne x}Q_t(z,x)\log [s_{\theta}(z,t)]_{x} \Bigg\} +C'\\ =&\int_0^T \mathrm d t\int p_{t}(x)\mathrm{d}x\Bigg\{\sum_{y\ne x}Q_t(x,y)[s_{\theta}(x,t)]_{y}-\sum_{y\ne x}Q_t(y,x)\log [s_{\theta}(y,t)]_{x} \Bigg\} +C'\\ =&\int_0^T \mathrm d t\int p_{t}(x)\mathrm{d}x\sum_{y\ne x}\Bigg\{Q_t(x,y)[s_{\theta}(x,t)]_{y}-Q_t(y,x)\log [s_{\theta}(y,t)]_{x} \Bigg\} +C' \end{align}\tag{4.8}\]

如果记 \(w_{xy}=Q_t(x,y)\) ,就得到了原论文的Proposition 3.3 (Implicit Score Entropy)。由于上式最后出现了两个对称的项 \([s_{\theta}(x)]_{y}\) 和 \([s_{\theta}(y)]_{x}\) ,所以接下来就想尝试把它们统一,不然一次训练得过两次神经网络了。注意到:

\[\begin{align} \mathbb E_{x}\sum_{y\ne x}f(y,x) &=\sum_xp(x)\sum_{y\ne x}f(y,x)\\ &=\sum_{x,y,x\ne y}p(x)f(y,x)\\ &=\sum_{x,y,x\ne y}p(y)f(x,y)\\ &=\sum_x\sum_{y\ne x}p(y)f(x,y)\\ &=\sum_{x}\sum_{y\ne x}p(x)\frac{p(y)}{p(x)}f(x,y)\\ &=\mathbb E_{x}\sum_{y\ne x}\frac{p(y)}{p(x)}f(x,y) \end{align}\tag{4.9}\]

所以如果令上式 \(f(y,x)=Q_t(y,x)\log [s_{\theta}(y)]_{x}\) , \(x\sim p_t(x)\) 与时间 \(t\) 有关,那么(4.8)就可以写成:

\[\begin{align} \int p_{\text{data}}(\mathrm{d}x_0) \left[ -\log p_0^\theta(x_0) \right] &\leq \int_0^T \mathrm d t\int p_{t}(x)\mathrm{d}x\sum_{y\ne x}\Bigg\{Q_t(x,y)[s_{\theta}(x,t)]_{y}-Q_t(y,x)\log [s_{\theta}(y,t)]_{x} \Bigg\} +C'\\ &=\int_0^T \mathrm d t\int p_{t}(x)\mathrm{d}x\sum_{y\ne x}\Bigg\{Q_t(x,y)[s_{\theta}(x,t)]_{y}-\frac{p_t(y)}{p_t(x)}Q_t(x,y)\log [s_{\theta}(x,t)]_{y} \Bigg\} +C'\\ &=\int_0^T \mathrm d t\int p_{t}(x)\mathrm{d}x\sum_{y\ne x}Q_t(x,y)\Bigg\{[s_{\theta}(x,t)]_{y}-\frac{p_t(y)}{p_t(x)}\log [s_{\theta}(x,t)]_{y} \Bigg\} +C' \end{align}\tag{4.10}\]

这就导出了原论文的Definition 3.1。而(4.10)其实就是 \([s_{\theta}(x,t)]_{y}\) 和 \(\frac{p_t(y)}{p_t(x)}\) 之间的Bregman Divergence \(D_F\left([s_{\theta}(x,t)]_{y},\frac{p_t(y)}{p_t(x)}\right)\) 。所以其实理论上,这篇博客的推导过程才比较符合逻辑,在选定好了连续时间马尔可夫链框架后,自然的从极大似然估计开始入手作为损失函数,然后尝试参数化Concrete Score,从原始的负对数似然期望(4.5)式,逐步推导得到(4.10)。而不是一上来说不考虑Concrete Score的MSE,直接定义 \(D_F\left([s_{\theta}(x,t)]_{y},\frac{p_t(y)}{p_t(x)}\right)\) 作为损失函数,最后再推导发现与极大似然估计存在关系,这就显得十分突兀。

当然,(4.10)中的真实Concrete Score也没法得到,所以去要转化成去噪形式的。注意到:

\[\begin{align} \mathbb E_{x_t}\sum_{y\ne x}\frac{p_t(y)}{p_t(x)}f_t(x,y) &=\sum_{y\ne x}p_t(y)f_t(x,y)\\ &=\sum_{y\ne x}\sum_{x_0}p_{t|0}(y\mid x_0)p_0(x_0)f_t(x,y)\\ &=\mathbb E_{x_0}\sum_{y\ne x}\frac{p_{t|0}(x\mid x_0)}{p_{t|0}(x\mid x_0)}p_{t|0}(y\mid x_0)f_t(x,y)\\ &=\mathbb E_{x_0,x_t}\sum_{y\ne x}\frac{p_{t|0}(y\mid x_0)}{p_{t|0}(x\mid x_0)}f_t(x,y)\\ \end{align}\tag{4.11}\]

带入到(4.10)就有:

\[\begin{align} \int p_{\text{data}}(\mathrm{d}x_0) \left[ -\log p_0^\theta(x_0) \right] &\leq \int_0^T \mathrm d t\int p_{t}(x)\mathrm{d}x\sum_{y\ne x}Q_t(x,y)\Bigg\{[s_{\theta}(x,t)]_{y}-\frac{p_t(y)}{p_t(x)}\log [s_{\theta}(x,t)]_{y} \Bigg\} +C' \\ &=\int_0^T \mathrm d t\int p_{0}(x_0)p_{t|0}(x|x_0)\mathrm{d}x\mathrm{d}x_0\sum_{y\ne x}Q_t(x,y)\Bigg\{[s_{\theta}(x,t)]_{y}-\frac{p_{t|0}(y\mid x_0)}{p_{t|0}(x\mid x_0)}\log [s_{\theta}(x,t)]_{y} \Bigg\} +C' \\ \end{align}\tag{4.12}\]

至此所有的元素都能计算得到,可以实际训练了,这也是原文的Theorem 3.4 (Denoising Score Entropy)

多维情况

对于多维情况,从第3节可以知道,如果我们定义了前向条件过程是独立的,如(3.12),那么逆向过程一个时刻的转移只会发生在一个元素上,这是由(3.17)决定的,所以多维的得分只需要考虑某个维度不同的情况 \(\left[ s_\theta(x^1 \ldots x^i \ldots x^d, t) \right]_{i, \widehat{x}^i} \approx \frac{p_t(x^1 \ldots \widehat{x}^i \ldots x^d)}{p_t(x^1 \ldots x^i \ldots x^d)} \) ,所以输入形状就是句子长度 \(d\) 和独热编码的空间大小 \(N\) 的乘积 \(d \times N\) ,输出形状也是 \(d \times N\) ,表示第 \(i\) 个token的Concrete Score。而对于多维情况下的损失函数,根据(4.12)有:

\[\begin{align} \mathcal L &=\int_0^T \mathrm d t\int p_{0}(\mathbf x_0)p_{t|0}(\mathbf x|\mathbf x_0)\mathrm{d}\mathbf x\mathrm{d}\mathbf x_0\sum_{\mathbf y\ne \mathbf x}Q_t(\mathbf x,\mathbf y)\Bigg\{[s_{\theta}(\mathbf x,t)]_{\mathbf y}-\frac{p_{t|0}(\mathbf y\mid \mathbf x_0)}{p_{t|0}(\mathbf x\mid \mathbf x_0)}\log [s_{\theta}(\mathbf x,t)]_{\mathbf y}\Bigg\}\\ &=\int_0^T \mathrm d t\int p_{0}(\mathbf x_0)p_{t|0}(\mathbf x|\mathbf x_0)\mathrm{d}\mathbf x\mathrm{d}\mathbf x_0 \sum_{i=1}^{d}\sum_{y^i\ne x^i}Q_t(x^i,y^i)\Bigg\{[s_{\theta}(\mathbf x,t)]_{y^i}-\frac{p_{t|0}(y^i\mid x_0^i)}{p_{t|0}(x^i\mid x_0^i)}\log [s_{\theta}(\mathbf x,t)]_{y^i}\Bigg\} \end{align}\tag{4.13}\]

其中第二个等号是只考虑了 \(\mathbf x, \mathbf y\) 只有一个元素不同的情况,因为前向过程一个时刻只有一个元素转移,由(3.16)决定。可以看到,相比第3节Continuous Time Framework转变到多维的损失函数(3.29)式,SEDD的多维损失函数(4.13)与一维的形式(4.12)保持了一致性,是一个非常好的性质。

这里再次想吐槽这篇文章的书写逻辑。除了先前一维情况下推导损失函数莫名其妙突然引入Bregman Divergence的概念之外,这里处理多维情况时说“只建模Hamming distnace为1之间的句子”,会非常让人不明所以。因为建模Hamming distnace为1之间的句子是由我们定义的前向过程条件独立((3.12))的情况下得来的,是被((3.17))式决定的。而且作者也没有给出多维情况下的损失函数(4.13)式的推导,这并不只是简单的把一维情况中的变量替换成多维,还需要深入理解前向过程条件独立扩散连续时间马尔可夫链的过程和正逆向转移速率矩阵的性质。

前向速率矩阵 \(Q_t\) 也选择 \(Q_t=\sigma(t)Q\) 的形式,其中 \(Q\) 的选取跟((3.31))类似。

从SDE的理论可以知道,除了用Euler的方式直接对SDE微分进行求解SDE,我们还可以考虑从分布 \(p_{\theta}(x_{t-\Delta t}\mid x_{t})\) 中进行采样(这里同样先考虑一个token更新的情况),跟之前的DDPM方式一样,考虑:

\[\begin{align} p_{\theta}(x_{t-\Delta t}\mid x_{t}) &\approx p_{\theta}(x_{t-\Delta t}\mid x_{t},x_0=\mu_\theta(x_{t})) \end{align}\tag{4.13}\]

而 \({\mu}_\theta(x_{t})\) 就是分布 \(p^{\theta}_{0|t}(x_{0}\mid x_{t})\) 的均值。由贝叶斯公式有:

\[p^{\theta}_{0|t}(x_{0}\mid x_{t})=p_{t|0}( x_{t}\mid x_{0})\left[\frac{p_0(x_0)}{p_t(x_t)}\right]_{\theta}\tag{4.14}\]

注意由于 \( x_0\) 是未知的, \(x_t\) 是给定的,所以 \(p_t(x_t)\) 是标量, \(p_0(x_0)\) 是 \(N\) 维向量,第维度 \(i\) 表示 \(0\) 时刻在 \(i\) 状态的概率, \(p^{\theta}_{0|t}(x_{0}\mid x_{t})\) 也是向量,第 \(i\) 维表示给定 \(x_t\) 后 \(x_0=i\) 的概率。根据(3.30)能得到一步前向转移速率 \(P_t\) ,而 \(p_0=P_t^{-1}p_t\) ,这里 \(p_t\) 是一个向量,表示 \(t\) 时刻各个状态的真实分布,所以有关系 \(p_0(i)=[P_t^{-1}p_t]_{i}\) 。考虑(4.14)第 \(i\) 维的元素,即 \(x_0=i\) ,则:

\[\left[p^{\theta}_{0|t}(x_{0}\mid x_{t})\right]_{i}=[P_t]_{x_t,i}\left[P_t^{-1}\frac{p_t}{p_t(x_t)}\right]^{\theta}_i=[P_t]_{x_t,i}[P_t^{-1}s_{\theta}(x_t,t)]_i \tag{4.15}\]

这就是离散形式的Tweedie's 定理。那么就能从(4.15)中选择最大的概率作为 \(x_0={\mu}_\theta(x_{t})\) ,带入(4.13)后再用贝叶斯公式(类似((1.2))式),就能知道下一时刻元素的分布了。知道每个元素如何更新后,那么也可以用tau-leaping来加速更新句子,即所有元素用共同的得分 \(s_{\theta}(\mathbf x_t,t)\) 同时更新。

最后来看看SEDD的效果:

PixPin_2026-01-20_14-43-48.png
SEDD量化结果

作为Diffusion来说,能赶上GPT-2我认为已经很不错了,还是有很大的进步空间的。

PixPin_2026-01-20_14-45-19.png
吸收态实验效果
PixPin_2026-01-20_14-46-04.png
均匀分布实验效果

明显感觉比D3PM正常一点,乍一看还像个句子。

总结一下,这篇ICML 2024的Best Paper并不是突然冒出来的,它的一切动机都是有迹可循的,主要的理论支撑来源于第3节,不过原文中一上来就是定义Bregman Divergence做为损失函数还有多维情况下建模Hamming distnace为1之间的句子也挺离谱的,逻辑非常奇怪,会让初读者根本摸不着头脑。而SEDD的点睛之笔在于参数化Concrete Score,推导出了去噪损失函数(4.12),使得训练流程变得十分自然,真正拉进了连续和离散扩散模型在训练和推理流程上的一致性。

Block Diffusion

Block Diffusion出自《Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models》是25年ICLR的oral。这篇文章相当于在自回归和扩散模型之间取了一个折中,同时吸收两种方式的有点。简单来说就是把一个句子拆成很多块,块跟块之间是自回归生成,由前前面所有块预测后一个块,块的内部用的是离散扩散模型,这样就同时兼具了自回归的质量和扩散模型的速度。在扩散模型对于文本生成来说是有速度上的优势,虽然需要一定的去噪步数,但如果生成较长的文本,比如2048个token的,对于自回归可能就需要2048次预测过程,但是对于扩散模型只需要128步或者1000步去噪过程即可。Block Diffusion的过程直接看github上的动画演示就非常好理解了:https://m-arriola.com/bd3lms/

PixPin_2026-01-20_15-16-53.png
PixPin_2026-01-20_15-17-30.png
PixPin_2026-01-20_15-17-57.png

整个方法的核心就是损失函数:

\[- \log p_\theta(\mathbf{x}) \leq \mathcal{L}_{\text{BD}}(\mathbf{x}; \theta) := \sum_{b=1}^B \mathbb{E}_{t \sim [0,1]} \mathbb{E}_q \left[ \frac{\alpha_t'}{1 - \alpha_t} \log p_\theta(\mathbf{x}^b \mid \mathbf{x}_t^b, \mathbf{x}^{<b}) \right]\tag{5.1}\]

因为Block Diffusion的扩散模型选用的是D3PM,所以它的损失函数形式跟离散时间的D3PM和DDPM那些一样。其中 \(\mathbf x^b\) 表示第 \(b\) 个块的所有tokens,块的长度为 \(L'\) , \(\mathbf x_t^b\) 表示第 \(t\) 步加噪后的块, \(\mathbf x^{<b}\) 表示 \(b\) 之前的所有块。所以非常明显的,损失函数为扩散模型的ELBO,而自回归的预测由条件生成控制,即前面所有块作为条件来控制当前第 \(b\) 个块的生成。

当然文章还有许多架构上和训练推理上的细节,但是核心的思想就是(5.1)式。另外,如果我们取 \(L '=1\) ,模型并不会退化为纯自回归模型,因为在一个token内还需要做扩散生成,所以效果反而还不如纯自回归;当然 \(L '\) 太大了也不好,这样就变成纯Diffusion模型了,影响了效果。所以需要折中的选取合适的 \(L '\) ,具体可以参考原文的消融实验。

最后Block Diffusion的效果还是很强的,相信有朝一日Diffusion能跟Transformer的架构掰掰手腕了:

PixPin_2026-01-20_14-48-16.png

不过Block Diffusion的扩散过程采用的是D3PM,不知道换成上述的连续时间马尔可夫链模型如SEDD效果会不会更好,我感觉应该会,毕竟SEDD效果比D3PM强太多,感兴趣的大佬们可以尝试一下。

总结

这篇博客主要介绍了离散扩散模型在文本上的生成。跟DDPM \(\rightarrow\) SDE的研究方式相似,一开始最经典的D3PM是定义在离散时间上的,接着推广到了连续时间上。但是连续时间框架Continuous Time Framework在训练上不太方便,因为需要采样一批batch来近似 \(p_t(x)\) 。而随后的SEDD从参数化Concrete Score开始,逐步推导出了类似conditional score matching去噪损失函数,这使得训练变得更加简单且有效,也真正开始将连续和离散扩散模型的训练和推理过程联系了起来。

Reference

Diffusion学习笔记(二十一)——离散扩散模型,文本生成