Knowledge Distillation

Knowledge distillation is a training method that transfers capabilities from a teacher model to a student model. The core idea is to have the student align with the teacher’s output distribution at each token position, yielding richer supervision than simply imitating labeled answers—the teacher tells the student not only which token is correct, but also how good or bad other tokens are.

This document introduces, top-down: why distillation works (Section 1), a unified design framework for distillation methods (Section 2), and finally three concrete distillation training methods in swift: GKD / OPD-RL / OPSD (Section 3).


1. Why Distillation: From Sparse to Dense Signals

A language model’s capabilities are typically built through a stack of training stages:

  • Pre-training: Acquire language, world knowledge, basic reasoning, and other general capabilities.

  • Mid-training: Inject domain knowledge, such as code, medicine, or internal company documents.

  • Post-training: Elicit target behaviors, such as instruction following, mathematical reasoning, or dialogue style.

Distillation mainly happens during post-training. To understand its value, consider post-training methods along two independent dimensions:

  1. Sampling mode (where data comes from): Whether training sequences are generated by the student itself (on-policy) or come from external fixed data (off-policy).

  2. Feedback density (how much each sequence teaches): Whether the entire sequence has a single reward (per-sequence, sparse) or each token has a signal (per-token, dense).

SFT / offline distillation (off-policy + dense): Align with labels or the teacher distribution on fixed data. The signal is dense, but training only sees states from the teacher/labels, which differ from states the student enters during its own inference. Once the student makes an early mistake the teacher would not make, it enters unseen states, errors accumulate—this is called exposure bias.

RL (on-policy + sparse): The student samples trajectories and receives rewards based on final outcomes. The distribution matches student inference, but rewards are typically sequence-level scalars that do not pinpoint which token went wrong.

On-policy distillation (on-policy + dense): The student samples trajectories, and the teacher scores every token on those trajectories. Training distribution matches student inference, and feedback is per-token.

SFT as a Special Case of Distillation

A natural entry point to understanding distillation is the SFT loss. SFT’s cross-entropy loss is equivalent to KL divergence against a one-hot “teacher” distribution \(\delta_{y^*}\) at the labeled token:

\[-\log P_S(y^*_t) = \text{KL}(\delta_{y^*} \,\|\, P_S)\]

Knowledge distillation simply replaces this deterministic one-hot “teacher” with the teacher model’s soft distribution \(P_T\), optimizing \(\text{KL}(P_T \,\|\, P_S)\) at each token to provide richer supervision than one-hot.

Method Sampling Mode Feedback Density Teacher Distribution
SFT off-policy (fixed data) dense (per-token) one-hot $\delta_{y^*}$
Offline (off-policy) distillation off-policy (fixed or teacher-generated data) dense (per-token) teacher soft distribution
RL on-policy (student sampling) sparse (per-sequence) none
On-policy distillation on-policy (student sampling) dense (per-token) teacher soft distribution

2. Two Core Choices in Distillation

Differences between distillation methods almost always boil down to two questions. Once you understand these dimensions, the methods below are just different combinations.

2.1 How to Compute the Teacher Signal

At each token position, quantify the difference between teacher distribution \(P_T\) and student distribution \(P_S\) (we call this Teacher KL). There are two sub-choices.

(a) Divergence direction

Divergence Definition Optimization behavior (information-theoretic meaning)
Forward KL $\text{KL}(P_T \,|\, P_S)$ Mode-covering: the student must assign enough probability to regions where the teacher has high probability
Reverse KL $\text{KL}(P_S \,|\, P_T)$ Mode-seeking: the student mainly fits the teacher's mode (high-probability) regions
Generalized JSD($\beta$) $\beta\,\text{KL}(P_T|M) + (1-\beta)\,\text{KL}(P_S|M)$, where $M=\beta P_T+(1-\beta)P_S$ Interpolates between the two

\(\beta=0\) reduces to Forward KL, \(\beta=1\) reduces to Reverse KL. SFT is equivalent to Forward KL (teacher is one-hot).

In swift:

  • GKD defaults to \(\beta=0.5\) (JSD); use --beta to choose among Forward / JSD / Reverse;

  • OPD-RL uses the Reverse KL k1 estimator \(\log\pi_{\text{teacher}}(y_t)-\log\pi_{\text{student}}(y_t)\) as per-token advantage.

(b) Computation granularity

Granularity Teacher information needed Notes
Full vocabulary Teacher's complete next-token distribution Exact divergence; high memory cost
Top-K K tokens with highest teacher probability Approximation after renormalization over top-K; suitable for external APIs (limited by max_logprobs)
Sampled token Single logp of the teacher on the student's sampled token Single-sample Monte Carlo estimate of Reverse KL; lowest communication cost

Accuracy vs cost: Full vocabulary requires materializing complete logits; sampled token only needs teacher logp on the sampled token (can use remote API). The DeepSeek-V4 technical report notes that using only sampled-token log-ratio as advantage yields high gradient variance, so its full-vocabulary OPD uses complete logit distillation.

2.2 How to Pass the Signal to the Student

Path A: GKD (direct loss) Path B: OPD-RL (RL advantage)
Training paradigm --rlhf_type gkd --rlhf_type grpo + teacher
Signal delivery Use signal as loss Use signal as advantage via policy gradient
Gradient flows through Student full-vocabulary logits (or top-k) Only student sampled token $\nabla\log\pi(y_t)$
Teacher information needed Full distribution (or top-k logits) Single logp on sampled token
Divergence choice Forward / Reverse / JSD (--beta) Reverse KL (k1 log-ratio)
Combine with task reward Mix SFT loss via sft_alpha Can stack with GRPO reward as advantage

Both paths share the same teacher infrastructure (see below); they differ only in how teacher KL is used.

Common uses of distillation

  1. Capability fusion: Distill multiple expert models into one unified model.

  2. Strong-to-weak: Transfer capabilities from a large model to a smaller one.

  3. Forgetting prevention: Use an old checkpoint as teacher to recover prior capabilities after multi-stage training.


3. Distillation Methods in swift

swift provides three distillation training methods. They share the same teacher infrastructure; differences fall within the framework in Section 2:

Method Signal path How to enable One-liner
GKD Direct loss (Path A) --rlhf_type gkd Teacher divergence as loss backprop; supports full-vocab / top-k divergence
OPD-RL RL advantage (Path B) --rlhf_type grpo + teacher Teacher log-ratio injected into GRPO advantage; can combine with task rewards
OPSD Path A or B Provide teacher_prompt on top of the above Single-model self-distillation: teacher input includes privileged info (e.g. reference solution)

Three teacher sources (shared by GKD and OPD-RL):

  • --teacher_model: Load a separate frozen teacher model in the training process.

  • --teacher_model_server: Connect to an external teacher service (swift deploy --infer_backend vllm) without loading the teacher on training GPUs. When using the API with GKD, also set --gkd_logits_topk.

  • Self-distillation: Teacher and student share the same source. For LoRA training when --teacher_model equals --model, the base model is used as a fixed teacher via disable_adapter() without extra loading; for GKD without --teacher_model, the student’s current weights serve as a dynamic teacher.

Teacher-related parameters (shared by GKD and OPD-RL; full details in command-line parameters):

Parameter Default Description
--teacher_model None Teacher model path; omit for dynamic self-distillation in GKD
--teacher_model_server None Teacher API URL (mutually exclusive with teacher_model)
--teacher_deepspeed None DeepSpeed config for teacher model (e.g. zero3)
--offload_teacher_model False Offload teacher to CPU when not in forward pass

3.1 GKD: Divergence as Direct Loss

GKD (Generalized Knowledge Distillation) directly backpropagates the divergence between teacher and student as the loss function.

Loss function

\[ \mathcal{L}_{\text{GKD}}(x, y) = \sum_{t=1}^{|y|} D_{\text{JSD}(\beta)}\big(P_{\text{teacher}}(\cdot|x,y_{<t}),\, P_{\text{student}}(\cdot|x,y_{<t})\big) \]

Divergence \(D\) is chosen via --beta (see 2.1): \(\beta=0\) is Forward KL, \(\beta=1\) is Reverse KL, \(0<\beta<1\) is generalized JSD (default \(0.5\)).

On-Policy vs Off-Policy: lmbda

GKD uses lmbda to control the probability that each batch uses student online sampling:

if random() <= lmbda:
    y = student.generate(x)   # on-policy: student samples
else:
    y = y_ground_truth        # off-policy: use dataset labels
loss = D(P_teacher(·|x, y), P_student(·|x, y))
  • lmbda=0: pure offline (traditional SFT distillation).

  • lmbda=1: pure online (student learns from its own mistakes, i.e. on-policy distillation).

  • 0<lmbda<1: mixed.

To use teacher-generated data, first offline-generate responses with the teacher and write them to the dataset, then train with lmbda=0.

GKD parameters

Parameter Type Default Description
--beta float 0.5 Divergence interpolation: 0=Forward KL, 0.5=JSD, 1=Reverse KL
--lmbda float 0.5 Online sampling probability: 0=offline, 1=pure online
--sft_alpha float 0 SFT loss mixing ratio; final loss = gkd_loss + sft_alpha * sft_loss (only for non-student-generated data)
--gkd_logits_topk int None Compute KL using teacher top-K logits only; required when using teacher_model_server

Top-K KL Computation

By default KL is computed over the full vocabulary, which can OOM on large vocabularies. Use --gkd_logits_topk.

External teacher API

When setting --teacher_model_server, also set --gkd_logits_topk (API returns only top-k logprobs). Example:

# Step 1: Deploy teacher model (max_logprobs must be >= gkd_logits_topk)
CUDA_VISIBLE_DEVICES=0 swift deploy \
    --model Qwen/Qwen3.5-9B \
    --infer_backend vllm \
    --port 8000 \
    --max_logprobs 64

# Step 2: Start GKD training
CUDA_VISIBLE_DEVICES=1,2,3,4 \
NPROC_PER_NODE=4 \
swift rlhf \
    --rlhf_type gkd \
    --model Qwen/Qwen3.5-2B \
    --teacher_model_server http://localhost:8000 \
    --gkd_logits_topk 64 \
    --lmbda 1.0 \
    --beta 1.0 \
    --dataset xxx

Online sampling acceleration

When lmbda > 0, the student must generate sequences online. Use vLLM to accelerate sampling (colocate / server modes, same as GRPO). See GRPO documentation.

Reference scripts


3.2 OPD-RL: KL as RL Advantage

OPD (On-Policy Distillation) RL injects teacher KL into GRPO’s per-token advantage and updates the student via policy gradient.

Principle

Standard GRPO advantage comes from group-normalized task rewards (per-sequence scalar). OPD-RL injects teacher signal token-by-token after advantage normalization:

\[ A_t = A_t^{\text{base}} + \alpha \cdot \big(\log \pi_{\text{teacher}}(y_t|x,y_{<t}) - \log \pi_{\text{student}}(y_t|x,y_{<t})\big) \]
  • \(A_t^{\text{base}}\): GRPO-normalized task reward advantage (0 when no reward function).

  • \(\alpha\): --teacher_kl_coef, teacher signal strength.

  • \(\log\pi_{\text{teacher}}(y_t) - \log\pi_{\text{student}}(y_t)\): teacher log-ratio, the k1 estimator corresponding to the coefficient of \(\nabla_\theta\log\pi_\theta(y_t)\) in the Reverse KL gradient (see compute_teacher_logratio).

Pure distillation mode: Without --reward_funcs, base advantage is 0 and teacher signal is the only driver: \(A_t = \alpha\cdot(\log\pi_{\text{teacher}}(y_t)-\log\pi_{\text{student}}(y_t))\).

Monitoring metric: teacher_kl in logs is the k3 estimator \(e^{d}-d-1\) (\(d=\log\pi_{\text{teacher}}-\log\pi_{\text{student}}\)), measuring distance between student and teacher.

How to enable: Under --rlhf_type grpo, set --teacher_model or --teacher_model_server to automatically enable OPD-RL—no extra switch needed. See shared teacher parameters above.

OPD-RL-specific parameters

Parameter Default Description
--teacher_kl_coef 1.0 Coefficient $\alpha$ for injecting teacher log-ratio into advantage

Reference scripts


3.3 OPSD: On-Policy Self-Distillation

OPSD (On-Policy Self-Distillation) is a single-model self-distillation method: the same model constructs student and teacher inputs separately; the teacher side additionally receives privileged information (e.g. reference solution), then aligns output distributions on the student’s sampled response.

Core mechanism

  • Student: sees only the problem and reasons normally.

  • Teacher: sees problem + reference solution (privileged info via teacher_prompt column).

  • Training objective: align student and teacher output distributions on the same student-sampled response via divergence (JSD / KL).

OPSD can follow either GKD or OPD-RL path:

  • GKD + OPSD: --rlhf_type gkd, teacher KL as direct loss.

  • OPD-RL + OPSD: --rlhf_type grpo + --teacher_model (same as --model), teacher KL as advantage.

Two self-distillation weight modes

Mode Configuration Teacher weights Description
Dynamic Omit --teacher_model Student's current weights Teacher updates with training
Fixed Set --teacher_model same as --model Initial teacher weights Teacher weights fixed

Data format

OPSD datasets need a teacher_prompt column. Load a data preprocessing plugin via --external_plugins to build it. Example with math reasoning dataset open-r1/OpenThoughts-114k-math:

from swift.dataset import DatasetMeta, RowPreprocessor, register_dataset

class OpenThoughtsOPSDPreprocessor(RowPreprocessor):
    def preprocess(self, row):
        if not row.get('correct', True):
            return None
        problem = row.get('problem', '')
        solution = row.get('solution', '')
        teacher_prompt = f'{problem}\n\nReference solution:\n{solution}\n\nNow articulate your own reasoning.'
        messages = [
            {'role': 'system', 'content': 'Please reason step by step, and put your final answer within \\boxed{}.'},
            {'role': 'user', 'content': problem},
        ]
        return {'messages': messages, 'teacher_prompt': teacher_prompt}

register_dataset(DatasetMeta(
    ms_dataset_id='open-r1/OpenThoughts-114k-math',
    preprocess_func=OpenThoughtsOPSDPreprocessor(),
    tags=['math', 'opsd'],
))

Reference scripts


Reference