Secrets of RLHF in Large Language Models Part I PPO

Overview

1.PPO 执行流程
2.PPO 计算流程模拟
3.PPO 训练的成功三要素

Recap: KL Divergence 在 RLHF 中的近似

KL 散度衡量两个分布之间的差异中的近似 KL 计算

在 LLM 的上下文中,我们对生成的 token 维度计算 (例如 step 为 $t$) 计算 policy 模型和 reference 模型下的 KL 散度

对于生成的句子,包括多个 token,我们计算平均 KL

PPO 执行流程

PPO 关键要素是 4 个模型 + 3 个损失:

  1. 准备 4 个 模型:

    1. policy model $\pi_{\theta}$, 从 SFT 模型初始化
    2. reference model $\pi_{ref}$, 从 SFT 模型初始化
    3. reward model $R_{\theta}$, 提前训练好的 SFT 模型
    4. value model $V_{\theta}$, value model 预测给定 [prompt + 上文] 情况下的未来累计奖励
  2. 计算总损失: policy 损失 + value 损失 + KL 损失

PPO 计算流程模拟

用一个具体的例子来完整演示 RLHF PPO 中 Advantage 和 Loss 的计算过程

  • 提示词 (Prompt): $x$ = “请解释一下机器学习:”
  • 生成的回答 (Response): $y$ = “它是一种人工智能的方法。”
  • 假设这句话的 token 序列是 [它, 是, 一种, 人工智能, 的, 方法, 。],共 7 个 token)
  • 演员模型 (Actor Model, $\pi_{\theta}$): 我们正在用 PPO 训练的模型
  • 参考模型 (Reference Model, $π_{ref}$): 初始的 SFT 模型,固定不变
  • 奖励模型 (Reward Model, $R_{\phi}$ or RM): 给 (x, y) 打分的模型
  • 评论家模型 (Critic Model, $V_{\phi}(s)$): 用于估计状态值 $V(s)$ 的模型,与演员模型一起训练

超参数:

  • KL 惩罚系数 $\beta=0.1$
  • GAE 权重参数 $λ=0.95$
  • 折扣因子 $γ=1$ 或者 $\gamma=0.99$
  • PPO 裁剪范围 $ϵ=0.2$

第一步:前向生成与数据收集

模型处理 prompt $x$,并自回归地生成回答 $y$。在这个过程中,我们需要为每个时间步 t (每个 token 位置) 记录以下信息:

$t$ 状态 $s_t$ 上文 动作 $a_t$ 单步生成的 token $π_θ(a_t \text{given} s_t)$ $π_{ref} (a_t \text{given} s_t)$ $V(s_t)$
1 “请解释一下机器学习:” 0.6 0.5 2.1
2 “请解释一下机器学习:它” 0.8 0.7 2.3
3 “请解释一下机器学习:它是” 一种 0.9 0.9 2.5
4 “请解释一下机器学习:它是一种” 人工智能 0.7 0.8 2.6
5 “请解释一下机器学习:它是一种人工智能” 0.95 0.9 2.7
6 “请解释一下机器学习:它是一种人工智能的” 方法 0.8 0.75 2.8
7 “请解释一下机器学习:它是一种人工智能的方法” 0.99 0.95 2.9

注意:

  1. $π_θ(a_t|s_t)$ 是演员模型在生成第 $t$ 个 token 时,选择真实动作 $a_t$ 的概率
  2. $V(s_t)$ 是评论家模型对当前状态 $s_t$ 的估值,它预测的是从当前状态开始到回合结束所能获得的预期总奖励

第二步:计算最终奖励 $R_{total}$

  1. 生成结束后,我们将完整的 $(x,y)$ 输入奖励模型 RM, 假设 RM 给出的分数 $R(x,y)=3.0$
  2. 计算平均 KL 散度, 为什么需要增加一个 KL 惩罚项? 我们算出来的这个 reward 就要增加一个惩罚, 惩罚当前的策略对于旧策略的偏移的程度, 如果当前策略相比旧策略便宜太多, 我们就让奖励小一点; 如果当前策略相比就策略偏移不大,那我们就让奖励和原始值差不多
  3. KL 散度是逐 token 计算再取平均:
$t$ $π_θ(a_t\text{given}s_t)$ $π_{ref}(a_t\text{given}s_t)$ $KL_t$
1 0.6 0.5 $\ln(0.6)-\ln(0.5)≈-0.511 + 0.693 = 0.182$
2 0.8 0.7 $\ln(0.8)-\ln(0.7)≈-0.223 + 0.357 = 0.134$
3 0.9 0.9 $\ln(0.9)-\ln(0.9)=0.0$
4 0.7 0.8 $\ln(0.7)-\ln(0.8)≈-0.357+0.223 = -0.134$
5 0.95 0.9 $\ln(0.95)-\ln(0.9)≈-0.051+0.105 = 0.054$
6 0.8 0.75 $\ln(0.8)-\ln(0.75)≈-0.223+0.288=0.065$
7 0.99 0.95 $\ln(0.99)-\ln(0.95)≈-0.010+0.051=0.041$

最终奖励

这个奖励 $R_total≈2.995$ 就是整个回合 (整个句子) 的实际回报 Return

第三步:计算 Advantage $A_t$

我们需要计算每个 Timestamp 的 Advantage $A_t$,Advantage 的定义是: 在状态 $s_t$ 下采取动作 $a_t$ 相比平均情况要好多少

但是 $Q(s,a)$ 无法直接估计,只能用采样的回报信号来进行模拟, 其中一种方法就是使用 GAE (Generalized Advantage Estimator) 的方式来计算 advantage: 不用单步估计/也不用全回报,取二者的一个折中: 用多步 TD 残差做指数衰减加权平均

其中,
1.$\gamma=0.99$ 或者 $\gamma=1$ 控制未来奖励的折扣
2.$\lambda=0.95$ 控制奖励偏差和方差之间的权衡, $\lambda=0$ 的时候,只考虑当前的单步残差, 方差最小,偏差最大,只看眼前的一步;$\lambda=1$ 的时候,看完整的轨迹,回报无偏,相当于全回报的 MC 估计, 但是方差最大
3.$0<\lambda<1$, 平衡偏差和方差,部分利用未来奖励
4.其中细拆来看, 单步的优势估计是依赖 Temporal Difference Residual (TD 残差) 来估计单步的优势, 时间步 $s$ 下 TD 残差的公式如下:

4.直观理解 $\delta_t$ 的作用: 衡量单步 [现实的回报] 和 [当前价值预测] 之间的差距, 如果 $\delta_t>0$, 说明这一个动作比预期好, $\delta_t<0$, 说明当前动作比预期差; 其中 $r_t$ 是即时奖励,在文本生成任务中除最后一步外均为 0, 在上面的例子中

其余的

在我们的例子中 $γ=1$, $r_t$ 只在最后一步 $t=7$ 有值, 其他都是 $0$, 计算 TD 残差 $δ_t$:

采用递推方式从后往前计算 $A_t$ 公示如下:

计算过程

现在我们得到了每个时间步的 Advantage 值:

$t$ 1 2 3 4 5 6 7
$A_t$ 0.795 0.626 0.448 0.366 0.280 0.190 0.095

解读: $A_1=0.795$ 意味着在第一步生成 “它” 这个 token 是一个非常好的决策,比 critic 模型当时的预估要好很多, 而最后一步生成句号的行为只是一个符合预期的操作 $A7=0.095$,贡献相对较小

第四步:计算总损失 (Loss)

PPO 的总损失 $\mathcal L_{\text{PPO}}$ 包含三部分:

  1. 策略梯度损失 $\mathcal L_{\text{policy gradient}}$
  2. 价值函数损失 $\mathcal L_{value}$
  3. KL 损失 $\mathcal L_{KL}$

我们以第一个时间步 $t=1$ 为例演示计算

公式:

1.策略梯度损失

其中概率比

这是怎么来的, 回顾下策略梯度的推导, 策略梯度的核心目标是最大化未来累计回报, 优化策略参数 $\theta$

对策略求梯度并更新梯度,根据策略梯度推导得到

其中 $\log\pi(a_t|s_t)$: 策略在 $s_t$ 状态下选择 $a_t$ 动作的概率, 梯度的方向是提升该动作的概率; $A_t$ 判断动作 $a_t$ 是好动作 ($A_t>0$) 还是差动作 ($A_t<0$)

PPO 思想是, 我们用 “概率比” 近似替代原始的 “概率选择” 这一项, 概率比一定是一个正数, 控制变化相对幅度:

结合动作好坏的指示信号 $A_t$, 因此得到

2.价值网络损失 value loss
Value net 预测当前价值的期望奖励,通过最小化均方误差损失实现

3.KL loss
防止模型偏离原始模型太远

但是实际词表很大,通常不会按照 KL 的标准定义在全词表上计算, 而是对生成 token 计算 log-prob 差值:

逐个 token 计算后, 然后对序列取平均

PPO 训练的成功三要素

在 LLM 场景中,PPO 的成功与否,不仅取决于算法参数(ε、γ、λ、β)的调优,更依赖于
(i). 高质量的 RM
(ii). 合理的奖励设计
(iii). 多样的 Prompt 池

Reference

[1]. Secrets of RLHF in Large Language Models Part I: PPO.


转载请注明来源 goldandrabbit.github.io