GKD
GKD (Generalized Knowledge Distillation) training algorithm is proposed in the paper On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes. This algorithm transfers knowledge from the teacher model to the student model by combining offline and on-policy learning strategies.
Loss Function
Given an input sequence \(x\) and output sequence \(y\), the GKD loss function can be written as:
Where:
\(y_{<t} = (y_1, y_2, \ldots, y_{t-1})\): sequence of the first \(t-1\) tokens
\(P_{\text{teacher}}(\cdot | x, y_{<t})\): output probability distribution of the teacher model given context \(x, y_{<t}\)
\(P_{\text{student}}(\cdot | x, y_{<t})\): output probability distribution of the student model given context \(x, y_{<t}\)
\(D(\cdot, \cdot)\): divergence function to measure the difference between two probability distributions
Divergence Metrics
KL Divergence (Kullback-Leibler Divergence)
KL divergence is an asymmetric measure of the difference between two probability distributions \(P\) and \(Q\):
Forward KL and Reverse KL
In knowledge distillation, there are two choices depending on the order of the two distributions in the KL divergence:
Forward KL
Characteristics: Mode-covering
Expectation is computed under the teacher distribution
The student model tends to cover the entire teacher distribution (including low-probability regions)
Reverse KL
Characteristics: Mode-seeking
Expectation is computed under the student distribution
The student model tends to concentrate on the peak regions (high-probability areas) of the teacher model
Generalized Jensen-Shannon Divergence (Generalized JSD)
GKD uses generalized JSD as the core metric, performing smooth interpolation between Forward KL and Reverse KL through parameter \(\beta \in [0, 1]\).
For two probability distributions \(P\) and \(Q\), generalized JSD is defined as:
Where the mixture distribution \(M\) is defined as:
When \(\beta = 0.5\), it reduces to the standard symmetric JSD
By adjusting \(\beta\), one can trade off between Mode-seeking and Mode-covering
In GKD, we set \(P = P_{\text{teacher}}\) and \(Q = P_{\text{student}}\), therefore:
Where \(M = \beta \cdot P_{\text{teacher}} + (1-\beta) \cdot P_{\text{student}}\)
For extreme cases (\(\beta = 0\) or \(\beta = 1\)), directly compute a single KL divergence:
When \(\beta = 0\): directly define \(D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})\) (Forward KL, Mode-covering)
When \(\beta = 1\): directly define \(D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})\) (Reverse KL, Mode-seeking)
When \(0 < \beta < 1\): use the above mixture distribution formula for interpolation
By adjusting the \(\beta\) parameter, interpolation can be performed between different divergence metrics. When \(\beta = 0.5\), the divergence is the standard symmetric JSD.
Two Training Modes
GKD training has two training modes, distinguished by the source of the output sequence \(y\).
Mode Selection Logic
During training, each sample selects a mode according to the following priority:
# Pseudocode: mode selection logic
if random() < lmbda:
# Mode 1: On-Policy learning, output sequence sampled by student model
y = student.generate(x)
source = "student"
else:
# Mode 2: Offline learning, use output sequence from dataset
y = y_ground_truth
source = "dataset"
# Same loss function
loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))
Mode 1: On-Policy Learning
Set parameter lambda, triggered with probability \(\lambda\), using student model sampling \(y \sim P_{\text{student}}(\cdot | x)\)
The student model learns from sequences generated by itself
Exposed to errors it might make, learning to self-correct and recover from errors
Aligns training distribution with inference distribution
Improves model robustness and practical application performance
Applicable Scenarios:
The student model already has certain generation capabilities
Want to improve model performance in real inference scenarios
Mode 2: Offline Learning (lmbda=0 or on-policy not triggered)
Data Source: \(y = y^* \sim \text{Dataset}\)
The student model learns from annotated sequences in the dataset
Parameter Settings
We can perform GKD training by setting the following parameters:
Basic Parameters
| Parameter | Type | Default | Range | Description |
|---|---|---|---|---|
--teacher_model |
str | None | - | Teacher model path or model ID *Can be omitted when using teacher_model_server |
--beta |
float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient • 0.0: Forward KL • 0.5: JSD (balanced) • 1.0: Reverse KL |
--lmbda |
float | 0.5 | [0.0, 1.0] | On-Policy learning trigger probability • 0.0: Pure Offline • 0.5: Mixed strategy (recommended) • 1.0: Pure On-Policy |
--temperature |
float | 0.9 | > 0 | Generation sampling temperature, controls randomness |
--sft_alpha |
float | 0 | >= 0 | Mix in a proportion of SFT loss; applied to non-student-generated completions |
--max_completion_length |
int | 512 | > 0 | Maximum number of tokens during generation |
Top-K KL Computation
By default, GKD computes KL divergence over the full vocabulary. For models with large vocabularies, you can use Top-K mode to reduce memory usage and computation.
| Parameter | Type | Default | Description |
|---|---|---|---|
--gkd_logits_topk |
int | None | Number of Top-K logits • None: Use full vocabulary (default) • Positive integer: Only use the K tokens with highest teacher probability for KL computation |
Top-K Mode Principle:
In Top-K mode, the top-K token indices are selected from the teacher model, and the KL divergence is computed on both models’ logits at these positions. It use the teacher model’s top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD.
Where the Top-K indices come from the teacher model: \(\text{Top-K} = \text{argtop}_K(P_T)\), and \(\tilde{P}_T\) and \(\tilde{P}_S\) are the probability distributions renormalized over the Top-K subset:
Usage Example:
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-7B-Instruct \
--teacher_model Qwen/Qwen2.5-14B-Instruct \
--gkd_logits_topk 64 \
--dataset your_dataset \
...
Note: Top-K mode cannot be used with liger kernel (
--use_liger_kernel).
External Teacher Model API
When gkd_logits_topk is set, you can use an external teacher model API service to fetch logprobs, which avoids loading the teacher model in the training process.
| Parameter | Type | Default | Description |
|---|---|---|---|
--teacher_model_server |
str | None | Teacher model service URL e.g., http://localhost:8000 |
--gkd_logits_topk |
int | Required | Must be set when using external API; corresponds to the top_logprobs returned by the API |
Step 1: Deploy Teacher Model Service
Using swift deploy (specify --infer_backend vllm):
CUDA_VISIBLE_DEVICES=0 \
swift deploy \
--model Qwen/Qwen2.5-14B-Instruct \
--infer_backend vllm \
--port 8000 \
--max_logprobs 64 \
--max_length 4096 \
--vllm_max_model_len 4096
Step 2: Start GKD Training
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen2.5-7B \
--teacher_model_server http://localhost:8000 \
--gkd_logits_topk 64 \
--dataset your_dataset \
--lmbda 1.0 \
--beta 1.0 \
...
Script example available here
Sampling Acceleration
In GKD training, student model online sampling involves the following scenario:
Student model sampling (when
lmbda > 0): triggered with probability \(\lambda\)
Since the sampling process significantly slows down training speed, you can refer to the following acceleration scheme:
Solution 1: Student Model Sampling Acceleration
Use vLLM as the inference backend to accelerate student model sampling. Supports two deployment modes, consistent with GRPO. Refer to GRPO documentation
Note: vLLM acceleration only applies to student model on-policy sampling (
lmbda > 0).
Training script reference here, for related parameters, please refer to GRPO vLLM Parameters.
Solution 2: Teacher Model Pre-sampling
Use pre-sampling: first use the teacher model to offline generate high-quality data, then train on it as a dataset.
Step 1: Generate data using teacher model
export teacher_model='OpenGVLab/InternVL3-8B'
NPROC_PER_NODE=4 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
swift infer \
--model $teacher_model \
--infer_backend vllm \
--val_dataset 'modelscope/coco_2014_caption:validation#5000' \
--vllm_gpu_memory_utilization 0.9 \
--vllm_max_model_len 8192 \
--max_new_tokens 2048 \
--write_batch_size 1000 \
--result_path teacher_generated_data.jsonl
Step 2: Train using pre-generated data
swift rlhf \
--rlhf_type gkd \
--model OpenGVLab/InternVL3-2B-Pretrained \
--teacher_model $teacher_model \
--dataset 'teacher_generated_data.jsonl' \
...
Training script reference here
On-Policy Distillation
We can achieve the On-Policy Distillation training described in the Thinking Machines Lab blog by setting the following parameters:
--lmbda 1 # on-policy
--beta 1 # reverse
For a complete implementation, refer to the example script here.
OPSD (On-Policy Self-Distillation)
OPSD (On-Policy Self-Distillation), is a method that requires no separate teacher model. The key idea: the same model serves as both teacher and student, where the teacher receives privileged information (e.g., reference solutions) to guide student learning.
Core Mechanism
Student: sees only the problem and reasons normally
Teacher: sees the problem + reference solution (privileged info via
teacher_promptcolumn), producing a better probability distributionTraining objective: align student and teacher output distributions via JSD divergence
Two Self-Distillation Modes
| Mode | Configuration | Teacher Weights | Description |
|---|---|---|---|
| Dynamic | No --teacher_model |
Student's current weights | Teacher updates with training |
| Fixed | --teacher_model = same as student |
Initial base weights | Fixed teacher weight |
Data Format
OPSD requires a teacher_prompt column in the dataset to provide privileged information for the teacher. Use --external_plugins to load a data preprocessing plugin that constructs this column.
Example with open-r1/OpenThoughts-114k-math:
from swift.dataset import DatasetMeta, RowPreprocessor, register_dataset
class OpenThoughtsOPSDPreprocessor(RowPreprocessor):
def preprocess(self, row):
if not row.get('correct', True):
return None
problem = row.get('problem', '')
solution = row.get('solution', '')
teacher_prompt = f'{problem}\n\nReference solution:\n{solution}\n\nNow articulate your own reasoning.'
messages = [
{'role': 'system', 'content': 'Please reason step by step, and put your final answer within \\boxed{}.'},
{'role': 'user', 'content': problem},
]
return {'messages': messages, 'teacher_prompt': teacher_prompt}
register_dataset(DatasetMeta(
ms_dataset_id='open-r1/OpenThoughts-114k-math',
preprocess_func=OpenThoughtsOPSDPreprocessor(),
tags=['math', 'opsd'],
))
Parameters
OPSD reuses all GKD parameters. The key difference is --teacher_model configuration:
| Parameter | Dynamic Mode | Fixed Mode |
|---|---|---|
--teacher_model |
Not set | Same model as --model |
Full scripts available here
Megatron available here