# Knowledge Distillation Knowledge distillation is a training method that transfers capabilities from a teacher model to a student model. The core idea is to have the student align with the teacher's output distribution at each token position, yielding richer supervision than simply imitating labeled answers—the teacher tells the student not only which token is correct, but also how good or bad other tokens are. This document introduces, top-down: why distillation works (Section 1), a unified design framework for distillation methods (Section 2), and finally three concrete distillation training methods in swift: GKD / OPD-RL / OPSD (Section 3). --- ## 1. Why Distillation: From Sparse to Dense Signals A language model's capabilities are typically built through a stack of training stages: - **Pre-training**: Acquire language, world knowledge, basic reasoning, and other general capabilities. - **Mid-training**: Inject domain knowledge, such as code, medicine, or internal company documents. - **Post-training**: Elicit target behaviors, such as instruction following, mathematical reasoning, or dialogue style. Distillation mainly happens during **post-training**. To understand its value, consider post-training methods along two independent dimensions: 1. **Sampling mode (where data comes from)**: Whether training sequences are generated by the student itself (on-policy) or come from external fixed data (off-policy). 2. **Feedback density (how much each sequence teaches)**: Whether the entire sequence has a single reward (per-sequence, sparse) or each token has a signal (per-token, dense). **SFT / offline distillation** (off-policy + dense): Align with labels or the teacher distribution on fixed data. The signal is dense, but training only sees states from the teacher/labels, which differ from states the student enters during its own inference. Once the student makes an early mistake the teacher would not make, it enters unseen states, errors accumulate—this is called **exposure bias**. **RL** (on-policy + sparse): The student samples trajectories and receives rewards based on final outcomes. The distribution matches student inference, but rewards are typically **sequence-level** scalars that do not pinpoint which token went wrong. **On-policy distillation** (on-policy + dense): The student samples trajectories, and the teacher scores **every token** on those trajectories. Training distribution matches student inference, and feedback is per-token. ### SFT as a Special Case of Distillation A natural entry point to understanding distillation is the SFT loss. SFT's cross-entropy loss is equivalent to KL divergence against a one-hot "teacher" distribution $\delta_{y^*}$ at the labeled token: $$-\log P_S(y^*_t) = \text{KL}(\delta_{y^*} \,\|\, P_S)$$ Knowledge distillation simply replaces this deterministic one-hot "teacher" with the teacher model's soft distribution $P_T$, optimizing $\text{KL}(P_T \,\|\, P_S)$ at each token to provide richer supervision than one-hot. | Method | Sampling Mode | Feedback Density | Teacher Distribution | |------|----------|----------|----------| | SFT | off-policy (fixed data) | dense (per-token) | one-hot $\delta_{y^*}$ | | Offline (off-policy) distillation | off-policy (fixed or teacher-generated data) | dense (per-token) | teacher soft distribution | | RL | on-policy (student sampling) | sparse (per-sequence) | none | | **On-policy distillation** | **on-policy (student sampling)** | **dense (per-token)** | **teacher soft distribution** | --- ## 2. Two Core Choices in Distillation Differences between distillation methods almost always boil down to two questions. Once you understand these dimensions, the methods below are just different combinations. ### 2.1 How to Compute the Teacher Signal At each token position, quantify the difference between teacher distribution $P_T$ and student distribution $P_S$ (we call this **Teacher KL**). There are two sub-choices. **(a) Divergence direction** | Divergence | Definition | Optimization behavior (information-theoretic meaning) | |------|------|------| | Forward KL | $\text{KL}(P_T \,\|\, P_S)$ | Mode-covering: the student must assign enough probability to regions where the teacher has high probability | | Reverse KL | $\text{KL}(P_S \,\|\, P_T)$ | Mode-seeking: the student mainly fits the teacher's mode (high-probability) regions | | Generalized JSD($\beta$) | $\beta\,\text{KL}(P_T\|M) + (1-\beta)\,\text{KL}(P_S\|M)$, where $M=\beta P_T+(1-\beta)P_S$ | Interpolates between the two | > $\beta=0$ reduces to Forward KL, $\beta=1$ reduces to Reverse KL. SFT is equivalent to Forward KL (teacher is one-hot). In swift: - GKD defaults to $\beta=0.5$ (JSD); use `--beta` to choose among Forward / JSD / Reverse; - OPD-RL uses the Reverse KL k1 estimator $\log\pi_{\text{teacher}}(y_t)-\log\pi_{\text{student}}(y_t)$ as per-token advantage. **(b) Computation granularity** | Granularity | Teacher information needed | Notes | |----------|---------------|------| | Full vocabulary | Teacher's complete next-token distribution | Exact divergence; high memory cost | | Top-K | K tokens with highest teacher probability | Approximation after renormalization over top-K; suitable for external APIs (limited by `max_logprobs`) | | Sampled token | Single logp of the teacher on the student's sampled token | Single-sample Monte Carlo estimate of Reverse KL; lowest communication cost | > **Accuracy vs cost**: Full vocabulary requires materializing complete logits; sampled token only needs teacher logp on the sampled token (can use remote API). The [DeepSeek-V4](https://arxiv.org/abs/2606.19348) technical report notes that using only sampled-token log-ratio as advantage yields high gradient variance, so its full-vocabulary OPD uses complete logit distillation. ### 2.2 How to Pass the Signal to the Student | | **Path A: GKD (direct loss)** | **Path B: OPD-RL (RL advantage)** | |---|---|---| | Training paradigm | `--rlhf_type gkd` | `--rlhf_type grpo` + teacher | | Signal delivery | Use signal as loss | Use signal as advantage via policy gradient | | Gradient flows through | Student **full-vocabulary** logits (or top-k) | Only student **sampled token** $\nabla\log\pi(y_t)$ | | Teacher information needed | Full distribution (or top-k logits) | Single logp on sampled token | | Divergence choice | Forward / Reverse / JSD (`--beta`) | Reverse KL (k1 log-ratio) | | Combine with task reward | Mix SFT loss via `sft_alpha` | Can stack with GRPO reward as advantage | Both paths **share the same teacher infrastructure** (see below); they differ only in how teacher KL is used. > **Common uses of distillation** > 1. **Capability fusion**: Distill multiple expert models into one unified model. > 2. **Strong-to-weak**: Transfer capabilities from a large model to a smaller one. > 3. **Forgetting prevention**: Use an old checkpoint as teacher to recover prior capabilities after multi-stage training. --- ## 3. Distillation Methods in swift swift provides three distillation training methods. They share the same teacher infrastructure; differences fall within the framework in Section 2: | Method | Signal path | How to enable | One-liner | |------|--------------|----------|--------| | **GKD** | Direct loss (Path A) | `--rlhf_type gkd` | Teacher divergence as loss backprop; supports full-vocab / top-k divergence | | **OPD-RL** | RL advantage (Path B) | `--rlhf_type grpo` + teacher | Teacher log-ratio injected into GRPO advantage; can combine with task rewards | | **OPSD** | Path A or B | Provide `teacher_prompt` on top of the above | Single-model self-distillation: teacher input includes privileged info (e.g. reference solution) | **Three teacher sources** (shared by GKD and OPD-RL): - `--teacher_model`: Load a separate frozen teacher model in the training process. - `--teacher_model_server`: Connect to an external teacher service (`swift deploy --infer_backend vllm`) without loading the teacher on training GPUs. When using the API with GKD, also set `--gkd_logits_topk`. - **Self-distillation**: Teacher and student share the same source. For LoRA training when `--teacher_model` equals `--model`, the base model is used as a fixed teacher via `disable_adapter()` without extra loading; for GKD without `--teacher_model`, the student's current weights serve as a dynamic teacher. **Teacher-related parameters** (shared by GKD and OPD-RL; full details in [command-line parameters](./Command-line-parameters.md)): | Parameter | Default | Description | |------|--------|------| | `--teacher_model` | None | Teacher model path; omit for dynamic self-distillation in GKD | | `--teacher_model_server` | None | Teacher API URL (mutually exclusive with `teacher_model`) | | `--teacher_deepspeed` | None | DeepSpeed config for teacher model (e.g. `zero3`) | | `--offload_teacher_model` | False | Offload teacher to CPU when not in forward pass | --- ### 3.1 GKD: Divergence as Direct Loss GKD ([Generalized Knowledge Distillation](https://arxiv.org/pdf/2306.13649)) directly backpropagates the divergence between teacher and student as the loss function. **Loss function** $$ \mathcal{L}_{\text{GKD}}(x, y) = \sum_{t=1}^{|y|} D_{\text{JSD}(\beta)}\big(P_{\text{teacher}}(\cdot|x,y_{ To use teacher-generated data, first offline-generate responses with the teacher and write them to the dataset, then train with `lmbda=0`. **GKD parameters** | Parameter | Type | Default | Description | |------|------|--------|------| | `--beta` | float | 0.5 | Divergence interpolation: 0=Forward KL, 0.5=JSD, 1=Reverse KL | | `--lmbda` | float | 0.5 | Online sampling probability: 0=offline, 1=pure online | | `--sft_alpha` | float | 0 | SFT loss mixing ratio; final `loss = gkd_loss + sft_alpha * sft_loss` (only for **non-student-generated** data) | | `--gkd_logits_topk` | int | None | Compute KL using teacher top-K logits only; required when using `teacher_model_server` | ### Top-K KL Computation By default KL is computed over the full vocabulary, which can OOM on large vocabularies. Use `--gkd_logits_topk`. **External teacher API** When setting `--teacher_model_server`, also set `--gkd_logits_topk` (API returns only top-k logprobs). Example: ```bash # Step 1: Deploy teacher model (max_logprobs must be >= gkd_logits_topk) CUDA_VISIBLE_DEVICES=0 swift deploy \ --model Qwen/Qwen3.5-9B \ --infer_backend vllm \ --port 8000 \ --max_logprobs 64 # Step 2: Start GKD training CUDA_VISIBLE_DEVICES=1,2,3,4 \ NPROC_PER_NODE=4 \ swift rlhf \ --rlhf_type gkd \ --model Qwen/Qwen3.5-2B \ --teacher_model_server http://localhost:8000 \ --gkd_logits_topk 64 \ --lmbda 1.0 \ --beta 1.0 \ --dataset xxx ``` **Online sampling acceleration** When `lmbda > 0`, the student must generate sequences online. Use vLLM to accelerate sampling (colocate / server modes, same as GRPO). See [GRPO documentation](./GRPO/GetStarted/GRPO.md#cluster-support). **Reference scripts** - Basic training: [examples/train/rlhf/gkd/](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/gkd/) - Multimodal: [examples/train/multimodal/rlhf/gkd/](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/rlhf/gkd/) - Megatron: [examples/megatron/rlhf/gkd/](https://github.com/modelscope/ms-swift/tree/main/examples/megatron/rlhf/gkd/) --- ### 3.2 OPD-RL: KL as RL Advantage OPD (On-Policy Distillation) RL injects teacher KL into GRPO's **per-token advantage** and updates the student via policy gradient. **Principle** Standard GRPO advantage comes from group-normalized task rewards (per-sequence scalar). OPD-RL injects teacher signal token-by-token **after** advantage normalization: $$ A_t = A_t^{\text{base}} + \alpha \cdot \big(\log \pi_{\text{teacher}}(y_t|x,y_{