Add Quantization Aware Distillation (QAD) to Megatron-Bridge example#1600
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughExtends Megatron-Bridge distillation to support Quantization Aware Distillation (QAD) by allowing distill.py to initialize the student model from a quantized Megatron checkpoint. Adds CLI argument, monkeypatch-based restoration workflow, documentation, quantize/export adjustments, and an end-to-end test. ChangesQAD Feature Implementation
Sequence DiagramsequenceDiagram
participant Test as test_qad
participant Quantize as quantize.py
participant MegatronCKPT as Megatron checkpoint (dir)
participant Distill as distill.py
participant Loader as load_modelopt_megatron_checkpoint
participant Export as export.py
participant HF as HuggingFace unified checkpoint
Test->>Quantize: run quantize.py -> produce MegatronCKPT
Test->>Distill: run distill.py --student_megatron_path MegatronCKPT --teacher_hf_path ...
Distill->>Loader: restore ModelOpt state & weights into student
Distill->>MegatronCKPT: write distilled Megatron checkpoints (modelopt_state preserved)
Test->>Export: run export.py on distilled MegatronCKPT
Export->>HF: export_mcore_gpt_to_hf -> produce HF checkpoint artifacts
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
9ba6385 to
7787f98
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1600 +/- ##
==========================================
+ Coverage 76.31% 76.89% +0.58%
==========================================
Files 488 489 +1
Lines 54386 54415 +29
==========================================
+ Hits 41503 41844 +341
+ Misses 12883 12571 -312
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
7787f98 to
4e22a59
Compare
|
/claude review |
There was a problem hiding this comment.
Claude review passed — no blocking issues found. LGTM
Reviewed Part 2 of the QAD-on-Megatron-Bridge series (extends distill.py with --student_megatron_path to load a quantized Megatron checkpoint as the student before the KD conversion, plus an end-to-end QAD test).
Findings: 0 CRITICAL · 0 IMPORTANT · 3 SUGGESTION (all inline)
The class-level provide() monkey-patch is well-documented as a 26.04 workaround with a clear removal path once student_pre_conversion_hook lands in 26.06; the test exercises the full quantize → QAD-distill → unified-HF-export flow and confirms modelopt_state survives. Suggestions are about local clarity, not behavior:
_restore_megatron_student: thestrict=Falserationale referencing "in-memory teacher weights" doesn't match the only call site (teacher isn't built yet at that point).student_is_quantized/quantized: boolactually mean "checkpoint has any ModelOpt mode state"; safe today (prune_minitronstrips its state, onlyquantize.pyemits any), but the QAD-specific log message andgradient_accumulation_fusion = Falsewould fire incorrectly if any other mode starts shipping state.id(self)-keyed registry silently falls back to vanilla distillation if the framework ever wraps/copies the provider beforeprovide()is called — consider asserting the lookup hit when--student_megatron_pathwas set so the failure is loud rather than producing an uninitialized-student run.
70b1610 to
46056fa
Compare
46056fa to
2ed6c0c
Compare
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/megatron_bridge/distill.py`:
- Line 104: The parameter name has_modelopt_state in the function
_restore_megatron_student shadows the imported function has_modelopt_state;
rename the parameter (e.g., to modelopt_present or has_modelopt_flag) in
_restore_megatron_student, update all references inside that function to the new
parameter name, and update all call sites of _restore_megatron_student to pass
the renamed parameter variable so the imported has_modelopt_state function
remains callable.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 7cf45685-2015-4140-bfc0-7c93ab8c17e4
📒 Files selected for processing (4)
CHANGELOG.rstexamples/megatron_bridge/README.mdexamples/megatron_bridge/distill.pytests/examples/megatron_bridge/test_qad.py
2ed6c0c to
f0d1988
Compare
|
|
||
| # _load_model_weights_from_checkpoint is the (private) helper bridge.load_megatron_model uses to load | ||
| # a (quantized) Megatron checkpoint into an already-built model; reused here to initialize the student. | ||
| from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint |
There was a problem hiding this comment.
why can't you just use bridge.load_megatron_model?
There was a problem hiding this comment.
Because of some checks in it which doesnt cause load_modelopt_state to be invoked at the righ time in case of quantized ckpt, we need a temporary workaround. @AAnoosheh is looking into the correct fix in megatron bridge side. Alternatively as we discussed in Nemo-ModelOpt meeting, ideally we will have mbridge natively support quantized ckpt then we wont need this workaround
| if restore_modelopt_state: | ||
| load_modelopt_state([student_model], str(ckpt_root)) | ||
| print_rank_0(f"Loading student weights from Megatron checkpoint {ckpt_dir}") | ||
| # strict=False because the bridge loader strips Transformer-Engine extra-state from the loaded |
There was a problem hiding this comment.
if the mbridge loader strips extra-state from the sharded checkpoint, how are the amax values restored?
There was a problem hiding this comment.
Its restored from load_modelopt_state call above
9d94730 to
87c33b5
Compare
3147d66 to
82982cb
Compare
There was a problem hiding this comment.
Should we name this file train.py and support both QAT and QAD similar to HF llm_qat folder?
There was a problem hiding this comment.
My understanding is we dont have a good story for using QAT and instead recommend QAD. But if needed, I can change it in a follow-up PR to also support QAD
|
Can we add a link to this from https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_qat as performant QAT/QAD backend? |
72013f2 to
67bc3a5
Compare
| import vllm | ||
|
|
||
| DEFAULT_PROMPTS = [ | ||
| "Hello!", |
There was a problem hiding this comment.
this can be addressed in later PR but: these prompts are rather short, would be good to have some more long generation/agentic task
| ) | ||
| parser.add_argument("--calib_batch_size", type=int, default=1, help="Calibration batch size") | ||
| parser.add_argument("--seq_length", type=int, default=4096, help="Calibration sequence length") | ||
| parser.add_argument("--seq_length", type=int, default=512, help="Calibration sequence length") |
There was a problem hiding this comment.
yea we need longer seq length than 512, like 2048 or 4096. and also suggest a longer seq length for higher quality calibration
| --quant_cfg fp8 \ | ||
| --tp_size 2 \ | ||
| --calib_batch_size 16 \ | ||
| --seq_length 512 \ |
There was a problem hiding this comment.
are examples in nemotron-post-training really under 2k seq length? nemotron is always post trained with long seq length like 256k
Extend examples/megatron_bridge/distill.py with --student_megatron_path to initialize the student from a Megatron checkpoint (a quantized checkpoint from quantize.py, or a pruned one) instead of HuggingFace weights; --student_hf_path still builds the architecture. For a quantized checkpoint, the ModelOpt quantize mode + base weights are restored onto the plain student before the knowledge-distillation conversion (restore_sharded_modelopt_state is a no-op once a model is already converted), so the distilled checkpoint stays exportable as a quantized model with export.py. Until nemo:26.06 (which adds DistillationProvider.student_pre_conversion_hook upstream), this is done by patching DistillationProvider.provide at the class level via an id()-keyed registry, since the provider proxies instance attribute assignment to its teacher once the teacher is set. A removal note documents the upstream-hook replacement. Add tests/examples/megatron_bridge/test_qad.py covering quantize -> QAD -> export, asserting hf_quant_config.json is written so the distilled checkpoint stays exportable as a quantized model. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
67bc3a5 to
2b5ad6a
Compare
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
What does this PR do?
Type of change: new example
Note: This is part 2 of 4 (builds on #1589):
quantize.py+export.pysupport and tests.distill.pyfor quantization-aware distillation (QAD) — load a quantized Megatron checkpoint as the student.Extends
examples/megatron_bridge/distill.pyto initialize the student from a Megatron checkpoint (a quantized checkpoint fromquantize.py, or a pruned one) via--student_megatron_path, enabling Quantization Aware Distillation (QAD):--student_hf_pathstill builds the student architecture;--student_megatron_pathsupplies the (optionally quantized) weights.restore_sharded_modelopt_stateis a no-op once a model is already converted), so the distilled checkpoint stays exportable as a quantized model withexport.py.Upstream dependency / workaround:
DistillationProvider.provide()has no seam to transform the student before the KD conversion, so this patchesprovide()at the class level (via anid()-keyed registry, because the provider proxies instance-attribute assignment to its teacher once the teacher is set). A companion Megatron-Bridge PR adds a first-classDistillationProvider.student_pre_conversion_hook; from nemo:26.06 onwards the workaround should be removed and replaced with that hook (a removal note indistill.pydocuments exactly how).Usage
Testing
tests/examples/megatron_bridge/test_qad.py(validated on a 2-GPU NeMo26.04container): quantize a tiny Qwen3 at TP=2 → QAD distill from the quantized student →export.pyto a unified HF checkpoint, assertinghf_quant_config.jsonis written (proves the quantize mode survived QAD). Includes a commented-out vLLM deployment check, validated locally (full flow passes; vLLM loads the export asquantization=modelopt). Existing normal/Puzzletron distillation tests still pass.Before your PR is "Ready for review"
--student_megatron_pathis not set)CONTRIBUTING.md: N/A (no new dependencies)Additional Information
Depends on a companion Megatron-Bridge PR adding
DistillationProvider.student_pre_conversion_hook(the upstream replacement for the class-levelprovide()workaround). The Nemotron-3 tutorial NVFP4 + QAD experiments ship in part 3.Summary by CodeRabbit
New Features
Documentation
Tests
Chores / UX