RLHF in LLM

Mar 12, 2025
2 views
Reinforcement Learning

引言

大语言模型(LLMs)在近年来取得了显著进展,展现出上下文学习、指令跟随和逐步推理等突出特性。然而,由于这些模型是在包含高质量和低质量数据的预训练语料库上训练的,它们可能会表现出编造事实、生成有偏见或有毒文本等意外行为。因此,将LLMs与人类价值观对齐变得至关重要,特别是在帮助性、诚实性和无害性(3H)方面。

基于人类反馈的强化学习(RLHF)已被验证为有效的对齐方法,但训练过程复杂且不稳定。本文深入分析了RLHF框架,特别是PPO算法的内部工作原理,并提出了PPO-max算法,以提高策略模型训练的稳定性和效果。

RLHF的基本框架

RLHF训练过程包括三个主要阶段:

  1. 监督微调(SFT):模型通过模仿人类标注的对话示例来学习一般的人类对话方式, 优化模型的指令跟随能力
  2. 奖励模型(RM)训练:模型学习基于人类反馈比较不同回复的偏好
  3. 近端策略优化(PPO):模型基于奖励模型的反馈进行更新,通过探索和利用来发现优化的策略

奖励建模(Reward Model)

奖励模型使用预训练的基于Transformer的语言模型,移除最后的反嵌入层,并在最终Transformer层添加一个额外的线性层。给定任何文本,奖励模型会为最后一个标记分配一个标量奖励值,奖励值越大,样本越好。

奖励模型的训练损失函数为:

\[ L(\psi) = \log \sigma(r(x, y_w) - r(x, y_l)) \]

其中 \(\sigma\) 是sigmoid函数,\(r\) 表示参数为 \(\psi\) 的奖励模型,\(r(x, y)\) 是输入提示 \(x\) 和回复 \(y\) 的预测奖励。\(y_w\) 和 \(y_l\) 分别是较好和较差的模型response

此外,还引入了自回归LM损失,使模型模仿每个比较的pair对中的首选回复

\[ L(\psi) = -\lambda E_{(x,y_w,y_l)\sim D_{rm}}[\log \sigma(r(x, y_w) - r(x, y_l))] + \beta_{rm}E_{(x,y_w)\sim D_{rm}}[\log(r'(x, y_w)] \]

其中\(D_{rm}\) 是训练集的经验分布,\(r'\)\(r\) 相同,只是顶部线性层的维度对应于词汇表大小。

奖励函数还包括基于学习的 RL 策略 \(\pi_{RL}^{\phi}\) 和初始监督模型 \(\pi^{SFT}\)之间的KL散度的惩罚项:

\[ r_{total} = r(x, y) - \eta KL(\pi_{RL}^{\phi}(y|x), \pi^{SFT}(y|x)) \]

其中 \(\eta\) 是KL奖励系数,控制KL惩罚的强度。这里 KL项起着两个重要作用。首先,它起到熵奖励的作用,促进对策略环境的探索,并防止策略过早地收敛到单一模式。其次,它致力于确保 RL 策略的输出不会与奖励模型在训练阶段遇到的样本大幅偏差。

💡 *为什么这里的KL散度项可以促进多样性和创新性*

LLM 强化学习

在对话生成中应用RL面临着巨大的状态-动作空间挑战。在这种情况下,我们将人类交互视为"环境"。在每个时间步 \(t\),代理(即AI助手)从环境(即对话历史)接收状态\(s_t\),基于其策略 \(\pi\),代理的动作 \(a_t\) 是生成下一个标记。环境返回奖励 \(r(s_t, a_t)\),代理然后转换到下一个状态 \(s_{t+1}\)

RL的目标是找到代理的最优行为策略,以最大化轨迹\(\tau = \{s_1, a_1, ..., s_T, a_T\}\)上的累积奖励(即回报)。

策略梯度方法

策略梯度方法是一类直接优化代理策略(状态到动作的映射)的RL技术,而不是学习价值函数。这些方法的核心思想是使用梯度上升算法来改进策略。

策略 \(\pi\) 通常由 \(\theta\) 参数化,表示为\(\pi(a|s, \theta)\),即在状态 \(s\) 下采取动作 \(a\) 的概率。策略梯度的更新规则为:

\[ \theta \leftarrow \theta + \alpha\nabla_{\theta}J(\theta) \]

其中\(\alpha\)是学习率,\(J(\theta)\) 表示遵循策略 \(\pi_{\theta}\) 时的期望回报。

策略梯度的一般形式可以表示为:

\[ \nabla_{\theta}J(\theta) = E_{\tau \sim\pi_{\theta}}\left[\sum_{t=0}^{T}\nabla_{\theta}\log\pi_{\theta}(a_t|s_t)\Phi_t\right] \]

其中 \(\Phi_t\) 可以是 \(\Phi_t = R(\tau)\)\(\Phi_t = \sum_{t'=t}^{T}R(s_{t'}, a_{t'})\)\(\Phi_t = \sum_{t'=t}^{T}R(s_{t'}, a_{t'}) - b(s_t)\),带有基线 \(b\)

为了减少方差,常见的策略是在策略梯度更新规则中使用优势函数估计代替原始回报:

\[ \Phi_t = A(s_t, a_t) \]

数学上,\(A(s_t, a_t) = Q(s_t, a_t) - V(s_t)\),其中 \(Q(s_t, a_t)\) 是动作价值函数,\(V(s_t)\)是价值函数。

广义优势估计(GAE)

广义优势估计(Generalized Advantage Estimation, GAE)是一种在强化学习中计算优势函数的方法,旨在平衡估计的偏差和方差。在理解GAE之前,我们需要先回顾几个关键概念。

  • 时序差分(TD)误差
  • n步回报估计
  • n步优势估计
    我们可以用TD误差重写n步优势。将TD误差的公式代入n步优势公式,我们可以得到:

$$

\[\begin{aligned} \hat{A}_t^{(n)} &= r_t + \gamma r_{t+1} + ... + \gamma^{n-1} r_{t+n-1} + \gamma^n V(s_{t+n}) - V(s_t) \\ &= [\delta_t + V(s_t) - \gamma V(s_{t+1})] + \gamma [\delta_{t+1} + V(s_{t+1}) - \gamma V(s_{t+2})] + ... \\ &\quad + \gamma^{n-1} [\delta_{t+n-1} + V(s_{t+n-1}) - \gamma V(s_{t+n})] + \gamma^n V(s_{t+n}) - V(s_t)\\ &= \delta_t + \gamma \delta_{t+1} + \gamma^2 \delta_{t+2} + ... + \gamma^{n-1} \delta_{t+n-1} \end{aligned}\]

$$

这表明n步优势可以表示为一系列折扣TD误差的和。

GAE的定义与推导

GAE将不同长度的n步优势进行加权平均,权重为\((1-\lambda)\lambda^{k-1}\)

$$

\[\begin{aligned} \hat{A}_t^{GAE(\gamma, \lambda)} &= (1-\lambda)(\hat{A}_t^{(1)} + \lambda \hat{A}_t^{(2)} + \lambda^2 \hat{A}_t^{(3)} + ...) \\ &= (1-\lambda)(\delta_t + \lambda(\delta_t + \gamma \delta_{t+1}) + \lambda^2(\delta_t + \gamma \delta_{t+1} + \gamma^2 \delta_{t+2}) + ...)\\ &= (1-\lambda)[\delta_t(1 + \lambda + \lambda^2 + ...) + \gamma \delta_{t+1}(\lambda + \lambda^2 + ...) + \gamma^2 \delta_{t+2}(\lambda^2 + ...) + ...] \\ &= (1-\lambda)[\delta_t \frac{1}{1-\lambda} + \gamma \delta_{t+1} \frac{\lambda}{1-\lambda} + \gamma^2 \delta_{t+2} \frac{\lambda^2}{1-\lambda} + ...] \end{aligned}\]

$$

注意到几何级数 \(1 + \lambda + \lambda^2 + ... = \frac{1}{1-\lambda}\)(当\(|\lambda| < 1\)),以及 \(\lambda + \lambda^2 + ... = \frac{\lambda}{1-\lambda}\),等等。

简化后得到:

\[ \begin{aligned} \hat{A}_t^{GAE(\gamma, \lambda)} &= \delta_t + \lambda \gamma \delta_{t+1} + (\lambda \gamma)^2 \delta_{t+2} + ... \\ &= \sum_{l=0}^{\infty} (\lambda \gamma)^l \delta_{t+l} \end{aligned} \]

这就是GAE的最终表达式。

GAE有两个重要的边界情况:

  1. \(\lambda = 0\) 时:
  2. \(\lambda = 1\) 时:

实际应用

在实际应用中,我们通常在有限长度的轨迹上计算GAE。假设轨迹长度为T,则:

\[ \hat{A}_t^{GAE(\gamma, \lambda)} = \sum_{l=0}^{T-t-1} (\lambda \gamma)^l \delta_{t+l} \]

这可以通过一个高效的反向迭代过程计算:

  1. 初始化 \(\hat{A}_T = 0\)
  2. 对于 \(t = T-1, T-2, ..., 0\)
    在PPO算法中,GAE用于计算策略梯度中的优势估计:
\[ \nabla_{\theta}J(\theta) \approx \frac{1}{N} \sum_{i=1}^{N} \sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t^i|s_t^i) \hat{A}_t^{i, GAE(\gamma, \lambda)} \]

这里:

  • \(\pi_{\theta}\) 是参数化策略
  • \(\hat{A}_t^{i, GAE(\gamma, \lambda)}\) 是使用GAE计算的优势估计
  • \(N\) 是收集的轨迹数量
    GAE在实践中表现出色的原因有以下几点:
  1. 平滑的偏差-方差权衡:通过调整 \(\lambda\) 参数,可以在不同任务上找到最佳平衡点
  2. 稳定的学习:减少了优势估计的方差,使策略梯度更加稳定
  3. 适应不同时间尺度:能够有效处理短期和长期奖励信号
  4. 与价值函数学习协同:随着价值函数估计改进,GAE的性能也会提高

数值示例

假设我们有以下轨迹数据:

  • 状态价值:\(V(s_1) = 0.5, V(s_2) = 0.6, V(s_3) = 0.7, V(s_4) = 0.8\)
  • 奖励:\(r_1 = 0.1, r_2 = 0.2, r_3 = 0.3\)
  • 参数:\(\gamma = 0.9, \lambda = 0.8\)
    计算TD误差:

  • \(\delta_1 = r_1 + \gamma V(s_2) - V(s_1) = 0.1 + 0.9 \times 0.6 - 0.5 = 0.14\)

  • \(\delta_2 = r_2 + \gamma V(s_3) - V(s_2) = 0.2 + 0.9 \times 0.7 - 0.6 = 0.23\)
  • \(\delta_3 = r_3 + \gamma V(s_4) - V(s_3) = 0.3 + 0.9 \times 0.8 - 0.7 = 0.32\)
    使用反向迭代计算GAE:

  • \(\hat{A}_3 = \delta_3 = 0.32\)

  • \(\hat{A}_2 = \delta_2 + \lambda \gamma \hat{A}_3 = 0.23 + 0.8 \times 0.9 \times 0.32 \approx 0.46\)
  • \(\hat{A}_1 = \delta_1 + \lambda \gamma \hat{A}_2 = 0.14 + 0.8 \times 0.9 \times 0.46 \approx 0.47\)
    这样,我们得到了状态1、2、3的GAE优势估计值。

近端策略优化(PPO)

近端策略优化(Proximal Policy Optimization, PPO)是一种策略梯度算法,旨在解决策略优化中的两个关键问题:

  1. 如何确定适当的步长以避免性能崩溃
  2. 如何有效利用采样数据进行多次参数更新
    PPO的核心思想是在保证性能单调改进的同时,允许较大的策略更新步长。

TRPO的约束优化问题

信任区域策略优化(Trust Region Policy Optimization, TRPO)提出了以下约束优化问题:

$$

\[\begin{aligned} \max_{\theta} \mathbb{E}_{s,a \sim \pi_{\theta_{old}}} \left[ \frac{\pi_{\theta}(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a) \right] \\ \text{subject to } D_{KL}(\pi_{\theta_{old}} || \pi_{\theta}) \leq \delta \end{aligned}\]

$$

其中:

  • \(\pi_{\theta}\) 是当前策略
  • \(\pi_{\theta_{old}}\) 是旧策略
  • \(A^{\pi_{\theta_{old}}}(s,a)\) 是在旧策略下的优势函数
  • \(D_{KL}\) 是KL散度
  • \(\delta\) 是信任区域大小
    TRPO的主要缺点是计算复杂,难以与架构(如共享参数或循环网络)集成。

PPO的简化方法