GKD

GKD (Generalized Knowledge Distillation) training algorithm is proposed in the paper On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes. This algorithm transfers knowledge from the teacher model to the student model by combining offline and on-policy learning strategies.

Loss Function

Given an input sequence \(x\) and output sequence \(y\), the GKD loss function can be written as:

\[ \mathcal{L}_{\text{GKD}}(x, y) = \sum_{t=1}^{|y|} D(P_{\text{teacher}}(\cdot | x, y_{<t}), P_{\text{student}}(\cdot | x, y_{<t})) \]

Where:

  • \(y_{<t} = (y_1, y_2, \ldots, y_{t-1})\): sequence of the first \(t-1\) tokens

  • \(P_{\text{teacher}}(\cdot | x, y_{<t})\): output probability distribution of the teacher model given context \(x, y_{<t}\)

  • \(P_{\text{student}}(\cdot | x, y_{<t})\): output probability distribution of the student model given context \(x, y_{<t}\)

  • \(D(\cdot, \cdot)\): divergence function to measure the difference between two probability distributions

Divergence Metrics

KL Divergence (Kullback-Leibler Divergence)

KL divergence is an asymmetric measure of the difference between two probability distributions \(P\) and \(Q\):

\[ \text{KL}(P \| Q) = \sum_v P(v) \log \frac{P(v)}{Q(v)} = \mathbb{E}_{v \sim P}\left[\log \frac{P(v)}{Q(v)}\right] \]

Forward KL and Reverse KL

In knowledge distillation, there are two choices depending on the order of the two distributions in the KL divergence:

Forward KL

\[ \text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)} \]

Characteristics: Mode-covering

  • Expectation is computed under the teacher distribution

  • The student model tends to cover the entire teacher distribution (including low-probability regions)

Reverse KL

\[ \text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)} \]

Characteristics: Mode-seeking

  • Expectation is computed under the student distribution

  • The student model tends to concentrate on the peak regions (high-probability areas) of the teacher model

Generalized Jensen-Shannon Divergence (Generalized JSD)

GKD uses generalized JSD as the core metric, performing smooth interpolation between Forward KL and Reverse KL through parameter \(\beta \in [0, 1]\).

For two probability distributions \(P\) and \(Q\), generalized JSD is defined as:

\[ D_{\text{JSD}(\beta)}(P, Q) = \beta \cdot \text{KL}(P \| M) + (1-\beta) \cdot \text{KL}(Q \| M) \]

Where the mixture distribution \(M\) is defined as:

\[ M = \beta \cdot P + (1-\beta) \cdot Q \]
  • When \(\beta = 0.5\), it reduces to the standard symmetric JSD

  • By adjusting \(\beta\), one can trade off between Mode-seeking and Mode-covering

In GKD, we set \(P = P_{\text{teacher}}\) and \(Q = P_{\text{student}}\), therefore:

\[ D_{\text{JSD}(\beta)}(P_{\text{teacher}}, P_{\text{student}}) = \beta \cdot \text{KL}(P_{\text{teacher}} \| M) + (1-\beta) \cdot \text{KL}(P_{\text{student}} \| M) \]

Where \(M = \beta \cdot P_{\text{teacher}} + (1-\beta) \cdot P_{\text{student}}\)

For extreme cases (\(\beta = 0\) or \(\beta = 1\)), directly compute a single KL divergence:

  • When \(\beta = 0\): directly define \(D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})\) (Forward KL, Mode-covering)

  • When \(\beta = 1\): directly define \(D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})\) (Reverse KL, Mode-seeking)

  • When \(0 < \beta < 1\): use the above mixture distribution formula for interpolation

By adjusting the \(\beta\) parameter, interpolation can be performed between different divergence metrics. When \(\beta = 0.5\), the divergence is the standard symmetric JSD.

Two Training Modes

GKD training has two training modes, distinguished by the source of the output sequence \(y\).

Mode Selection Logic

During training, each sample selects a mode according to the following priority:

# Pseudocode: mode selection logic
if random() < lmbda:
    # Mode 1: On-Policy learning, output sequence sampled by student model
    y = student.generate(x)
    source = "student"
else:
    # Mode 2: Offline learning, use output sequence from dataset
    y = y_ground_truth
    source = "dataset"

# Same loss function
loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))

Mode 1: On-Policy Learning

Set parameter lambda, triggered with probability \(\lambda\), using student model sampling \(y \sim P_{\text{student}}(\cdot | x)\)

  • The student model learns from sequences generated by itself

  • Exposed to errors it might make, learning to self-correct and recover from errors

  • Aligns training distribution with inference distribution

  • Improves model robustness and practical application performance

Applicable Scenarios:

  • The student model already has certain generation capabilities

  • Want to improve model performance in real inference scenarios

Mode 2: Offline Learning (lmbda=0 or on-policy not triggered)

Data Source: \(y = y^* \sim \text{Dataset}\)

  • The student model learns from annotated sequences in the dataset

Parameter Settings

We can perform GKD training by setting the following parameters:

Basic Parameters

Parameter Type Default Range Description
--teacher_model str None - Teacher model path or model ID
*Can be omitted when using teacher_model_server
--beta float 0.5 [0.0, 1.0] Divergence interpolation coefficient
• 0.0: Forward KL
• 0.5: JSD (balanced)
• 1.0: Reverse KL
--lmbda float 0.5 [0.0, 1.0] On-Policy learning trigger probability
• 0.0: Pure Offline
• 0.5: Mixed strategy (recommended)
• 1.0: Pure On-Policy
--temperature float 0.9 > 0 Generation sampling temperature, controls randomness
--sft_alpha float 0 >= 0 Mix in a proportion of SFT loss; applied to non-student-generated completions
--max_completion_length int 512 > 0 Maximum number of tokens during generation

Top-K KL Computation

By default, GKD computes KL divergence over the full vocabulary. For models with large vocabularies, you can use Top-K mode to reduce memory usage and computation.

Parameter Type Default Description
--gkd_logits_topk int None Number of Top-K logits
• None: Use full vocabulary (default)
• Positive integer: Only use the K tokens with highest teacher probability for KL computation

Top-K Mode Principle:

In Top-K mode, the top-K token indices are selected from the teacher model, and the KL divergence is computed on both models’ logits at these positions. It use the teacher model’s top-k indices to gather logits from both models, then renormalize over the top-k subset before computing JSD.

\[ D_{\text{JSD}(\beta)}^{\text{top-k}}(P_T, P_S) = \beta \cdot \text{KL}(\tilde{P}_T \| \tilde{M}) + (1-\beta) \cdot \text{KL}(\tilde{P}_S \| \tilde{M}) \]

Where the Top-K indices come from the teacher model: \(\text{Top-K} = \text{argtop}_K(P_T)\), and \(\tilde{P}_T\) and \(\tilde{P}_S\) are the probability distributions renormalized over the Top-K subset:

\[ \tilde{P}_T(v) = \frac{P_T(v)}{\sum_{v' \in \text{Top-K}} P_T(v')}, \quad \tilde{P}_S(v) = \frac{P_S(v)}{\sum_{v' \in \text{Top-K}} P_S(v')}, \quad v \in \text{Top-K} \]

Usage Example:

swift rlhf \
    --rlhf_type gkd \
    --model Qwen/Qwen2.5-7B-Instruct \
    --teacher_model Qwen/Qwen2.5-14B-Instruct \
    --gkd_logits_topk 64 \
    --dataset your_dataset \
    ...

Note: Top-K mode cannot be used with liger kernel (--use_liger_kernel).

External Teacher Model API

When gkd_logits_topk is set, you can use an external teacher model API service to fetch logprobs, which avoids loading the teacher model in the training process.

Parameter Type Default Description
--teacher_model_server str None Teacher model service URL
e.g., http://localhost:8000
--gkd_logits_topk int Required Must be set when using external API; corresponds to the top_logprobs returned by the API

Step 1: Deploy Teacher Model Service

Using swift deploy (specify --infer_backend vllm):

CUDA_VISIBLE_DEVICES=0 \
swift deploy \
    --model Qwen/Qwen2.5-14B-Instruct \
    --infer_backend vllm \
    --port 8000 \
    --max_logprobs 64 \
    --max_length 4096 \
    --vllm_max_model_len 4096

Step 2: Start GKD Training

swift rlhf \
    --rlhf_type gkd \
    --model Qwen/Qwen2.5-7B \
    --teacher_model_server http://localhost:8000 \
    --gkd_logits_topk 64 \
    --dataset your_dataset \
    --lmbda 1.0 \
    --beta 1.0 \
    ...

Script example available here

Sampling Acceleration

In GKD training, student model online sampling involves the following scenario:

  1. Student model sampling (when lmbda > 0): triggered with probability \(\lambda\)

Since the sampling process significantly slows down training speed, you can refer to the following acceleration scheme:

Solution 1: Student Model Sampling Acceleration

Use vLLM as the inference backend to accelerate student model sampling. Supports two deployment modes, consistent with GRPO. Refer to GRPO documentation

Note: vLLM acceleration only applies to student model on-policy sampling (lmbda > 0).

Training script reference here, for related parameters, please refer to GRPO vLLM Parameters.

Solution 2: Teacher Model Pre-sampling

Use pre-sampling: first use the teacher model to offline generate high-quality data, then train on it as a dataset.

Step 1: Generate data using teacher model

export teacher_model='OpenGVLab/InternVL3-8B'

NPROC_PER_NODE=4 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
swift infer \
    --model $teacher_model \
    --infer_backend vllm \
    --val_dataset 'modelscope/coco_2014_caption:validation#5000' \
    --vllm_gpu_memory_utilization 0.9 \
    --vllm_max_model_len 8192 \
    --max_new_tokens 2048 \
    --write_batch_size 1000 \
    --result_path teacher_generated_data.jsonl

Step 2: Train using pre-generated data

swift rlhf \
    --rlhf_type gkd \
    --model OpenGVLab/InternVL3-2B-Pretrained \
    --teacher_model $teacher_model \
    --dataset 'teacher_generated_data.jsonl' \
    ...

Training script reference here

On-Policy Distillation

We can achieve the On-Policy Distillation training described in the Thinking Machines Lab blog by setting the following parameters:

--lmbda 1 # on-policy
--beta 1 # reverse

For a complete implementation, refer to the example script here.

OPSD (On-Policy Self-Distillation)

OPSD (On-Policy Self-Distillation), is a method that requires no separate teacher model. The key idea: the same model serves as both teacher and student, where the teacher receives privileged information (e.g., reference solutions) to guide student learning.

Core Mechanism

  • Student: sees only the problem and reasons normally

  • Teacher: sees the problem + reference solution (privileged info via teacher_prompt column), producing a better probability distribution

  • Training objective: align student and teacher output distributions via JSD divergence

Two Self-Distillation Modes

Mode Configuration Teacher Weights Description
Dynamic No --teacher_model Student's current weights Teacher updates with training
Fixed --teacher_model = same as student Initial base weights Fixed teacher weight

Data Format

OPSD requires a teacher_prompt column in the dataset to provide privileged information for the teacher. Use --external_plugins to load a data preprocessing plugin that constructs this column.

Example with open-r1/OpenThoughts-114k-math:

from swift.dataset import DatasetMeta, RowPreprocessor, register_dataset

class OpenThoughtsOPSDPreprocessor(RowPreprocessor):
    def preprocess(self, row):
        if not row.get('correct', True):
            return None
        problem = row.get('problem', '')
        solution = row.get('solution', '')
        teacher_prompt = f'{problem}\n\nReference solution:\n{solution}\n\nNow articulate your own reasoning.'
        messages = [
            {'role': 'system', 'content': 'Please reason step by step, and put your final answer within \\boxed{}.'},
            {'role': 'user', 'content': problem},
        ]
        return {'messages': messages, 'teacher_prompt': teacher_prompt}

register_dataset(DatasetMeta(
    ms_dataset_id='open-r1/OpenThoughts-114k-math',
    preprocess_func=OpenThoughtsOPSDPreprocessor(),
    tags=['math', 'opsd'],
))

Parameters

OPSD reuses all GKD parameters. The key difference is --teacher_model configuration:

Parameter Dynamic Mode Fixed Mode
--teacher_model Not set Same model as --model

Full scripts available here

Megatron available here