GKD
If you are new to GKD, please refer to the GKD Documentation first.
GKD (Generalized Knowledge Distillation) is a training method that transfers knowledge from a teacher model to a student model by computing the Jensen-Shannon Divergence (JSD) loss between their output distributions.
Feature Support
Megatron GKD currently supports the following features:
Training Modes: Full parameter training and LoRA fine-tuning
Parallelism Strategies: Context Parallel (CP), Pipeline Parallel (PP), Tensor Parallel (TP), and Expert Parallel (EP)
Model Support: Compatible with LLMs and MLLMs in Megatron-SWIFT
Teacher Offload: Supports offloading teacher model to CPU to save GPU memory
Online Generation: Supports on-policy generation using vLLM for student model
Parameters
GKD-specific Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
--teacher_model |
str | - | Path or model ID of the teacher model *Can be omitted when using teacher_model_server |
--teacher_model_server |
str | None | Teacher model service URL (vllm serve only), e.g. http://localhost:8000 |
--gkd_logits_topk |
int | None | Number of Top-K logits; required when using external API |
--beta |
float | 0.5 | JSD divergence interpolation coefficient: • 0.0: Forward KL • 0.5: Symmetric JSD • 1.0: Reverse KL |
--lmbda |
float | 0.5 | On-Policy learning probability: • 0.0: Pure Off-Policy • 1.0: Pure On-Policy |
--temperature |
float | 0.9 | Temperature for sampling and loss computation |
--sft_alpha |
float | 0 | Mix in a proportion of SFT loss; applied to non-student-generated completions |
--max_completion_length |
int | 512 | Maximum tokens for generation |
Reference
For more parameters, please refer to Command-line Parameters
For training scripts, please refer to Megatron GKD Scripts
Training script using Teacher Server reference here