Training-Inference-Mismatch
版本依赖:ms-swift>=3.11
TL;DR: GRPO 引入 vLLM 加速采样过程的同时,也引入了训练-推理不一致(Training-Inference Mismatch)的问题,从而可能影响训练稳定性。本文将解释这个问题的背景、原因以及相应的解决方案。
Background
GRPO 的基本假设
GRPO (Group Relative Policy Optimization) 的训练目标可以表示为:
其中:
\(r_t(\theta) = \frac{\pi_\theta(y_t|x, y_{<t})}{\pi_{\theta_{\text{old}}}(y_t|x, y_{<t})}\) 是重要性采样比
\(\hat{A}_t\) 是优势函数(advantage),基于 reward 和 group baseline 计算
\(\epsilon\) 是 clipping 参数
核心假设:样本 \(y\) 是从策略 \(\pi_\theta\) 中采样得到的。在实际训练中,这意味着:
采样模型(rollout model)与训练模型(policy model)应当是同一个模型 \(\pi_\theta\)
两个模型的概率分布应当完全一致,即 \(\pi_{\text{rollout}} = \pi_\theta\)
引入 vLLM 后的假设偏离
GRPO 的训练速度很大程度上受到采样过程(rollout)的速度制约。为了加速,训练框架引入高效推理引擎(如 vLLM)来执行采样。理想假设是:通过权重同步,vLLM 与训练模型保持一致,即 \(\pi_{\text{vLLM}} \equiv \pi_\theta\)。
然而,在实践中,即使权重完全同步,由于算子实现等差异,两者的概率分布仍然存在偏差:
此时,实际的训练目标变为:
其中样本来自 \(\pi_{\text{vLLM}}\),但梯度是基于 \(\pi_\theta\) 计算的,这破坏了算法的 on-policy 假设,引入了训推不一致的问题。
Solution
针对训推不一致问题,可以引入**重要性采样(Importance Sampling, IS)**的校正机制。
重要性采样校正
重要性采样的基本思想是:当样本来自分布 \(q\) 而非目标分布 \(p\) 时,可以通过引入权重来修正期望的计算:
应用到 GRPO 的场景,修正后的损失函数为:
其中 \(w(x, y)\) 是重要性采样权重,用于校正 vLLM 与训练模型之间的分布偏差
重要性采样权重可以在不同粒度上计算和应用:
Token-Level
每个 token 上计算重要性采样比:
Sequence-Level
计算序列级别的重要性采样比,然后广播到每个 token:
稳定性控制:Truncate vs. Mask
过大的重要性采样权重会导致梯度爆炸,破坏训练稳定性。因此需要对权重进行控制:
1. Truncate(截断)
将重要性采样权重截断到 \([0, \tau]\) 区间:
该方法保留所有样本,但限制其影响范围。
2. Mask(屏蔽)
舍弃权重超过阈值的 token/sequence 数据
四种校正模式
结合粒度和控制策略,共设置四种校正模式(通过 --rollout_importance_sampling_mode 参数选择):
| 模式 | 说明 |
|---|---|
token_truncate |
Token 级截断 |
token_mask |
Token 级屏蔽 |
sequence_truncate |
Sequence 级截断 |
sequence_mask |
Sequence 级屏蔽 |
其中阈值通过 --rollout_importance_sampling_threshold 参数设置。
Metrics
为了监控训练中训推不一致的程度,我们在log中加入以下指标(前缀为 rollout_correction/):
1. KL 散度(KL Divergence)
KL 散度衡量 rollout 策略与训练策略之间的偏离程度。两个指标都估计 \(\text{KL}(\pi_{\text{vLLM}} \| \pi_\theta)\)
直接估计器 kl:
K3 估计器 k3_kl:
K3 估计器在 KL 值较小时数值更稳定,且始终非负。
2. Perplexity (PPL)
困惑度衡量模型对序列的预测不确定性:
相关指标:
training_ppl/training_log_ppl:训练策略的 PPL 及其对数rollout_ppl/rollout_log_ppl:rollout 策略的 PPL 及其对数log_ppl_diff:log PPL 差异,正值表示训练策略分配的概率更低log_ppl_abs_diff:log PPL 绝对差异log_ppl_diff_max/log_ppl_diff_min:log PPL 差异的最大/最小值ppl_ratio:PPL 比率 \(\frac{\text{PPL}_{\text{training}}}{\text{PPL}_{\text{rollout}}}\)
3. χ² 散度(Chi-squared Divergence)
χ² 散度衡量重要性采样权重的方差:
chi2_token:Token 级别 χ² 散度,\(\mathbb{E}[\rho_t^2] - 1\)chi2_seq:Sequence 级别 χ² 散度(基于几何平均),\(\mathbb{E}[\rho_{\text{geo}}^2] - 1\),其中 \(\rho_{\text{geo}} = \exp(\frac{1}{T}\sum_t \log \rho_t)\)
χ² 散度越大,表示 IS 权重方差越大,训练越不稳定。chi2_seq 使用几何平均而非乘积,使其与 chi2_token 在量级上可比较。
4. Effective Sample Size (ESS)
有效样本大小衡量重要性采样后实际起作用的样本数量:
ESS 值越大(接近1),表示重要性采样权重分布越均匀,样本的有效利用率越高。当所有权重相等时(on-policy),ESS = 1;当权重差异很大时(严重 off-policy),ESS 会很小。
5. IS 权重统计
is_weight_mean:平均重要性采样权重,理想值为 1.0clipped_frac:被截断或屏蔽的样本比例
使用方式
仅记录诊断指标
如果只想监控训推不一致的程度,而不启用重要性采样校正,可以设置:
--log_rollout_offpolicy_metrics true
这将记录上述所有诊断指标(KL、PPL、χ² 等),但不会对损失函数进行任何修正。
启用重要性采样校正
在GRPO训练中,设置以下参数启用校正机制:
--rollout_importance_sampling_mode (默认为None)
--rollout_importance_sampling_threshold (默认为2)
当设置了 rollout_importance_sampling_mode 时,诊断指标会自动记录,无需额外设置 log_rollout_offpolicy_metrics。
Off-Policy Sequence Masking
除了重要性采样校正外,还可以使用 Off-Policy Sequence Masking 技术来处理训推不一致问题。该技术来自 DeepSeek-V3.2 论文。
原理
Off-Policy Sequence Masking 的核心思想是:当当前策略相对于旧策略(rollout 或 old policy)发生较大偏移时,直接丢弃(mask)该序列,不参与损失计算。这种方法特别针对优势为负的序列,因为这些序列在策略偏移较大时更容易导致训练不稳定。
具体来说,对于每个序列,计算:
当满足以下条件时,序列 \(i\) 将被 mask 掉(平均只在 completion token 上计算,即 completion_mask=1 的位置):
\(\delta_i > \tau\)
且 \(\hat{A}_i < 0\)
其中:
\(\pi_{\text{old}}\) 优先使用
rollout_per_token_logps(rollout/行为策略的 logprobs),若不存在则使用old_per_token_logps\(\tau\) 是用户设置的阈值(
--off_policy_sequence_mask_delta,默认 None 表示关闭)
参考资料
https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda
https://fengyao.notion.site/off-policy-rl
https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/rollout_corr_helper.py
DeepSeek-V3.2: Pushing the Frontier of Open Large Language Models