# GKD GKD (Generalized Knowledge Distillation) training algorithm is proposed in the paper [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://arxiv.org/pdf/2306.13649). This algorithm transfers knowledge from the teacher model to the student model by combining offline and on-policy learning strategies. ## Loss Function Given an input sequence $x$ and output sequence $y$, the GKD loss function can be written as: $$ \mathcal{L}_{\text{GKD}}(x, y) = \sum_{t=1}^{|y|} D(P_{\text{teacher}}(\cdot | x, y_{ For extreme cases ($\beta = 0$ or $\beta = 1$), directly compute a single KL divergence: > - When $\beta = 0$: directly define $D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})$ (Forward KL, Mode-covering) > - When $\beta = 1$: directly define $D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})$ (Reverse KL, Mode-seeking) > - When $0 < \beta < 1$: use the above mixture distribution formula for interpolation By adjusting the $\beta$ parameter, interpolation can be performed between different divergence metrics. When $\beta = 0.5$, the divergence is the standard symmetric JSD. ## Two Training Modes GKD training has two training modes, distinguished by the source of the output sequence $y$. ### Mode Selection Logic During training, each sample selects a mode according to the following priority: ```python # Pseudocode: mode selection logic if random() < lmbda: # Mode 1: On-Policy learning, output sequence sampled by student model y = student.generate(x) source = "student" else: # Mode 2: Offline learning, use output sequence from dataset y = y_ground_truth source = "dataset" # Same loss function loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y)) ``` ### Mode 1: On-Policy Learning Set parameter `lambda`, triggered with probability $\lambda$, using student model sampling $y \sim P_{\text{student}}(\cdot | x)$ - The student model learns from **sequences generated by itself** - Exposed to errors it might make, learning to **self-correct and recover from errors** - Aligns training distribution with inference distribution - Improves model robustness and practical application performance **Applicable Scenarios**: - The student model already has certain generation capabilities - Want to improve model performance in real inference scenarios ### Mode 2: Offline Learning (`lmbda=0` or on-policy not triggered) **Data Source**: $y = y^* \sim \text{Dataset}$ - The student model learns from **annotated sequences in the dataset** ## Parameter Settings We can perform GKD training by setting the following parameters: ### Basic Parameters | Parameter | Type | Default | Range | Description | |------|------|--------|---------|------| | `--teacher_model` | str | None | - | Teacher model path or model ID
*Can be omitted when using `teacher_model_server` | | `--beta` | float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient
• 0.0: Forward KL
• 0.5: JSD (balanced)
• 1.0: Reverse KL | | `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy learning trigger probability
• 0.0: Pure Offline
• 0.5: Mixed strategy (**recommended**)
• 1.0: Pure On-Policy | | `--temperature` | float | 0.9 | > 0 | Generation sampling temperature, controls randomness | | `--sft_alpha` | float | 0 | >= 0 | Mix in a proportion of SFT loss; applied to non-student-generated completions | | `--max_completion_length` | int | 512 | > 0 | Maximum number of tokens during generation | ### Top-K KL Computation By default, GKD computes KL divergence over the full vocabulary. For models with large vocabularies, you can use **Top-K** mode to reduce memory usage and computation. | Parameter | Type | Default | Description | |------|------|--------|------| | `--gkd_logits_topk` | int | None | Number of Top-K logits
• None: Use full vocabulary (default)
• Positive integer: Only use the K tokens with highest teacher probability for KL computation | **Top-K Mode Principle**: In Top-K mode, the top-K token indices are selected from the **teacher model**, and the KL divergence is computed on both models' logits at these positions. It use the teacher model's top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD. $$ D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M}) $$ Where the Top-K indices come from the teacher model: $\text{Top-K} = \text{argtop}_K(P_T)$, and $\tilde{P}_T$ and $\tilde{P}_S$ are the probability distributions **renormalized** over the Top-K subset: $$ \tilde{P}_T(v) = \frac{P_T(v)}{\sum_{v' \in \text{Top-K}} P_T(v')}, \quad \tilde{P}_S(v) = \frac{P_S(v)}{\sum_{v' \in \text{Top-K}} P_S(v')}, \quad v \in \text{Top-K} $$ **Usage Example**: ```bash swift rlhf \ --rlhf_type gkd \ --model Qwen/Qwen2.5-7B-Instruct \ --teacher_model Qwen/Qwen2.5-14B-Instruct \ --gkd_logits_topk 64 \ --dataset your_dataset \ ... ``` > **Note**: Top-K mode cannot be used with liger kernel (`--use_liger_kernel`). ### External Teacher Model API When `gkd_logits_topk` is set, you can use an external teacher model API service to fetch logprobs, which avoids loading the teacher model in the training process. | Parameter | Type | Default | Description | |------|------|--------|------| | `--teacher_model_server` | str | None | Teacher model service URL
e.g., `http://localhost:8000` | | `--gkd_logits_topk` | int | **Required** | Must be set when using external API; corresponds to the top_logprobs returned by the API | **Step 1: Deploy Teacher Model Service** Using `swift deploy` (specify `--infer_backend vllm`): ```bash CUDA_VISIBLE_DEVICES=0 \ swift deploy \ --model Qwen/Qwen2.5-14B-Instruct \ --infer_backend vllm \ --port 8000 \ --max_logprobs 64 \ --max_length 4096 \ --vllm_max_model_len 4096 ``` **Step 2: Start GKD Training** ```bash swift rlhf \ --rlhf_type gkd \ --model Qwen/Qwen2.5-7B \ --teacher_model_server http://localhost:8000 \ --gkd_logits_topk 64 \ --dataset your_dataset \ --lmbda 1.0 \ --beta 1.0 \ ... ``` Script example available [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/teacher_server.sh) ## Sampling Acceleration In GKD training, student model online sampling involves the following scenario: 1. **Student model sampling** (when `lmbda > 0`): triggered with probability $\lambda$ Since the sampling process significantly slows down training speed, you can refer to the following acceleration scheme: ### Solution 1: Student Model Sampling Acceleration Use vLLM as the inference backend to accelerate student model sampling. Supports two deployment modes, consistent with GRPO. Refer to [GRPO documentation](./GRPO/GetStarted/GRPO.md#cluster-support) > **Note**: vLLM acceleration only applies to student model on-policy sampling (`lmbda > 0`). Training script reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/vllm_server.sh), for related parameters, please refer to [GRPO vLLM Parameters](./Command-line-parameters.md#vllm_mode). ### Solution 2: Teacher Model Pre-sampling Use **pre-sampling**: first use the teacher model to offline generate high-quality data, then train on it as a dataset. **Step 1: Generate data using teacher model** ```bash export teacher_model='OpenGVLab/InternVL3-8B' NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 \ swift infer \ --model $teacher_model \ --infer_backend vllm \ --val_dataset 'modelscope/coco_2014_caption:validation#5000' \ --vllm_gpu_memory_utilization 0.9 \ --vllm_max_model_len 8192 \ --max_new_tokens 2048 \ --write_batch_size 1000 \ --result_path teacher_generated_data.jsonl ``` **Step 2: Train using pre-generated data** ```bash swift rlhf \ --rlhf_type gkd \ --model OpenGVLab/InternVL3-2B-Pretrained \ --teacher_model $teacher_model \ --dataset 'teacher_generated_data.jsonl' \ ... ``` Training script reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/rlhf/gkd/fast.sh) ## On-Policy Distillation We can achieve the [On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/) training described in the Thinking Machines Lab blog by setting the following parameters: ```bash --lmbda 1 # on-policy --beta 1 # reverse ``` For a complete implementation, refer to the example script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh). ## OPSD (On-Policy Self-Distillation) OPSD ([On-Policy Self-Distillation](https://arxiv.org/abs/2601.18734)), is a method that requires no separate teacher model. The key idea: the same model serves as both teacher and student, where the teacher receives **privileged information** (e.g., reference solutions) to guide student learning. ### Core Mechanism - **Student**: sees only the problem and reasons normally - **Teacher**: sees the problem + reference solution (privileged info via `teacher_prompt` column), producing a better probability distribution - **Training objective**: align student and teacher output distributions via JSD divergence ### Two Self-Distillation Modes | Mode | Configuration | Teacher Weights | Description | |------|--------------|----------------|-------------| | **Dynamic** | No `--teacher_model` | Student's current weights | Teacher updates with training | | **Fixed** | `--teacher_model` = same as student | Initial base weights | Fixed teacher weight| ### Data Format OPSD requires a `teacher_prompt` column in the dataset to provide privileged information for the teacher. Use `--external_plugins` to load a data preprocessing plugin that constructs this column. Example with `open-r1/OpenThoughts-114k-math`: ```python from swift.dataset import DatasetMeta, RowPreprocessor, register_dataset class OpenThoughtsOPSDPreprocessor(RowPreprocessor): def preprocess(self, row): if not row.get('correct', True): return None problem = row.get('problem', '') solution = row.get('solution', '') teacher_prompt = f'{problem}\n\nReference solution:\n{solution}\n\nNow articulate your own reasoning.' messages = [ {'role': 'system', 'content': 'Please reason step by step, and put your final answer within \\boxed{}.'}, {'role': 'user', 'content': problem}, ] return {'messages': messages, 'teacher_prompt': teacher_prompt} register_dataset(DatasetMeta( ms_dataset_id='open-r1/OpenThoughts-114k-math', preprocess_func=OpenThoughtsOPSDPreprocessor(), tags=['math', 'opsd'], )) ``` ### Parameters OPSD reuses all GKD parameters. The key difference is `--teacher_model` configuration: | Parameter | Dynamic Mode | Fixed Mode | |-----------|-------------|-----------| | `--teacher_model` | Not set | Same model as `--model` | Full scripts available [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/opsd/) Megatron available [here](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/rlhf/gkd/opsd.sh)