GKD

If you are new to GKD, please refer to the GKD Documentation first.

GKD (Generalized Knowledge Distillation) is a training method that transfers knowledge from a teacher model to a student model by computing the Jensen-Shannon Divergence (JSD) loss between their output distributions.

Feature Support

Megatron GKD currently supports the following features:

  • Training Modes: Full parameter training and LoRA fine-tuning

  • Parallelism Strategies: Context Parallel (CP), Pipeline Parallel (PP), Tensor Parallel (TP), and Expert Parallel (EP)

  • Model Support: Compatible with LLMs and MLLMs in Megatron-SWIFT

  • Teacher Offload: Supports offloading teacher model to CPU to save GPU memory

  • Online Generation: Supports on-policy generation using vLLM for student model

Parameters

GKD-specific Parameters

Parameter Type Default Description
--teacher_model str - Path or model ID of the teacher model
*Can be omitted when using teacher_model_server
--teacher_model_server str None Teacher model service URL (vllm serve only), e.g. http://localhost:8000
--gkd_logits_topk int None Number of Top-K logits; required when using external API
--beta float 0.5 JSD divergence interpolation coefficient:
• 0.0: Forward KL
• 0.5: Symmetric JSD
• 1.0: Reverse KL
--lmbda float 0.5 On-Policy learning probability:
• 0.0: Pure Off-Policy
• 1.0: Pure On-Policy
--temperature float 0.9 Temperature for sampling and loss computation
--sft_alpha float 0 Mix in a proportion of SFT loss; applied to non-student-generated completions
--max_completion_length int 512 Maximum tokens for generation

Reference

For more parameters, please refer to Command-line Parameters

For training scripts, please refer to Megatron GKD Scripts

Training script using Teacher Server reference here