Overview
1.PPO 执行流程
2.PPO 计算流程模拟
3.PPO 训练的成功三要素
Recap: KL Divergence 在 RLHF 中的近似
KL 散度衡量两个分布之间的差异中的近似 KL 计算
在 LLM 的上下文中,我们对生成的 token 维度计算 (例如 step 为 $t$) 计算 policy 模型和 reference 模型下的 KL 散度
对于生成的句子,包括多个 token,我们计算平均 KL
PPO 执行流程
PPO 关键要素是 4 个模型 + 3 个损失:
准备 4 个 模型:
- policy model $\pi_{\theta}$, 从 SFT 模型初始化
- reference model $\pi_{ref}$, 从 SFT 模型初始化
- reward model $R_{\theta}$, 提前训练好的 SFT 模型
- value model $V_{\theta}$, value model 预测给定 [prompt + 上文] 情况下的未来累计奖励
计算总损失: policy 损失 + value 损失 + KL 损失
PPO 计算流程模拟
用一个具体的例子来完整演示 RLHF PPO 中 Advantage 和 Loss 的计算过程
- 提示词 (Prompt): $x$ = “请解释一下机器学习:”
- 生成的回答 (Response): $y$ = “它是一种人工智能的方法。”
- 假设这句话的 token 序列是 [它, 是, 一种, 人工智能, 的, 方法, 。],共 7 个 token)
- 演员模型 (Actor Model, $\pi_{\theta}$): 我们正在用 PPO 训练的模型
- 参考模型 (Reference Model, $π_{ref}$): 初始的 SFT 模型,固定不变
- 奖励模型 (Reward Model, $R_{\phi}$ or RM): 给 (x, y) 打分的模型
- 评论家模型 (Critic Model, $V_{\phi}(s)$): 用于估计状态值 $V(s)$ 的模型,与演员模型一起训练
超参数:
- KL 惩罚系数 $\beta=0.1$
- GAE 权重参数 $λ=0.95$
- 折扣因子 $γ=1$ 或者 $\gamma=0.99$
- PPO 裁剪范围 $ϵ=0.2$
第一步:前向生成与数据收集
模型处理 prompt $x$,并自回归地生成回答 $y$。在这个过程中,我们需要为每个时间步 t (每个 token 位置) 记录以下信息:
| $t$ | 状态 $s_t$ 上文 | 动作 $a_t$ 单步生成的 token | $π_θ(a_t \text{given} s_t)$ | $π_{ref} (a_t \text{given} s_t)$ | $V(s_t)$ |
|---|---|---|---|---|---|
| 1 | “请解释一下机器学习:” | 它 | 0.6 | 0.5 | 2.1 |
| 2 | “请解释一下机器学习:它” | 是 | 0.8 | 0.7 | 2.3 |
| 3 | “请解释一下机器学习:它是” | 一种 | 0.9 | 0.9 | 2.5 |
| 4 | “请解释一下机器学习:它是一种” | 人工智能 | 0.7 | 0.8 | 2.6 |
| 5 | “请解释一下机器学习:它是一种人工智能” | 的 | 0.95 | 0.9 | 2.7 |
| 6 | “请解释一下机器学习:它是一种人工智能的” | 方法 | 0.8 | 0.75 | 2.8 |
| 7 | “请解释一下机器学习:它是一种人工智能的方法” | 。 | 0.99 | 0.95 | 2.9 |
注意:
- $π_θ(a_t|s_t)$ 是演员模型在生成第 $t$ 个 token 时,选择真实动作 $a_t$ 的概率
- $V(s_t)$ 是评论家模型对当前状态 $s_t$ 的估值,它预测的是从当前状态开始到回合结束所能获得的预期总奖励
第二步:计算最终奖励 $R_{total}$
- 生成结束后,我们将完整的 $(x,y)$ 输入奖励模型 RM, 假设 RM 给出的分数 $R(x,y)=3.0$
- 计算平均 KL 散度, 为什么需要增加一个 KL 惩罚项? 我们算出来的这个 reward 就要增加一个惩罚, 惩罚当前的策略对于旧策略的偏移的程度, 如果当前策略相比旧策略便宜太多, 我们就让奖励小一点; 如果当前策略相比就策略偏移不大,那我们就让奖励和原始值差不多
- KL 散度是逐 token 计算再取平均:
| $t$ | $π_θ(a_t\text{given}s_t)$ | $π_{ref}(a_t\text{given}s_t)$ | $KL_t$ |
|---|---|---|---|
| 1 | 0.6 | 0.5 | $\ln(0.6)-\ln(0.5)≈-0.511 + 0.693 = 0.182$ |
| 2 | 0.8 | 0.7 | $\ln(0.8)-\ln(0.7)≈-0.223 + 0.357 = 0.134$ |
| 3 | 0.9 | 0.9 | $\ln(0.9)-\ln(0.9)=0.0$ |
| 4 | 0.7 | 0.8 | $\ln(0.7)-\ln(0.8)≈-0.357+0.223 = -0.134$ |
| 5 | 0.95 | 0.9 | $\ln(0.95)-\ln(0.9)≈-0.051+0.105 = 0.054$ |
| 6 | 0.8 | 0.75 | $\ln(0.8)-\ln(0.75)≈-0.223+0.288=0.065$ |
| 7 | 0.99 | 0.95 | $\ln(0.99)-\ln(0.95)≈-0.010+0.051=0.041$ |
最终奖励
这个奖励 $R_total≈2.995$ 就是整个回合 (整个句子) 的实际回报 Return
第三步:计算 Advantage $A_t$
我们需要计算每个 Timestamp 的 Advantage $A_t$,Advantage 的定义是: 在状态 $s_t$ 下采取动作 $a_t$ 相比平均情况要好多少
但是 $Q(s,a)$ 无法直接估计,只能用采样的回报信号来进行模拟, 其中一种方法就是使用 GAE (Generalized Advantage Estimator) 的方式来计算 advantage: 不用单步估计/也不用全回报,取二者的一个折中: 用多步 TD 残差做指数衰减加权平均
其中,
1.$\gamma=0.99$ 或者 $\gamma=1$ 控制未来奖励的折扣
2.$\lambda=0.95$ 控制奖励偏差和方差之间的权衡, $\lambda=0$ 的时候,只考虑当前的单步残差, 方差最小,偏差最大,只看眼前的一步;$\lambda=1$ 的时候,看完整的轨迹,回报无偏,相当于全回报的 MC 估计, 但是方差最大
3.$0<\lambda<1$, 平衡偏差和方差,部分利用未来奖励
4.其中细拆来看, 单步的优势估计是依赖 Temporal Difference Residual (TD 残差) 来估计单步的优势, 时间步 $s$ 下 TD 残差的公式如下:
4.直观理解 $\delta_t$ 的作用: 衡量单步 [现实的回报] 和 [当前价值预测] 之间的差距, 如果 $\delta_t>0$, 说明这一个动作比预期好, $\delta_t<0$, 说明当前动作比预期差; 其中 $r_t$ 是即时奖励,在文本生成任务中除最后一步外均为 0, 在上面的例子中
其余的
在我们的例子中 $γ=1$, $r_t$ 只在最后一步 $t=7$ 有值, 其他都是 $0$, 计算 TD 残差 $δ_t$:
采用递推方式从后往前计算 $A_t$ 公示如下:
计算过程
现在我们得到了每个时间步的 Advantage 值:
| $t$ | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
|---|---|---|---|---|---|---|---|
| $A_t$ | 0.795 | 0.626 | 0.448 | 0.366 | 0.280 | 0.190 | 0.095 |
解读: $A_1=0.795$ 意味着在第一步生成 “它” 这个 token 是一个非常好的决策,比 critic 模型当时的预估要好很多, 而最后一步生成句号的行为只是一个符合预期的操作 $A7=0.095$,贡献相对较小
第四步:计算总损失 (Loss)
PPO 的总损失 $\mathcal L_{\text{PPO}}$ 包含三部分:
- 策略梯度损失 $\mathcal L_{\text{policy gradient}}$
- 价值函数损失 $\mathcal L_{value}$
- KL 损失 $\mathcal L_{KL}$
我们以第一个时间步 $t=1$ 为例演示计算
公式:
1.策略梯度损失
其中概率比
这是怎么来的, 回顾下策略梯度的推导, 策略梯度的核心目标是最大化未来累计回报, 优化策略参数 $\theta$
对策略求梯度并更新梯度,根据策略梯度推导得到
其中 $\log\pi(a_t|s_t)$: 策略在 $s_t$ 状态下选择 $a_t$ 动作的概率, 梯度的方向是提升该动作的概率; $A_t$ 判断动作 $a_t$ 是好动作 ($A_t>0$) 还是差动作 ($A_t<0$)
PPO 思想是, 我们用 “概率比” 近似替代原始的 “概率选择” 这一项, 概率比一定是一个正数, 控制变化相对幅度:
结合动作好坏的指示信号 $A_t$, 因此得到
2.价值网络损失 value loss
Value net 预测当前价值的期望奖励,通过最小化均方误差损失实现
3.KL loss
防止模型偏离原始模型太远
但是实际词表很大,通常不会按照 KL 的标准定义在全词表上计算, 而是对生成 token 计算 log-prob 差值:
逐个 token 计算后, 然后对序列取平均
PPO 训练的成功三要素
在 LLM 场景中,PPO 的成功与否,不仅取决于算法参数(ε、γ、λ、β)的调优,更依赖于
(i). 高质量的 RM
(ii). 合理的奖励设计
(iii). 多样的 Prompt 池
Reference
[1]. Secrets of RLHF in Large Language Models Part I: PPO.
转载请注明来源 goldandrabbit.github.io