知识蒸馏(Knowledge Distillation)

知识蒸馏是一种将教师模型(teacher model)的能力迁移到学生模型(student model)的训练方法。其核心思想是:让学生在每个 token 位置上向教师的输出分布靠拢,从而获得比单纯模仿标注答案更丰富的监督信号——教师不仅告诉学生「哪个 token 是对的」,还告诉学生「其他 token 有多好 / 多差」。

本文档自顶向下地介绍:蒸馏为什么有效(第一节)、蒸馏方法的统一设计框架(第二节),最后落到 swift 中三种具体的蒸馏训练方法 GKD / OPD-RL / OPSD(第三节)。


一、为什么需要蒸馏:从稀疏信号到稠密信号

一个语言模型的能力,通常由一连串训练阶段堆叠而成:

  • 预训练(Pre-training):习得语言、世界知识、基础推理等通用能力。

  • 中训练(Mid-training):注入领域知识,如代码、医学、公司内部文档等。

  • 后训练(Post-training):激发目标行为,如指令遵循、数学推理、对话风格等。

蒸馏主要发生在后训练阶段。要理解它的价值,需要从两个彼此独立的维度来看待后训练方法:

  1. 采样方式(数据从哪来):训练序列是由学生自己生成(on-policy),还是来自外部固定数据(off-policy)。

  2. 反馈密度(每条序列能学到多少):是整条序列只有一个奖励(per-sequence,稀疏),还是每个 token 都有信号(per-token,稠密)。

SFT / 离线蒸馏(off-policy + 稠密):在固定数据上对齐标注或教师分布。信号稠密,但训练时见到的都是教师/标注的状态,和学生推理时自己会进入的状态不一致。学生一旦在推理早期犯了教师不会犯的错,就会进入训练中从未见过的状态,误差不断累积,这被称为 exposure bias(曝光偏差)

RL(on-policy + 稀疏):学生自己采样,按最终结果给奖励。分布与学生推理一致,但奖励通常是序列级标量,一般不指明具体哪个 token 出错。

On-policy 蒸馏(on-policy + 稠密):学生自己采样轨迹,再由教师对轨迹的每一个 token 打分。训练分布与学生推理分布一致,且反馈为 per-token 级别。

SFT 是蒸馏的一个特例

理解蒸馏的一个自然切入点是 SFT 的损失函数。SFT 的 cross-entropy loss,等价于以标注 token 的 one-hot 分布 \(\delta_{y^*}\) 为「教师」的 KL 散度:

\[-\log P_S(y^*_t) = \text{KL}(\delta_{y^*} \,\|\, P_S)\]

知识蒸馏只是把这个确定性的 one-hot「教师」换成了真实教师模型的软分布 \(P_T\),在每个 token 上优化 \(\text{KL}(P_T \,\|\, P_S)\),从而提供比 one-hot 丰富的监督信号。

方法 采样方式 反馈密度 教师分布
SFT off-policy(固定数据) 稠密(per-token) one-hot $\delta_{y^*}$
离线(off-policy)蒸馏 off-policy(固定数据或教师生成) 稠密(per-token) 教师软分布
RL on-policy(学生采样) 稀疏(per-sequence)
On-policy 蒸馏 on-policy(学生采样) 稠密(per-token) 教师软分布

二、蒸馏的两个核心选择

不同蒸馏方法的差异,几乎都可以归结为两个问题。理解了这两个维度,后面的方法都只是它们的不同组合。

2.1 教师信号怎么算

在每个 token 位置上,量化教师分布 \(P_T\) 与学生分布 \(P_S\) 的差异(我们称之为 Teacher KL),有两个子选择。

(a) 散度方向

散度 定义 优化时的行为(信息论含义)
Forward KL $\text{KL}(P_T \,|\, P_S)$ Mode-covering:学生需对教师概率较高的区域都赋予足够概率
Reverse KL $\text{KL}(P_S \,|\, P_T)$ Mode-seeking:学生主要拟合教师的众数(高概率)区域
广义 JSD($\beta$) $\beta\,\text{KL}(P_T|M) + (1-\beta)\,\text{KL}(P_S|M)$,其中 $M=\beta P_T+(1-\beta)P_S$ 在两者之间插值

其中 \(\beta=0\) 退化为 Forward KL,\(\beta=1\) 退化为 Reverse KL。SFT 等价于 Forward KL(教师为 one-hot)。

在 swift 中:

  • GKD 默认 \(\beta=0.5\)(JSD),可通过 --beta 在 Forward / JSD / Reverse 之间选择;

  • OPD-RL 的实现固定使用 Reverse KL 的 k1 估计量 \(\log\pi_{\text{teacher}}(y_t)-\log\pi_{\text{student}}(y_t)\) 作为 per-token advantage。

(b) 计算粒度

计算粒度 需要的教师信息 说明
全词表 教师完整的 next-token 分布 可计算散度的精确值;显存开销大
Top-K 教师概率最高的 K 个 token 在 top-K 子集上重新归一化后的近似;适合外部 API(受 max_logprobs 限制)
采样 token 教师在学生实际采样 token 上的单个 logp Reverse KL 的单样本蒙特卡洛估计;通信开销最低

精度 vs 开销:全词表需要物化完整 logits;采样 token 只需教师在已采样 token 上的 logp(可走远程 API)。DeepSeek-V4技术报告指出,仅用采样 token 的 log-ratio 作 advantage 时梯度估计方差较大,因此其全词表 OPD 采用完整 logit 蒸馏

2.2 信号怎么传给学生

路径 A:GKD(直接损失) 路径 B:OPD-RL(RL Advantage)
训练范式 --rlhf_type gkd --rlhf_type grpo + 教师
信号传递 把信号作为 loss 把信号当 advantage,走 policy gradient
梯度流经 学生全词表 logits(或 top-k) 仅学生采样 token 的 $\nabla\log\pi(y_t)$
教师信息需求 全词表分布(或 top-k logits) 采样 token 上的单个 logp
散度选择 Forward / Reverse / JSD(--beta Reverse KL(k1 log-ratio)
与任务奖励组合 通过 sft_alpha 混合 SFT loss 可与 GRPO reward 叠加为 advantage

两者共享同一套教师基础设施(见下文),区别只在如何使用 teacher KL 信号。

蒸馏的常见用法

  1. 能力融合:多个专家模型蒸馏到统一模型。

  2. 强到弱:大模型向小模型传递能力。

  3. 防遗忘:用旧 checkpoint 作教师,在多阶段训练后恢复先前能力。


三、swift 中的蒸馏方法

swift 提供三种蒸馏训练方法,它们共享同一套教师基础设施,差异落在第二节的框架里:

方法 信号传递路径 启用方式 一句话
GKD 直接损失(路径 A) --rlhf_type gkd 教师散度作为 loss 反向传播;支持全词表 / top-k 散度
OPD-RL RL advantage(路径 B) --rlhf_type grpo + 教师 教师 log-ratio 注入 GRPO advantage,可与任务奖励叠加
OPSD A 或 B 均可 在上面基础上提供 teacher_prompt 单模型自蒸馏:教师输入含特权信息(如参考解答)

教师的三种来源(GKD 与 OPD-RL 通用):

  • --teacher_model:在训练进程中加载一个独立的冻结教师模型。

  • --teacher_model_server:连接一个外部教师服务(swift deploy --infer_backend vllm),不在训练卡上加载教师。GKD 使用 API 时需同时设置 --gkd_logits_topk

  • 自蒸馏:教师与学生同源。LoRA 训练且 --teacher_model--model 相同时,自动用 disable_adapter() 以基座为固定教师,无需额外加载;GKD 下不传 --teacher_model 则以「学生当前权重」为动态教师。

教师相关参数(GKD 与 OPD-RL 共享,完整说明见命令行参数):

参数 默认值 说明
--teacher_model None 教师模型路径;GKD 下不传则为动态自蒸馏
--teacher_model_server None 教师 API 地址(与 teacher_model 互斥)
--teacher_deepspeed None 教师模型的 DeepSpeed 配置(如 zero3
--offload_teacher_model False 非前向阶段将教师卸载到 CPU

3.1 GKD:散度作为直接损失

GKD(Generalized Knowledge Distillation)直接把教师-学生间的散度作为损失函数反向传播。

损失函数

\[ \mathcal{L}_{\text{GKD}}(x, y) = \sum_{t=1}^{|y|} D_{\text{JSD}(\beta)}\big(P_{\text{teacher}}(\cdot|x,y_{<t}),\, P_{\text{student}}(\cdot|x,y_{<t})\big) \]

其中散度 \(D\)--beta 选择(见 2.1):\(\beta=0\) 为 Forward KL,\(\beta=1\) 为 Reverse KL,\(0<\beta<1\) 为广义 JSD(默认 \(0.5\))。

On-Policy vs Off-Policy:lmbda

GKD 通过 lmbda 控制每个 batch 用学生在线采样的概率:

if random() <= lmbda:
    y = student.generate(x)   # on-policy:学生自己采样
else:
    y = y_ground_truth        # off-policy:使用数据集标注
loss = D(P_teacher(·|x, y), P_student(·|x, y))
  • lmbda=0:纯离线(传统 SFT 蒸馏)。

  • lmbda=1:纯在线(学生从自身错误中学习,即 on-policy 蒸馏)。

  • 0<lmbda<1:混合。

若希望走teacher生成数据路径,需先用教师离线生成响应写入数据集,再设 lmbda=0 训练

GKD参数

参数 类型 默认值 说明
--beta float 0.5 散度插值:0=Forward KL,0.5=JSD,1=Reverse KL
--lmbda float 0.5 在线采样概率:0=离线,1=纯在线
--sft_alpha float 0 混合 SFT loss 比例,最终 loss = gkd_loss + sft_alpha * sft_loss(仅对非学生生成的数据生效)
--gkd_logits_topk int None 仅用教师 top-K logits 计算 KL;使用 teacher_model_server 时为必填

Top-K 蒸馏(省显存)

默认用完整词表计算 KL,词表很大时容易 OOM,可使用--gkd_logits_topk参数。

外部教师 API

设置 --teacher_model_server 时需同时设置 --gkd_logits_topk(API 仅返回 top-k logprobs)。示例如下:

# 步骤 1:部署教师模型(max_logprobs 需 >= gkd_logits_topk)
CUDA_VISIBLE_DEVICES=0 swift deploy \
    --model Qwen/Qwen3.5-9B \
    --infer_backend vllm \
    --port 8000 \
    --max_logprobs 64

# 步骤 2:启动 GKD 训练
CUDA_VISIBLE_DEVICES=1,2,3,4 \
NPROC_PER_NODE=4 \
swift rlhf \
    --rlhf_type gkd \
    --model Qwen/Qwen3.5-2B \
    --teacher_model_server http://localhost:8000 \
    --gkd_logits_topk 64 \
    --lmbda 1.0 \
    --beta 1.0 \
    --dataset xxx

在线采样加速

lmbda > 0 时学生需在线生成序列,建议用 vLLM 加速采样(colocate / server 两种模式,与 GRPO 一致),参考 GRPO 文档

参考脚本


3.2 OPD-RL:KL 作为 RL Advantage

OPD(On-Policy Distillation)RL 把教师 KL 信号注入 GRPO 的 per-token advantage,通过 policy gradient 更新学生。

原理

标准 GRPO 的 advantage 来自组内归一化的任务奖励(per-sequence 标量)。OPD-RL 在 advantage 归一化之后逐 token 注入教师信号:

\[ A_t = A_t^{\text{base}} + \alpha \cdot \big(\log \pi_{\text{teacher}}(y_t|x,y_{<t}) - \log \pi_{\text{student}}(y_t|x,y_{<t})\big) \]
  • \(A_t^{\text{base}}\):GRPO 归一化后的任务奖励 advantage(无奖励函数时为 0)。

  • \(\alpha\)--teacher_kl_coef,教师信号强度。

  • \(\log\pi_{\text{teacher}}(y_t) - \log\pi_{\text{student}}(y_t)\):教师 log-ratio,即 Reverse KL 梯度中 \(\nabla_\theta\log\pi_\theta(y_t)\) 系数对应的 k1 估计量(见 compute_teacher_logratio)。

纯蒸馏模式:不设 --reward_funcs 时 base advantage 为 0,教师信号是唯一驱动力,此时 \(A_t = \alpha\cdot(\log\pi_{\text{teacher}}(y_t)-\log\pi_{\text{student}}(y_t))\)

监控指标:日志中的 teacher_kl 是 k3 估计量 \(e^{d}-d-1\)\(d=\log\pi_{\text{teacher}}-\log\pi_{\text{student}}\)),衡量学生与教师的距离。

启用方式:在 --rlhf_type grpo 下设置 --teacher_model--teacher_model_server 即自动启用 OPD-RL,无需额外开关。教师相关参数见上文共享参数表。

OPD-RL 特有参数

参数 默认值 说明
--teacher_kl_coef 1.0 教师 log-ratio 注入 advantage 的系数 $\alpha$

参考脚本


3.3 OPSD:On-Policy Self-Distillation

OPSD(On-Policy Self-Distillation)是一种单模型自蒸馏方法:同一模型分别构造学生输入与教师输入,教师侧额外接收特权信息(如参考解答),再对齐两者在学生采样响应上的输出分布。

核心机制

  • 学生:仅看到问题,正常推理。

  • 教师:看到问题 + 参考解答(通过 teacher_prompt 列提供特权信息)。

  • 训练目标:用散度(JSD / KL)对齐学生与教师在同一份学生采样响应上的输出分布。

OPSD 既可走 GKD 路径,也可走 OPD-RL 路径:

  • GKD + OPSD--rlhf_type gkd,教师 KL 作为直接损失。

  • OPD-RL + OPSD--rlhf_type grpo + --teacher_model(与 --model 相同),教师 KL 作为 advantage。

两种自蒸馏权重模式

模式 参数配置 教师权重 说明
Dynamic(动态) 不传 --teacher_model 学生当前权重 教师随训练同步更新
Fixed(固定) --teacher_model 设为与 --model 相同 初始教师权重 教师权重固定

数据格式

OPSD 数据集需包含 teacher_prompt 列,可通过 --external_plugins 加载数据处理插件来构建。以数学推理数据集 open-r1/OpenThoughts-114k-math 为例:

from swift.dataset import DatasetMeta, RowPreprocessor, register_dataset

class OpenThoughtsOPSDPreprocessor(RowPreprocessor):
    def preprocess(self, row):
        if not row.get('correct', True):
            return None
        problem = row.get('problem', '')
        solution = row.get('solution', '')
        teacher_prompt = f'{problem}\n\nReference solution:\n{solution}\n\nNow articulate your own reasoning.'
        messages = [
            {'role': 'system', 'content': 'Please reason step by step, and put your final answer within \\boxed{}.'},
            {'role': 'user', 'content': problem},
        ]
        return {'messages': messages, 'teacher_prompt': teacher_prompt}

register_dataset(DatasetMeta(
    ms_dataset_id='open-r1/OpenThoughts-114k-math',
    preprocess_func=OpenThoughtsOPSDPreprocessor(),
    tags=['math', 'opsd'],
))

参考脚本


Reference