# GKD GKD(Generalized Knowledge Distillation,广义知识蒸馏)训练算法由论文 [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://arxiv.org/pdf/2306.13649) 提出。该算法通过结合离线(off-policy)和在线(on-policy)学习策略,将教师模型的知识迁移到学生模型中。 ## 损失函数 当给定输入序列 $x$ 与输出序列 $y$,GKD 的损失函数可以写为: $$ \mathcal{L}_{\text{GKD}}(x, y) = \sum_{t=1}^{|y|} D(P_{\text{teacher}}(\cdot | x, y_{ 对极端情况($\beta = 0$ 或 $\beta = 1$),直接计算单个 KL 散度: > - 当 $\beta = 0$ 时:直接定义 $D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})$(Forward KL,Mode-covering) > - 当 $\beta = 1$ 时:直接定义 $D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})$(Reverse KL,Mode-seeking) > - 当 $0 < \beta < 1$ 时:使用上述混合分布公式进行插值 通过调节 $\beta$ 参数,可以在不同的散度度量之间进行插值,当 $\beta = 0.5$ 时,散度为标准的对称 JSD。 ## 三种训练模式 GKD训练具有三种训练模式,区别在于输出序列 $y$ 的来源。 ### 模式选择逻辑 训练时,每个样本按照以下优先级选择模式: ```python # 伪代码:模式选择逻辑 if random() < lmbda: # Mode 1: On-Policy 学习,由学生模型采样输出序列 y = student.generate(x) source = "student" elif seq_kd: # Mode 2: Sequential KD,由教师模型采样输出序列 y = teacher.generate(x) source = "teacher" else: # Mode 3: 使用数据集中的输出序列 y = y_ground_truth source = "dataset" # 相同的损失函数 loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y)) ``` ### Mode 1: On-Policy 学习 设置参数`lambda`, 以概率 $\lambda$ 触发,使用学生模型采样 $y \sim P_{\text{student}}(\cdot | x)$ - 学生模型从**自己生成的序列**中学习 - 暴露在自己可能犯的错误中,学会**自我纠正和错误恢复** - 对齐训练分布与推理分布 - 提升模型的鲁棒性和实际应用表现 **适用场景**: - 学生模型已有一定生成能力 - 希望提升模型在真实推理场景下的表现 ### Mode 2: Sequential KD(`seq_kd=True` 且未触发 on-policy) 设置参数 `seq_kd=True`, 当未触发 on-policy 时,使用教师模型采样 **数据来源**:$y \sim P_{\text{teacher}}(\cdot | x)$ ### Mode 3: 离线学习(其他情况) **数据来源**:$y = y^* \sim \text{Dataset}$ - 学生模型从**数据集的标注序列**中学习 ## 参数设置 我们可以通过设置以下参数进行 GKD 训练: ### 基础参数 | 参数 | 类型 | 默认值 | 取值范围 | 说明 | |------|------|--------|---------|------| | `--teacher_model` | str | None | - | 教师模型路径或模型 ID | | `--beta` | float | 0.5 | [0.0, 1.0] | 散度插值系数
• 0.0: Forward KL
• 0.5: JSD (平衡)
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率
• 0.0: 离线学习
• 0.5: 混合策略
• 1.0: 纯 On-Policy | | `--seq_kd` | bool | False | True/False | 是否使用教师生成序列
• False: 非 on-policy 时使用数据集
• True: 非 on-policy 时使用教师生成 | | `--temperature` | float | 0.9 | > 0 | 生成采样温度,控制随机性 | | `--sft_alpha` | float | 0 | >= 0 | 混合一定比例的sft loss,对非student生成结果生效 | | `--max_completion_length` | int | 512 | > 0 | 生成时的最大 token 数 | ### Top-K KL 计算 默认情况下,GKD 使用完整词表计算 KL 散度,容易造成 OOM,这种情况下可以使用 **Top-K** 模式来减少显存占用和计算量。 | 参数 | 类型 | 默认值 | 说明 | |------|------|--------|------| | `--gkd_logits_topk` | int | None | Top-K logits 数量
• None: 使用完整词表(默认)
• 正整数: 仅使用教师模型概率最高的 K 个 token 计算 KL | **Top-K 模式原理**: 在 Top-K 模式下,选取**教师模型**输出概率最高的 K 个 token,在这个子集上计算两个模型分布的 KL 散度。 $$ D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M}) $$ 其中 Top-K 索引来自教师模型:$\text{Top-K} = \text{argtop}_K(P_T)$,$\tilde{P}_T$ 和 $\tilde{P}_S$ 是在 Top-K 子集上**重新归一化**的概率分布: $$ \tilde{P}_T(v) = \frac{P_T(v)}{\sum_{v' \in \text{Top-K}} P_T(v')}, \quad \tilde{P}_S(v) = \frac{P_S(v)}{\sum_{v' \in \text{Top-K}} P_S(v')}, \quad v \in \text{Top-K} $$ **使用示例**: ```bash swift rlhf \ --rlhf_type gkd \ --model Qwen/Qwen2.5-7B-Instruct \ --teacher_model Qwen/Qwen2.5-72B-Instruct \ --gkd_logits_topk 64 \ --dataset your_dataset \ ... ``` > **注意**:Top-K 模式不能与 liger kernel 同时使用(`--use_liger_kernel`)。 ### 外部教师模型 API 当设置 `gkd_logits_topk` 时,可以使用外部教师模型 API 服务来获取 logprobs,这样可以避免在训练进程中加载教师模型。 | 参数 | 类型 | 默认值 | 说明 | |------|------|--------|------| | `--teacher_model_server` | str | None | 教师模型服务地址
如:`http://localhost:8000` | | `--gkd_logits_topk` | int | **必需** | 使用外部 API 时必须设置,对应 API 返回的 top_logprobs 数量 | **步骤 1:部署教师模型服务** ```bash # 使用 vllm serve 部署教师模型 CUDA_VISIBLE_DEVICES=0 vllm serve Qwen/Qwen2.5-14B-Instruct \ --port 8000 \ --max-logprobs 64 \ --gpu-memory-utilization 0.9 ``` **步骤 2:启动 GKD 训练** ```bash swift rlhf \ --rlhf_type gkd \ --model Qwen/Qwen2.5-7B \ --teacher_model_server http://localhost:8000 \ --gkd_logits_topk 64 \ --dataset your_dataset \ --lmbda 1.0 \ --beta 1.0 \ ... ``` > **vLLM max_logprobs 限制**: > - vLLM 默认 `max_logprobs=20`,可通过 `--max-logprobs N` 参数调整 > - `gkd_logits_topk` 不能超过服务端的 `max_logprobs` 设置 ## 采样加速 在 GKD 训练中,涉及到两种在线采样的情况: 1. **学生模型采样**(当 `lmbda > 0`):以 $\lambda$ 概率触发学生模型采样 2. **教师模型采样**(当 `seq_kd=True`):以 $1-\lambda$ 概率触发教师模型采样 由于采样过程会显著减慢训练速度,可参考以下两种加速方案: ### 方案 1:学生模型采样加速 使用 vLLM 作为推理后端来加速学生模型采样,支持两种部署模式,与 GRPO 一致,参考[GRPO文档](./GRPO/GetStarted/GRPO.md#集群支持), 相关参数参考[GRPO vLLM 参数](./Command-line-parameters.md#vllm_mode) > **注意**:vLLM 加速仅适用于学生模型的 on-policy 采样(`lmbda > 0`)。教师模型的 sequential KD 采样(`seq_kd=True`)目前仍使用 Transformers,建议使用预采样方案。 训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/vllm_server.sh) 使用 Teacher Server 的训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/teacher_server.sh) ### 方案 2:教师模型预采样 对于教师模型采样(`seq_kd=True`),推荐使用 **预采样** 方式:先用教师模型离线生成高质量数据,再进行训练。 **步骤 1:使用教师模型生成数据** ```bash export teacher_model='OpenGVLab/InternVL3-8B' NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 \ swift infer \ --model $teacher_model \ --infer_backend vllm \ --val_dataset 'modelscope/coco_2014_caption:validation#5000' \ --vllm_gpu_memory_utilization 0.9 \ --vllm_max_model_len 8192 \ --max_new_tokens 2048 \ --write_batch_size 1000 \ --result_path teacher_generated_data.jsonl ``` **步骤 2:使用预生成数据训练** ```bash swift rlhf \ --rlhf_type gkd \ --model OpenGVLab/InternVL3-2B-Pretrained \ --teacher_model $teacher_model \ --dataset 'teacher_generated_data.jsonl' \ --seq_kd false \ ... ``` 训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/rlhf/gkd/fast.sh) ## On-Policy Distillation 我们可以通过设置以下参数实现 Thinking Machine Lab blog 中的[On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/)训练。 ```bash --lmbda 1 # on-policy --beta 1 # reverse ``` 相关脚本可以参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh) ## OPSD(On-Policy Self-Distillation) OPSD([On-Policy Self-Distillation](https://arxiv.org/abs/2601.18734)) 是一种**单模型自蒸馏**方法,无需额外的教师模型。核心思想是:同一个模型同时扮演教师和学生,教师通过接收**特权信息**(如参考解答)来引导学生学习。 ### 核心机制 - **学生**:仅看到问题,正常推理 - **教师**:看到问题 + 参考解答(通过 `teacher_prompt` 列提供特权信息),产出更优的概率分布 - **训练目标**:用 JSD 散度对齐学生和教师的输出分布 ### 两种自蒸馏模式 | 模式 | 参数配置 | 教师权重 | 说明 | |------|---------|---------|------| | **Dynamic**(动态) | 不传 `--teacher_model` | 学生当前权重 | 教师随训练同步更新 | | **Fixed**(固定) | `--teacher_model` 设为与学生相同的模型 | 初始教师权重 | 教师权重固定 | ### 数据格式 OPSD 需要数据集包含 `teacher_prompt` 列来提供教师的特权信息。可通过 `--external_plugins` 加载数据处理插件来构建该列。 以数学推理数据集 `open-r1/OpenThoughts-114k-math` 为例: ```python from swift.dataset import DatasetMeta, RowPreprocessor, register_dataset class OpenThoughtsOPSDPreprocessor(RowPreprocessor): def preprocess(self, row): if not row.get('correct', True): return None problem = row.get('problem', '') solution = row.get('solution', '') # 教师看到问题 + 参考解答 teacher_prompt = f'{problem}\n\nReference solution:\n{solution}\n\nNow articulate your own reasoning.' messages = [ {'role': 'system', 'content': 'Please reason step by step, and put your final answer within \\boxed{}.'}, {'role': 'user', 'content': problem}, ] return {'messages': messages, 'teacher_prompt': teacher_prompt} register_dataset(DatasetMeta( ms_dataset_id='open-r1/OpenThoughts-114k-math', preprocess_func=OpenThoughtsOPSDPreprocessor(), tags=['math', 'opsd'], )) ``` ### 参数设置 OPSD 复用 GKD 的所有参数,核心区别在于 `--teacher_model` 的配置: | 参数 | Dynamic 模式 | Fixed 模式 | |------|-------------|-----------| | `--teacher_model` | 不设置 | 设为与 `--model` 相同的模型 | 参考脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/opsd/) Megatron脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/rlhf/gkd/opsd.sh)