From 9648a91b7cfd6feb898dc9ba473d115dc9b8346c Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Tue, 12 May 2026 12:34:48 +0200 Subject: [PATCH 1/8] Split bypass Puzzletron integration Signed-off-by: Sepehr Sameni --- .../Nemotron-3-Nano-30B-A3B-Base-BF16.md | 95 ++ examples/puzzletron/README.md | 2 +- .../bypass/defaults.yaml | 130 ++ ...DIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml | 110 ++ .../bypass/defaults.yaml | 120 ++ .../nemotron-3-nano-30b-a3b-with-bypass.yaml | 4 + .../nemotron-3-nano-30b-a3b.yaml | 30 + .../pruning/kv_heads_pruning.yaml | 24 + .../pruning/pruning_defaults.yaml | 33 + .../validate_model_defaults.yaml | 17 + .../validate_solutions_defaults.yaml | 10 + examples/puzzletron/main.py | 26 +- modelopt/torch/puzzletron/mip/run_puzzle.py | 5 + .../torch/puzzletron/puzzletron_nas_plugin.py | 248 +++- .../build_replacement_library.py | 34 +- tests/gpu/torch/puzzletron/test_bypass.py | 1112 +++++++++++++++++ .../test_bypass_checkpoint_utils.py | 201 +++ .../torch/puzzletron/test_bypass_resume.py | 251 ++++ tests/gpu/torch/puzzletron/test_puzzletron.py | 14 +- .../test_bypass_replacement_library.py | 246 ++++ .../puzzletron/test_puzzletron_progress.py | 113 ++ 21 files changed, 2768 insertions(+), 57 deletions(-) create mode 100644 examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md create mode 100644 examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b-with-bypass.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/kv_heads_pruning.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/pruning_defaults.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_model_defaults.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_solutions_defaults.yaml create mode 100644 tests/gpu/torch/puzzletron/test_bypass.py create mode 100644 tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py create mode 100644 tests/gpu/torch/puzzletron/test_bypass_resume.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_replacement_library.py create mode 100644 tests/unit/torch/puzzletron/test_puzzletron_progress.py diff --git a/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md b/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md new file mode 100644 index 00000000000..1f3ad6983df --- /dev/null +++ b/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md @@ -0,0 +1,95 @@ +# Bypass Distillation Tutorial: Nemotron-3-Nano-30B-A3B (KV-heads-only) + +A minimal end-to-end demonstration that **bypass distillation improves quality** at the same compression budget. The setup is a **toy pruning task on a real production model** — we compress only KV heads (12 → 9, a modest 25% reduction) so a single comparison surfaces the bypass benefit cleanly without needing extensive downstream evaluation. The model itself ([Nemotron-3-Nano-30B-A3B-Base-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16)) is a real 30B-A3B MoE-Mamba hybrid, not a tiny stand-in. + +## What this tutorial does + +The teacher has 6 attention layers (each with `num_key_value_heads=2`) interleaved between Mamba and MoE-FFN blocks — **12 KV heads total** across the whole model. We compress to **9 KV heads (75% of teacher)** in two ways and compare: + +1. **Without bypass** — replacement library uses Truncate-init weights (KV heads sliced from teacher; no further training). +2. **With bypass** — the bypass step runs ~50M tokens of per-block knowledge distillation, training a 1-KV-head variant per attention layer against the teacher. + +Both runs use the same MIP solver and the same constraint (`target_num_kv_heads: 9`), so MIP picks per attention layer from `{teacher 2-head, 1-head}`. FFN/MoE/Mamba blocks are copied verbatim from the teacher in both runs — only attention weights change. + +**Metrics:** `lm_loss` and `token_accuracy_top_1` measured against the same held-out dataset by the realize-model step (printed automatically to `puzzle_dir/log.txt`). + +## Hardware & install + +- 8×H100 80GB (the teacher needs ≥60 GiB for activation scoring on a 4096 context). +- Container: `nvcr.io/nvidia/nemo:26.04` or later. +- `pip install -e ".[dev]"` from the modelopt repo root. +- Mamba kernels (required by Nemotron-3-Nano's hybrid backbone): + + ```bash + pip install mamba-ssm[causal-conv1d] --no-build-isolation + ``` + +- HF auth set up so the model is downloadable: `huggingface-cli login`. + +## Step A — pipeline without bypass + +Edit `examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml` to point `puzzle_dir` and `dataset_path` at writable locations, then: + +```bash +torchrun --nproc_per_node=8 examples/puzzletron/main.py \ + --config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml +``` + +This runs the 8-step puzzletron pipeline (convert → score pruning activations → prune → build replacement library → score replacements → MIP → realize). With `bypass:` added in Step B the pipeline grows to 9 steps; without it, the bypass step is skipped and progress prints `N/8`. Wall-clock: roughly **1h on 8×H100** for this KV-heads-only task (KV-head importance scoring is one forward pass via `IndependentKvHeadContributionHook`, much cheaper than iterative FFN-channel scoring). + +When the realize-model step finishes, the log lines at `${puzzle_dir}/log.txt` contain: + +```text +validate_model_with_kl_div(model_name='teacher', ...) +Average losses = {'lm_loss': ..., 'token_accuracy_top_1': ..., 'token_accuracy_top_5': ..., 'token_accuracy_top_10': ...} +... +validate_model_with_kl_div(model_name='solution_0', ...) +Average losses = {..., 'token_accuracy_top_1': ..., ...} +``` + +Record the teacher's `token_accuracy_top_1` and `solution_0`'s `token_accuracy_top_1`. **Move or rename `${puzzle_dir}/single_sequence_replacement_solutions--validation/` and `${puzzle_dir}/mip/` aside** before Step B if you want to keep the no-bypass artifacts — Step B reuses the same `puzzle_dir` and the library/scoring/MIP outputs will be overwritten. + +## Step B — pipeline with bypass + +Use the bypass-enabled config, which overrides the base config's empty `- bypass:` entry with `bypass: defaults`: + +```yaml +defaults: + - nemotron-3-nano-30b-a3b + - override bypass: defaults + - _self_ +``` + +Run the bypass config: + +```bash +torchrun --nproc_per_node=8 examples/puzzletron/main.py \ + --config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b-with-bypass.yaml +``` + +Skip-if-done caching reuses Step A's converted teacher checkpoint, activation scores, and pruned checkpoints. Only Step 5 (bypass distillation, ~50M tokens) and the downstream library/scoring/MIP rerun. + +Bypass writes its outputs under `${puzzle_dir}/bypass/bypass_runs//` and creates a symlink `${puzzle_dir}/ckpts/` that the replacement library builder picks up automatically. + +Capture `solution_0`'s `token_accuracy_top_1` from the new realize-model log section. + +## Results + +Reducing total KV heads from 12 → 9 (75% of teacher) at fixed FFN/MoE/Mamba on Nemotron-3-Nano-30B-A3B-Base-BF16: + +| Run | `target_num_kv_heads` | `lm_loss` | `token_accuracy_top_1` | +|------------------------------|----------------------:|----------:|-----------------------:| +| Teacher | 12 | 0.5950 | 0.8468 | +| Pruned, **no bypass** (Truncate-init) | 9 | 0.6347 | 0.8373 | +| Pruned, **with bypass** (50M-token BLD) | 9 | **0.6055**| **0.8441** | + +**Bypass closes ~74% of the regression gap** at this compression budget: + +- `lm_loss` gap to teacher: `0.0397` without bypass → `0.0105` with bypass — bypass recovers **74%**. +- `token_accuracy_top_1` gap to teacher: `0.0095` without bypass → `0.0027` with bypass — bypass recovers **72%**. + +For 50M tokens of per-block KD, that's a substantial lift on a real 30B-A3B teacher. + +## Going further: full accuracy recovery + +Bypass distillation is Stage 1 of the PUZZLE pipeline — local, per-block KD that tightens the replacement library. For larger compression targets (or more aggressive KV pruning) you'll want Stage 2: **global knowledge distillation** on the realized student. See [`examples/pruning/puzzletron/`](../pruning/puzzletron/) for the Megatron-Bridge recipe and concrete MMLU recovery numbers. diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index 571b40ca499..93f8ced1cd5 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -11,7 +11,7 @@ To use the Puzzle algorithm effectively, we need to specify the target number of In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md). -> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md). +> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100, Nemotron-3-Nano-30B-A3B-Base-BF16 on 8x H100 — see the [bypass distillation tutorial](Nemotron-3-Nano-30B-A3B-Base-BF16.md)). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md). ## Environment diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml new file mode 100644 index 00000000000..2a13a4c9742 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml @@ -0,0 +1,130 @@ +# @package bypass +# Bypass Distillation Configuration +# This config defines parameters for blockwise local distillation (BLD), +# which trains alternative transformer block configurations using per-block +# knowledge distillation from a teacher model. + +# Runtime Configuration +dtype: "bf16" # Model precision: bf16 for efficiency, fp32 for stability +seed: 42 # Random seed for reproducibility + +# Experiment Tracking +experiment_id: # Unique identifier for this experiment. Will be dynamically set +experiment_dir: # Directory for this experiment. Will be dynamically set +iter_num: 1 # Current iteration number +step_num: 1 # Current step number within iteration +token_count: 0 # Token count tracker (auto-updated during training) + +# Data Configuration +data: + data_column: "messages" + block_size: 512 # Sequence length (tokens per sample) + bos_rate: 0.5 + fim_rate: 0 + fim_spm_rate: 0 + source_datasets_to_discard: [] + load_from_disk: true # Load preprocessed data from disk or from stream + keep_in_memory: false + val_dataset_name: valid + max_eval_samples: 4 + eval_samples_per_process: # Samples per GPU during distributed eval (auto if null) + shuffle_train_data_seed: ${random_int:0,9999} # Seed for shuffling train data + +# Training Configuration +training: + learning_rate: 1e-4 # Initial learning rate (1e-4 = 0.0001) + training_tokens: 1e+4 # Total training tokens (10K tokens - sanity check) + micro_batch_size: 2 + val_micro_batch_size: 1 + warmup_ratio: 0.05 + warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.grad_accumulation_steps},${.warmup_ratio}} # Auto-calculated warmup steps + min_lr_factor: 1e-5 + grad_accumulation_steps: 1 + skip_first_batches: 0 # Use for debugging or to skip few batches which cause crashes or optimization issues. + weight_decay: 0.1 + decay_lr: true + beta1: 0.9 + beta2: 0.95 + use_grad_scaling: false + grad_clip: 1.0 + grad_clip_type: norm + clipping_count: 0 + log_interval: 5 + eval_interval: 5 + +# Model Loading Configuration +resume_checkpoint_path: # Path to resume training from checkpoint +find_last_ckpt_for_resume: true # Auto-resume by finding last checkpoint (bool) +parameter_count: +init_checkpoint_path: # Path to initialize weights from + +model: + student_weights_dtype: "bf16" # Student model weight precision + + model_overrides: + delete_old_checkpoints: true # Clean up old checkpoints to save disk space + save_interval_seconds: 12900 # Save checkpoint every ~3.5 hours + save_interval: 1e+9 # Save checkpoint every 1B steps (effectively disabled) + save_checkpoint_when_done: true # Save final checkpoint when training completes + +# Architecture modifications for student model + model_config_overrides: + ffn: + - intermediate_size: + no_op: # Disable FFN entirely (true/false) + attention: + - num_key_value_heads: # Number of kv-heads (for GQA) + no_op: # Disable attention entirely (true/false) + +# Model Factory Configuration - Controls student model creation and initialization +model_factory: + factory: bypass_factory_fn # Unified factory supporting all layer types + block_loss_func: normalized_mse_loss # Loss function for comparing teacher/student blocks. vectorwise_normalized_mse_loss / batched_normalized_mse_loss / normalized_mse_loss + gqa_init_mode: AverageKV # How to initialize K/V heads in GQA. All options here: GQAInitMode + mlp_init_mode: Truncate # MLP initialization. All options here: MlpInitMode + mlp_init_config: # Configuration for MLP initialization (if needed) + activations_log_dir: # Directory with activation statistics (required for PruneByActivationsLog) + linear_init_mode: FromTeacher # How to initialize linear layers: FromTeacher, Random, etc. + submodule_for_loss_calculation: # Specific submodule for loss calc. + keys_to_learn: # Subblock(s) to train: entire_block, subblock_attention, subblock_ffn, subblock_mamba, or a list of those. + +# Validation Configuration +disable_initial_validate: false +validate_teacher_model: true +validate_student_model: true +disable_validation: false # Enable validation to exercise all code paths +best_val_loss: 1e+9 # Track best validation loss achieved + +# Performance Optimization +compile: false # Use PyTorch compilation +disable_fa2: false # Disable Flash Attention 2 (false = use FA2 if available) +teacher_model_load_on_cpu: false + +# Checkpoint Management +save_checkpoint_before_training: false # Save initial checkpoint before training +disable_checkpoint_save: false # Disable all checkpoint saving +save_best_ckpt: true # Save checkpoint when validation improves +kill_after_first_save: false # Exit after first checkpoint save (for testing) +realize_best_or_latest: "best" + +wandb_log: false +wandb: + project: + entity: + +# Multiple bypass configurations to train sequentially. +# Each entry overrides model.model_config_overrides and optionally model_factory.keys_to_learn. +# If empty or absent, a single run uses the settings above. +configs: + - model_config_overrides: + ffn: + - intermediate_size: 3072 + attention: + - num_key_value_heads: 8 + keys_to_learn: subblock_ffn + - model_config_overrides: + ffn: + - intermediate_size: 5888 + attention: + - num_key_value_heads: 8 + keys_to_learn: subblock_ffn diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml new file mode 100644 index 00000000000..dc1905ded75 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: kv_heads_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: nemotron_h +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +# KV-heads-only pruning: lock off FFN/MoE-side variants. The replacement library +# exposes {teacher 2-head, 1-head} per attention layer; FFN and Mamba +# blocks are copied verbatim from the teacher. +build_replacement_library: + add_ffn_no_ops: false + add_attention_no_ops: false + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_num_kv_heads: 9 # toy KV-heads-only target; see nemotron-3-nano-30b-a3b.yaml + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml new file mode 100644 index 00000000000..a1c63ac913a --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml @@ -0,0 +1,120 @@ +# @package bypass +# Bypass Distillation Configuration — Nemotron-3-Nano-30B-A3B (KV-heads-only toy task). +# +# Trains a single 1-KV-head variant per attention layer using per-block knowledge +# distillation against the teacher (`subblock_attention` keys only — FFN/MoE/Mamba +# blocks are frozen). The trained weights are saved into the replacement library +# and consumed by the MIP solver alongside the no_op variant. +# +# Tutorial budget: ~10M tokens (quick sanity, ~30 min on 4×H100). Increase +# `training_tokens` for a stronger bypass effect. + +# Runtime Configuration +dtype: "bf16" +seed: 42 + +# Experiment Tracking (auto-set from architecture, keys_to_learn, and a config fingerprint) +experiment_id: +experiment_dir: +iter_num: 1 +step_num: 1 +token_count: 0 + +# Data Configuration +data: + data_column: "messages" + block_size: 4096 + bos_rate: 0.5 + fim_rate: 0 + fim_spm_rate: 0 + source_datasets_to_discard: [] + load_from_disk: true + keep_in_memory: false + val_dataset_name: valid + max_eval_samples: 4 + eval_samples_per_process: + shuffle_train_data_seed: ${random_int:0,9999} + +# Training Configuration +training: + learning_rate: 3e-4 + training_tokens: 5e+7 # 50M tokens (toy budget) + micro_batch_size: 2 + val_micro_batch_size: 2 + warmup_ratio: 0.05 + warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.grad_accumulation_steps},${.warmup_ratio}} + min_lr_factor: 1e-5 + grad_accumulation_steps: 8 + skip_first_batches: 0 + weight_decay: 0.1 + decay_lr: true + beta1: 0.9 + beta2: 0.95 + use_grad_scaling: false + grad_clip: 1.0 + grad_clip_type: norm + clipping_count: 0 + log_interval: 100 + eval_interval: 100 + +# Model Loading Configuration +resume_checkpoint_path: +find_last_ckpt_for_resume: true +parameter_count: +init_checkpoint_path: + +model: + student_weights_dtype: "bf16" + + model_overrides: + delete_old_checkpoints: true + save_interval_seconds: 12900 + save_interval: 100 + save_checkpoint_when_done: true + + # Architecture override: only attention is touched. FFN/MoE/Mamba sub-blocks + # use teacher weights verbatim (the `ffn` key is omitted on purpose). + model_config_overrides: + attention: + - num_key_value_heads: 1 + no_op: + +# Model Factory Configuration +model_factory: + factory: bypass_factory_fn + block_loss_func: normalized_mse_loss + gqa_init_mode: AverageKV + mlp_init_mode: Truncate # FFN is frozen; this knob is dormant for KV-only tasks + mlp_init_config: + activations_log_dir: + linear_init_mode: FromTeacher + submodule_for_loss_calculation: + keys_to_learn: subblock_attention # train ONLY the attention sub-block + +# Validation Configuration +disable_initial_validate: false +validate_teacher_model: true +validate_student_model: true +disable_validation: false +best_val_loss: 1e+9 + +# Performance Optimization +compile: false +disable_fa2: false +teacher_model_load_on_cpu: false + +# Checkpoint Management +save_checkpoint_before_training: false +disable_checkpoint_save: false +save_best_ckpt: true +kill_after_first_save: false +realize_best_or_latest: "best" + +wandb_log: false +wandb: + project: + entity: + +# Single architectural variant. `set_experiment_id` produces a readable, fingerprinted ID. +# Add more entries here to train multiple variants in one bypass run. +configs: [] diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b-with-bypass.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b-with-bypass.yaml new file mode 100644 index 00000000000..917623c028e --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b-with-bypass.yaml @@ -0,0 +1,4 @@ +defaults: + - nemotron-3-nano-30b-a3b + - override bypass: defaults + - _self_ diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml new file mode 100644 index 00000000000..97d3ca60ffa --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml @@ -0,0 +1,30 @@ +defaults: + - NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 + - _self_ + +# Input Hugging Face model to compress. +# Auto-downloads from HuggingFace if the path is not a local directory. +input_hf_model_path: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for puzzletron outputs +puzzle_dir: /workspace/puzzle_dir + +# Toy KV-heads-only constraint. +# Teacher has 6 attention layers × num_key_value_heads=2 = 12 KV heads total. +# Target 9 leaves 75% of teacher KV heads — the MIP solver picks per-layer from +# {teacher 2-head, 1-head} so some layers stay full and the rest collapse to 1 +# head. +mip: + human_constraints: + target_num_kv_heads: 9 + +# KV-heads-only toy pruning task. +# teacher num_attention_heads = 32, num_key_value_heads = 2 (n_heads_in_group = 16) +# Bypass-trains a single 1-KV-head variant per attention layer +# (n_heads_in_group = 32). Combined with `add_attention_no_ops: false` in the +# base config, MIP picks per-layer from {teacher 2-head, 1-head}. +pruning: + n_heads_in_group_list: [32] # 32 / 32 = 1 KV head per attention layer diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/kv_heads_pruning.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/kv_heads_pruning.yaml new file mode 100644 index 00000000000..df37b7403c0 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/kv_heads_pruning.yaml @@ -0,0 +1,24 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +# Score per-KV-head importance and create the pruned-checkpoint variants used +# to build the replacement library. +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IndependentKvHeadContributionHook} + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin.KVHeadsPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_model_descriptor.NemotronHKVHeadsLayerDescriptor + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory + target_layer: "mixer.o_proj" # Nemotron-H attention is under `mixer`, not `self_attn` + layer_input_descriptors_path: + +# Teacher: num_attention_heads = 32, num_key_value_heads = 2 (n_heads_in_group = 16) +# Single 1-KV-head variant: n_heads_in_group = 32 → num_kv_heads = 32 / 32 = 1 +n_heads_in_group_list: [32] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/pruning_defaults.yaml new file mode 100644 index 00000000000..e05e775bee3 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_model_defaults.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_model_defaults.yaml new file mode 100644 index 00000000000..ce1749d9698 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_solutions_defaults.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_solutions_defaults.yaml new file mode 100644 index 00000000000..ec139023794 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/main.py b/examples/puzzletron/main.py index 8ceed378318..f0ff54e8be7 100644 --- a/examples/puzzletron/main.py +++ b/examples/puzzletron/main.py @@ -67,7 +67,6 @@ def run_full_puzzletron(hydra_config_path: str): Args: config_path: Path to the YAML configuration file """ - mtpz.tools.mprint("Puzzletron Progress 1/8: starting puzzletron pipeline") dist.setup(timeout=timedelta(minutes=10)) # Register Hydra custom resolvers (needed for config resolution) @@ -77,12 +76,17 @@ def run_full_puzzletron(hydra_config_path: str): hydra_config_dir = str(hydra_config_path.parent) hydra_config_name = hydra_config_path.stem - # Load hydra config + # Load hydra config to determine total step count (bypass adds one step) hydra_cfg = mtpz.tools.initialize_hydra_config_for_dir( config_dir=hydra_config_dir, config_name=hydra_config_name, overrides=[], ) + start_step, total_steps = mtpz.puzzletron_nas_plugin._progress_step(hydra_cfg, "start") + + mtpz.tools.mprint( + f"Puzzletron Progress {start_step}/{total_steps}: starting puzzletron pipeline" + ) # Convert model (convert from HF to DeciLM, score pruning activations, # prune the model and save pruned checkpoints) @@ -113,7 +117,10 @@ def run_full_puzzletron(hydra_config_path: str): ) dist.cleanup() - mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + complete_step, _ = mtpz.puzzletron_nas_plugin._progress_step(hydra_cfg, "complete") + mtpz.tools.mprint( + f"Puzzletron Progress {complete_step}/{total_steps}: puzzletron pipeline completed (multi-gpu)" + ) def run_mip_only(hydra_config_path: str): @@ -140,21 +147,28 @@ def run_mip_only(hydra_config_path: str): config_name=hydra_config_name, overrides=[], ) + mip_step, total_steps = mtpz.puzzletron_nas_plugin._progress_step(hydra_cfg, "mip") # Check if sweep mode is enabled if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): mtpz.tools.mprint( - "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)" + f"Puzzletron Progress {mip_step}/{total_steps}:" + " running MIP sweep for multiple compression rates (multi-gpu)" ) mtpz.mip.run_mip_sweep(hydra_cfg) else: # mip_and_realize_models (distributed processing) # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API - mtpz.tools.mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + mtpz.tools.mprint( + f"Puzzletron Progress {mip_step}/{total_steps}: running MIP and realizing models (multi-gpu)" + ) mtpz.mip.launch_mip_and_realize_model(hydra_cfg) dist.cleanup() - mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + complete_step, _ = mtpz.puzzletron_nas_plugin._progress_step(hydra_cfg, "complete") + mtpz.tools.mprint( + f"Puzzletron Progress {complete_step}/{total_steps}: puzzletron pipeline completed (multi-gpu)" + ) def main(): diff --git a/modelopt/torch/puzzletron/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py index 761534f6df9..bade8cfb15b 100644 --- a/modelopt/torch/puzzletron/mip/run_puzzle.py +++ b/modelopt/torch/puzzletron/mip/run_puzzle.py @@ -81,6 +81,7 @@ class Type(enum.Enum): "target_throughput", "target_latency", "target_time_to_first_token", + "target_num_kv_heads", "num_params", "stats.has_attention", } @@ -167,6 +168,10 @@ def to_mip_constraints(self, subblock_stats_args) -> dict[str, Any]: if "target_memory" in self.constraints: mip_constraints["stats.memory_mib"] = self.constraints["target_memory"] + # Total KV-heads constraint (sum across attention layers; used for KV-cache-only sweeps) + if "target_num_kv_heads" in self.constraints: + mip_constraints["stats.num_kv_heads"] = self.constraints["target_num_kv_heads"] + # Throughput constraints throughput_constraints = [] if "target_throughput" in self.constraints: diff --git a/modelopt/torch/puzzletron/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py index 253674f97af..abe94b15b02 100644 --- a/modelopt/torch/puzzletron/puzzletron_nas_plugin.py +++ b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py @@ -37,10 +37,12 @@ ) from modelopt.torch.opt.searcher import BaseSearcher, SearchStateDict +from . import bypass_distillation from .activation_scoring import launch_score_activations from .anymodel.converter import ConverterFactory from .anymodel.model_descriptor import ModelDescriptorFactory from .build_library_and_stats import launch_build_library_and_stats +from .bypass_distillation.bypass_utils import expected_bypass_runs, load_bypass_state from .mip import launch_mip_and_realize_model from .pruning import launch_prune_ckpt from .scoring import launch_scoring @@ -100,10 +102,69 @@ class PuzzletronConfig(ModeloptBaseConfig): ) +_StageName = str + +# Canonical stage order. Stages absent from a given run (e.g. "bypass" when +# bypass isn't configured) are skipped, but the rest keep their relative order. +_STAGE_ORDER: tuple[_StageName, ...] = ( + "start", + "convert", + "score_activations", + "prune", + "bypass", + "build_library", + "score_blocks", + "mip", + "complete", +) + + +def _total_steps(hydra_cfg) -> int: + """Return total pipeline step count: 9 with bypass, 8 without.""" + return 9 if hydra_cfg.get("bypass", None) is not None else 8 + + +def _progress_step(hydra_cfg, stage: _StageName) -> tuple[int, int]: + """Return ``(step_number, total_steps)`` for a given pipeline stage. + + Single source of truth for the user-facing ``Puzzletron Progress N/T`` strings — + keeps numbering coherent across ``main.py``, ``convert_puzzletron_model``, and + ``PuzzletronSearcher.run_search``, and shifts MIP/realize automatically when + bypass is added or removed. + """ + has_bypass = hydra_cfg.get("bypass", None) is not None + total = _total_steps(hydra_cfg) + step = 0 + for s in _STAGE_ORDER: + if s == "bypass" and not has_bypass: + continue + step += 1 + if s == stage: + return step, total + raise ValueError(f"Unknown pipeline stage: {stage!r}") + + +def _find_incomplete_bypass_runs(hydra_cfg, puzzle_dir: str | Path) -> list[str]: + expected_runs = expected_bypass_runs(hydra_cfg) + incomplete_runs = [] + for expected_run in expected_runs: + state = load_bypass_state(expected_run["experiment_dir"]) + symlink = Path(puzzle_dir) / "ckpts" / expected_run["experiment_id"] + if ( + state is None + or state.get("status") != "completed" + or state.get("config_fingerprint") != expected_run["config_fingerprint"] + or not symlink.exists() + ): + incomplete_runs.append(expected_run["experiment_id"]) + return incomplete_runs + + def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> ConvertReturnType: """1. Convert the model from HF format to AnyModel format. 2. Score the pruning activations. - 3. Prune the model and save pruned checkpoints + 3. Prune the model and save pruned checkpoints. + 4. (Optional) Run bypass distillation. The output of this step will be used by mnt.search() to perform the NAS search. """ @@ -125,37 +186,105 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) hydra_cfg = hydra.utils.instantiate(hydra_cfg) - # Convert HuggingFace model to Puzzletron heterogeneous format (generic, uses descriptor from config) - if dist.is_master(): - mprint( - "Puzzletron Progress 2/8: converting model to Puzzletron heterogeneous format (single-gpu)" - ) - hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + has_bypass = hydra_cfg.get("bypass", None) is not None + convert_step, N = _progress_step(hydra_cfg, "convert") + score_step, _ = _progress_step(hydra_cfg, "score_activations") + prune_step, _ = _progress_step(hydra_cfg, "prune") - # Get descriptor and converter from the hydra config - descriptor_name = hydra_cfg.descriptor - descriptor = ModelDescriptorFactory.get(descriptor_name) - converter = ConverterFactory.get(descriptor_name) + # Step 2: Convert HuggingFace model to Puzzletron heterogeneous format + hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + teacher_dir = Path(config.puzzle_dir) / hf_ckpt_teacher_dir + if dist.is_master(): + if (teacher_dir / "config.json").exists(): + mprint( + f"Puzzletron Progress {convert_step}/{N}: teacher checkpoint already exists, skipping conversion" + ) + else: + mprint( + f"Puzzletron Progress {convert_step}/{N}: converting model to Puzzletron heterogeneous format (single-gpu)" + ) - converter.convert( - descriptor=descriptor, - input_dir=Path(config.input_model_path), - output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, - ) + # Get descriptor and converter from the hydra config + descriptor_name = hydra_cfg.descriptor + descriptor = ModelDescriptorFactory.get(descriptor_name) + converter = ConverterFactory.get(descriptor_name) + + # Auto-download from HuggingFace if path doesn't exist locally + input_model_path = config.input_model_path + if not Path(input_model_path).exists(): + from huggingface_hub import snapshot_download + + if input_model_path.startswith("https://huggingface.co/"): + model_id = "/".join(input_model_path.rstrip("/").split("/")[-2:]) + else: + model_id = input_model_path # assume HF model ID like "org/model-name" + mprint( + f"Downloading HuggingFace model '{model_id}' — this may take several minutes " + f"for large models. Other ranks are waiting at a barrier." + ) + input_model_path = snapshot_download(repo_id=model_id) + mprint(f"Downloaded to: {input_model_path}") + + converter.convert( + descriptor=descriptor, + input_dir=Path(input_model_path), + output_dir=teacher_dir, + ) dist.barrier() - # Score_pruning_activations (distributed processing) - mprint("Puzzletron Progress 3/8: scoring pruning activations (multi-gpu)") - launch_score_activations(hydra_cfg) - - # Prune the model and save pruned checkpoints - if dist.is_master(): + # Step 3: Score pruning activations (distributed processing) + activations_log_dir = Path(hydra_cfg.pruning.activations_log_dir) + if activations_log_dir.exists() and any(activations_log_dir.glob("rank_*.pth")): mprint( - "Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" + f"Puzzletron Progress {score_step}/{N}: pruning activation scores already " + f"exist at {activations_log_dir} — delete this directory to re-score with " + f"the current config." ) - launch_prune_ckpt(hydra_cfg) + dist.barrier() + else: + mprint(f"Puzzletron Progress {score_step}/{N}: scoring pruning activations (multi-gpu)") + launch_score_activations(hydra_cfg) + + # Step 4: Prune the model and save pruned checkpoints (single process) + pruned_ckpts_dir = Path(hydra_cfg.pruning.pruned_ckpts_output_dir) + if dist.is_master(): + if pruned_ckpts_dir.exists() and any(pruned_ckpts_dir.iterdir()): + mprint( + f"Puzzletron Progress {prune_step}/{N}: pruned checkpoints already " + f"exist at {pruned_ckpts_dir} — delete this directory to re-prune with " + f"the current config." + ) + else: + mprint( + f"Puzzletron Progress {prune_step}/{N}: pruning the model and saving pruned checkpoints (single-gpu)" + ) + launch_prune_ckpt(hydra_cfg) dist.barrier() + # Step 5: Bypass distillation (optional, distributed processing) + if has_bypass: + bypass_step, _ = _progress_step(hydra_cfg, "bypass") + # Skip only when every expected bypass run has a matching manifest, a + # completed status, the same config fingerprint, and a realized ckpts/ + # symlink. Counting arbitrary `_DONE` files is not config-specific and + # can skip the current sweep because of stale unrelated runs. + incomplete_runs = ( + _find_incomplete_bypass_runs(hydra_cfg, config.puzzle_dir) if dist.is_master() else None + ) + incomplete_runs = dist.broadcast(incomplete_runs, src=0) + bypass_done = len(incomplete_runs) == 0 + if bypass_done: + mprint( + f"Puzzletron Progress {bypass_step}/{N}: bypass distillation already completed, skipping" + ) + else: + mprint( + f"Puzzletron Progress {bypass_step}/{N}: running bypass distillation " + f"(multi-gpu); incomplete runs: {incomplete_runs}" + ) + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + return model, {} @@ -226,18 +355,71 @@ def run_search(self) -> None: # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) hydra_cfg = hydra.utils.instantiate(hydra_cfg) - # Build_library_and_stats (single process) + library_step, N = _progress_step(hydra_cfg, "build_library") + scoring_step, _ = _progress_step(hydra_cfg, "score_blocks") + mip_step, _ = _progress_step(hydra_cfg, "mip") + + # Build replacement library and subblock statistics (single process) + puzzle_dir = Path(self.model.puzzle_dir) + replacement_library_path = puzzle_dir / "replacement_library.json" + subblock_stats_path = puzzle_dir / hydra_cfg.calc_subblock_stats.subblock_stats_filename + # Detect a stale library: any ckpts/* entry whose finalisation marker + # is newer than the library file means a new replacement (e.g. bypass- + # trained subblocks) appeared after the last build and must be picked + # up. Without this check, our skip-if-done would happily reuse a + # no-bypass library even after bypass completes. + # + # We probe ``config.json`` rather than the directory mtime because: + # 1. directory mtime tracks "an entry was added/removed", not "a file + # inside was modified" — adding new shards to an existing checkpoint + # dir wouldn't bump the dir mtime; + # 2. ``entry.resolve()`` on a dangling symlink raises (or returns a + # non-existent path), which the previous code's ``resolved.exists()`` + # silently treated as "not stale"; + # 3. ``config.json`` is written last when a checkpoint is finalised — + # its mtime is the real "checkpoint ready" timestamp. + # The check is conservative: false positives just trigger a rebuild, + # which is safe. + ckpts_dir = puzzle_dir / "ckpts" + library_is_stale = False + if replacement_library_path.exists() and ckpts_dir.exists(): + library_mtime = replacement_library_path.stat().st_mtime + for entry in ckpts_dir.iterdir(): + # `Path.stat()` follows symlinks by default, so this works + # whether `entry` is a real dir or a symlink to one (the + # bypass and pruning pipelines both land here as symlinks). + # `try` guards against dangling symlinks (FileNotFoundError). + config_path = entry / "config.json" + try: + config_mtime = config_path.stat().st_mtime + except (FileNotFoundError, OSError): + continue + if config_mtime > library_mtime: + library_is_stale = True + mprint( + f"Replacement library is stale: '{entry.name}/config.json' is newer than the existing library, will rebuild." + ) + break if dist.is_master(): - mprint( - "Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu)" - ) - launch_build_library_and_stats(hydra_cfg) + if ( + replacement_library_path.exists() + and subblock_stats_path.exists() + and not library_is_stale + ): + mprint( + f"Puzzletron Progress {library_step}/{N}: replacement library and subblock stats already exist, skipping" + ) + else: + mprint( + f"Puzzletron Progress {library_step}/{N}: building replacement library and subblock statistics (single-gpu)" + ) + launch_build_library_and_stats(hydra_cfg) dist.barrier() - # Calc_one_block_scores (distributed processing) - mprint("Puzzletron Progress 6/8: calculating one block scores (multi-gpu)") + # Calculate one block scores (distributed processing) + mprint(f"Puzzletron Progress {scoring_step}/{N}: calculating one block scores (multi-gpu)") launch_scoring(hydra_cfg) - # mip_and_realize_models (distributed processing) - mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + # MIP search and realize models (distributed processing) + mprint(f"Puzzletron Progress {mip_step}/{N}: running MIP and realizing models (multi-gpu)") launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py index b4edbdd385c..57a1de039b5 100644 --- a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -39,7 +39,7 @@ import pandas as pd from omegaconf import DictConfig -from modelopt.torch.utils import json_dump +from modelopt.torch.utils import json_dump, json_load from ..anymodel.model_descriptor import ModelDescriptor, ModelDescriptorFactory from ..block_config import AttentionConfig, BlockConfig, FFNConfig @@ -212,7 +212,21 @@ def _build_subblocks_df( checkpoint_dirs = _get_last_checkpoint_from_each_experiment( master_puzzle_dir, trust_remote_code=trust_remote_code ) - checkpoint_dirs = [teacher_checkpoint_dir] + list(checkpoint_dirs - {teacher_checkpoint_dir}) + + # Order the non-teacher checkpoints so that downstream `drop_duplicates(keep="first")` + # deterministically prefers bypass-trained subblocks over Truncate-init pruned ones + # when both produce a row with the same architectural identifier. Without this, + # `set` iteration order makes the choice random (hash-of-path) and we'd sometimes + # discard the BLD-trained weights we just paid 30+ min to compute. + # + # Priority (lowest sort key wins): 0 = bypass-trained, 1 = everything else. + # Bypass checkpoints land under `/bypass/bypass_runs//`. + def _checkpoint_priority(p: Path) -> tuple[int, str]: + is_bypass = "bypass" in p.parts and "bypass_runs" in p.parts + return (0 if is_bypass else 1, str(p)) + + non_teacher_dirs = sorted(checkpoint_dirs - {teacher_checkpoint_dir}, key=_checkpoint_priority) + checkpoint_dirs = [teacher_checkpoint_dir] + non_teacher_dirs checkpoints_to_split = [teacher_checkpoint_dir] subblock_rows = [] @@ -455,11 +469,21 @@ def _infer_subblocks_to_extract( if (checkpoint_dir / "replacement_library.json").exists(): return [] bypass_config_path = checkpoint_dir / "bypass_config.json" - if (checkpoint_dir in checkpoints_to_split) or (not bypass_config_path.exists()): + bypass_args_path = checkpoint_dir / "args.json" + if (checkpoint_dir in checkpoints_to_split) or ( + not bypass_config_path.exists() and not bypass_args_path.exists() + ): subblocks_to_extract = ["block", "attention", "ffn"] else: - bypass_config = json.loads(bypass_config_path.read_text()) - keys_to_learn = bypass_config.get("keys_to_learn", "entire_block") + if bypass_args_path.exists(): + bypass_config = json_load(bypass_args_path) + keys_to_learn = bypass_config.get("model_factory", {}).get( + "keys_to_learn", "entire_block" + ) + else: + bypass_config = json.loads(bypass_config_path.read_text()) + keys_to_learn = bypass_config.get("keys_to_learn", "entire_block") + subblocks_to_extract = learned_subblocks_from_keys_to_learn(keys_to_learn) return subblocks_to_extract diff --git a/tests/gpu/torch/puzzletron/test_bypass.py b/tests/gpu/torch/puzzletron/test_bypass.py new file mode 100644 index 00000000000..59881a242b1 --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_bypass.py @@ -0,0 +1,1112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU integration tests for bypass distillation (blockwise local distillation). + +Each test is parametrized over the same model families covered by ``test_puzzletron.py`` +(see ``PUZZLETRON_FAMILIES`` in ``tests/_test_utils/torch/puzzletron/utils.py``). + +Tiny model dimensions used throughout (set by ``setup_test_model_and_data``): + - hidden_size: 256, intermediate_size: 512, num_layers: max(2, world_size) + - num_attention_heads: 32, num_key_value_heads: 8 + - num_local_experts: 16 (MoE families only, e.g. Qwen3-VL) + - training_tokens: 128, block_size: 64, micro_batch_size: 1 -> max_steps = 2 + +Pruning targets (used by all four tests): + - pruned intermediate_size: 256 (dense) — half of teacher + - pruned num_local_experts: 8 (MoE) — half of teacher + - pruned num_key_value_heads: 4 — half of teacher + +mlp_init_mode is family-aware: + - Dense families use ``Truncate`` (FFN intermediate slicing in the generic path). + - MoE families use ``ExpertRemoval`` and delegate per-expert weight slicing to the + ``experts_removal`` mixin registered on the descriptor. ``mlp_init_config`` is + sourced from the family's pruning YAML (``mlp_init_config_yaml``) — no + per-family branching needed in this test file. + +To add a new model family: + 1. Append one row to PUZZLETRON_FAMILIES in tests/_test_utils/torch/puzzletron/utils.py. + 2. Ensure tests/gpu/torch/puzzletron/resources/configs//.yaml exists + and that setup_test_model_and_data() can build a tiny stand-in for it. + 3. For MoE families, ensure the family's descriptor registers ``"kv_heads"`` and + ``"experts_removal"`` in ``pruning_mixins()`` (see e.g. NemotronH, GPT-OSS, + Qwen3-VL descriptors). + 4. The four bypass tests below pick up the new row automatically. +""" + +import copy +import json +from datetime import timedelta +from functools import partial +from pathlib import Path + +import hydra +import pytest +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.misc import set_seed +from _test_utils.torch.puzzletron.utils import PUZZLETRON_FAMILIES, setup_test_model_and_data +from omegaconf import OmegaConf + +import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations +import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation +import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch.puzzletron.replacement_library.build_replacement_library as build_lib +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel import convert_model +from modelopt.torch.puzzletron.bypass_distillation.bypass_checkpoint_utils import ( + find_latest_run_dir, +) +from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import set_experiment_id +from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir + +# --------------------------------------------------------------------------- +# Constants — shared tiny-model dimensions and pruning targets +# --------------------------------------------------------------------------- + +SEED = 1234 + +# Teacher tiny-model dimensions (set uniformly by setup_test_model_and_data) +TEACHER_INTERMEDIATE_SIZE = 512 +TEACHER_NUM_KV_HEADS = 8 +TEACHER_NUM_LOCAL_EXPERTS = 16 + +# Pruned targets (half of teacher) +PRUNED_INTERMEDIATE_SIZE = 256 +PRUNED_NUM_KV_HEADS = 4 +PRUNED_NUM_LOCAL_EXPERTS = 8 + +# Training budget: 128 tokens / (64 block * 1 mbs) = 2 steps — completes fast +TRAINING_TOKENS = 128 +BLOCK_SIZE = 64 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _block_override(has_moe_layers: bool, pruned: bool = True) -> dict: + """Return a single FFN-block override entry, family-aware. + + When ``pruned=True`` the override compresses the block (halves intermediate size for + dense or halves num_local_experts for MoE). When ``pruned=False`` it pins the block + to teacher size — used by tests that exercise attention pruning while keeping the FFN + side fixed. + """ + if has_moe_layers: + n_experts = PRUNED_NUM_LOCAL_EXPERTS if pruned else TEACHER_NUM_LOCAL_EXPERTS + return {"moe": {"num_local_experts": n_experts}, "no_op": None} + intermediate = PRUNED_INTERMEDIATE_SIZE if pruned else TEACHER_INTERMEDIATE_SIZE + return {"intermediate_size": intermediate, "no_op": None} + + +def _mlp_init_settings(has_moe_layers: bool, hydra_cfg) -> tuple[str, dict]: + """Return ``(mlp_init_mode, mlp_init_config)`` for the family. + + Dense families use ``Truncate`` (FFN intermediate slicing). MoE families use + ``ExpertRemoval``, which delegates per-expert weight slicing to the + ``experts_removal`` mixin registered on the descriptor. The expert-scores + metadata (``expert_scores_key``, ``layer_prefix_template``) is read directly + from the family's pruning YAML — no per-family branching here. + """ + if not has_moe_layers: + return "Truncate", {"activations_log_dir": None} + + mlp_init_config = ( + OmegaConf.to_container( + hydra_cfg.pruning.get("mlp_init_config_yaml", OmegaConf.create({})), + resolve=True, + ) + or {} + ) + mlp_init_config["activations_log_dir"] = str(hydra_cfg.pruning.activations_log_dir) + return "ExpertRemoval", mlp_init_config + + +def _make_bypass_cfg_dict( + has_moe_layers: bool, + hydra_cfg, + *, + include_block_override: bool = True, + block_pruned: bool = True, + include_attention_override: bool = True, + attention_pruned: bool = True, + configs_list: list | None = None, +) -> dict: + """Return a plain-dict bypass config suitable for OmegaConf.update injection. + + Args: + has_moe_layers: Whether the model family is MoE (dispatches FFN override shape + and the mlp_init_mode). + hydra_cfg: The post-pruning hydra config — used to source the family's + ``mlp_init_config_yaml`` and ``activations_log_dir`` for MoE expert removal. + include_block_override / block_pruned: Whether to override the per-block FFN + sub-component, and whether to prune (vs. pin to teacher). + include_attention_override / attention_pruned: Same for the attention sub-component. + configs_list: If provided, populates bypass.configs for a multi-config sweep. + """ + overrides: dict = {} + if include_block_override: + overrides["ffn"] = [_block_override(has_moe_layers, pruned=block_pruned)] + if include_attention_override: + kv = PRUNED_NUM_KV_HEADS if attention_pruned else TEACHER_NUM_KV_HEADS + overrides["attention"] = [{"num_key_value_heads": kv, "no_op": None}] + + mlp_init_mode, mlp_init_config = _mlp_init_settings(has_moe_layers, hydra_cfg) + + cfg = { + "dtype": "bf16", + "seed": 42, + "experiment_id": None, + "experiment_dir": None, + "iter_num": 1, + "step_num": 1, + "token_count": 0, + "data": { + # The dummy test dataset stores conversations under the "conversation" column. + "data_column": "conversation", + "block_size": BLOCK_SIZE, + "bos_rate": 0.5, + "fim_rate": 0, + "fim_spm_rate": 0, + "source_datasets_to_discard": [], + "load_from_disk": True, + "keep_in_memory": False, + "val_dataset_name": "valid", + "max_eval_samples": 1, + "eval_samples_per_process": None, + "shuffle_train_data_seed": 42, + }, + "training": { + "learning_rate": 1e-4, + "training_tokens": TRAINING_TOKENS, + "micro_batch_size": 1, + "val_micro_batch_size": 1, + "warmup_ratio": 0.05, + "warmup_steps": None, + "min_lr_factor": 1e-5, + "grad_accumulation_steps": 1, + "skip_first_batches": 0, + "weight_decay": 0.1, + "decay_lr": True, + "beta1": 0.9, + "beta2": 0.95, + "use_grad_scaling": False, + "grad_clip": 1.0, + "grad_clip_type": "norm", + "clipping_count": 0, + "log_interval": 5, + # Large eval_interval so validation is skipped during this short run. + # Validation is fully disabled anyway (disable_validation=True below). + "eval_interval": 100, + }, + "resume_checkpoint_path": None, + "find_last_ckpt_for_resume": False, + "parameter_count": None, + "init_checkpoint_path": None, + "model": { + "student_weights_dtype": "bf16", + "model_overrides": { + "delete_old_checkpoints": True, + "save_interval_seconds": None, + # Effectively disable step-interval saving; rely on save_checkpoint_when_done. + "save_interval": 1_000_000_000, + "save_checkpoint_when_done": True, + }, + "model_config_overrides": overrides, + }, + "model_factory": { + "factory": "bypass_factory_fn", + "block_loss_func": "normalized_mse_loss", + "gqa_init_mode": "AverageKV", + "mlp_init_mode": mlp_init_mode, + "mlp_init_config": mlp_init_config, + "linear_init_mode": "FromTeacher", + "submodule_for_loss_calculation": None, + "keys_to_learn": "entire_block", + }, + # Disable all validation to keep tests fast. + "disable_initial_validate": True, + "validate_teacher_model": False, + "validate_student_model": False, + "disable_validation": True, + "best_val_loss": 1e9, + "compile": False, + "disable_fa2": False, + "teacher_model_load_on_cpu": False, + "save_checkpoint_before_training": False, + "disable_checkpoint_save": False, + "save_best_ckpt": True, + # Do NOT use kill_after_first_save — it raises RuntimeError which becomes sys.exit(1). + # Instead let the short training run (2 steps) complete naturally. + "kill_after_first_save": False, + "realize_best_or_latest": "best", + "wandb_log": False, + "wandb": {"project": None, "entity": None}, + } + + if configs_list is not None: + cfg["configs"] = configs_list + + return cfg + + +def _expected_experiment_id(bypass_cfg_dict: dict) -> str: + """Compute the experiment_id that ``set_experiment_id`` will assign. + + Avoids duplicating the formula in tests — uses the same function the runtime uses. + """ + cfg = OmegaConf.create({"bypass": copy.deepcopy(bypass_cfg_dict)}) + set_experiment_id(cfg) + return cfg.bypass.experiment_id + + +def _setup_hydra_cfg_and_pruning( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, +) -> tuple: + """Set up the tiny model, convert it, score activations, and create pruning ckpts. + + Returns ``(puzzle_dir, dataset_path, hydra_cfg)``. + + Steps performed: + 1. Create a small HF model and dummy dataset via ``setup_test_model_and_data``. + 2. Convert the HF checkpoint to AnyModel/DeciLM format (rank 0 only). + 3. Load the per-family Hydra config with ``puzzle_dir`` and ``dataset_path`` overrides. + 4. Run ``score_pruning_activations`` (distributed). + 5. Run ``pruning_ckpts`` (rank 0 only) then barrier. + """ + set_seed(SEED) + dist.setup(timeout=timedelta(minutes=10)) + + puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, hf_model_name, hybrid_override_pattern + ) + + hydra_config_dir = str(project_root_path / "tests/gpu/torch/puzzletron/resources/configs") + # Per-family hydra config name follows the layout configs///. + hydra_config_name = f"{hf_model_name}/{Path(hf_model_name).name}" + + # Step 0: Convert HF checkpoint to AnyModel/DeciLM format. + if rank == 0: + convert_model( + input_dir=str(hf_checkpoint_path), + output_dir=str(puzzle_dir / "ckpts/teacher"), + converter=converter, + ) + dist.barrier() + + # Step 1: Load Hydra config. + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[ + f"puzzle_dir={puzzle_dir}", + f"dataset_path={dataset_path}", + ], + ) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + # Step 2: Score pruning activations (distributed). + score_pruning_activations.launch_score_activations(hydra_cfg) + + # Step 3: Create pruning checkpoints (rank 0 only). + if rank == 0: + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + dist.barrier() + + return puzzle_dir, dataset_path, hydra_cfg + + +# --------------------------------------------------------------------------- +# Tests — each parametrized over PUZZLETRON_FAMILIES +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + PUZZLETRON_FAMILIES, +) +def test_bypass_block_pruning( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Bypass distillation with the per-block sub-component pruned. + + For dense families, prunes FFN intermediate (512 -> 256). For MoE families, + prunes num_local_experts (16 -> 8). KV heads are also halved (8 -> 4). + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_block_pruning_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_block_pruning_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, + ) + + bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_block_pruning[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + PUZZLETRON_FAMILIES, +) +def test_bypass_kv_head_compression( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Bypass distillation with KV heads halved (8 -> 4) and FFN block pinned to teacher. + + For dense, the experiment_id will be ``bypass_ffn_512_heads_4`` (FFN at teacher size, + attention halved). For MoE, ``bypass_experts_16_heads_4``. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_kv_head_compression_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_kv_head_compression_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, + ) + + bypass_cfg_dict = _make_bypass_cfg_dict( + has_moe_layers, + hydra_cfg, + block_pruned=False, # keep FFN/experts at teacher + attention_pruned=True, # halve KV heads + ) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_kv_head_compression[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + PUZZLETRON_FAMILIES, +) +def test_bypass_multi_config_sequential( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Bypass distillation sweep: two configs run sequentially via bypass.configs list. + + Config 0: block pruned + attention pruned + Config 1: block at teacher + attention pruned + Both checkpoint symlinks must exist after the sweep completes. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_multi_config_sequential_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_multi_config_sequential_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, + ) + + configs_list = [ + { + "model_config_overrides": { + "ffn": [_block_override(has_moe_layers, pruned=True)], + "attention": [{"num_key_value_heads": PRUNED_NUM_KV_HEADS, "no_op": None}], + }, + "keys_to_learn": "entire_block", + }, + { + "model_config_overrides": { + "ffn": [_block_override(has_moe_layers, pruned=False)], + "attention": [{"num_key_value_heads": PRUNED_NUM_KV_HEADS, "no_op": None}], + }, + "keys_to_learn": "entire_block", + }, + ] + bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg, configs_list=configs_list) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + # Compute expected IDs by running set_experiment_id against each sub-config. + expected_ids = [] + for sub in configs_list: + sub_cfg = copy.deepcopy(bypass_cfg_dict) + sub_cfg["model"]["model_config_overrides"] = sub["model_config_overrides"] + sub_cfg["experiment_id"] = None + expected_ids.append(_expected_experiment_id(sub_cfg)) + + for experiment_id in expected_ids: + experiment_dir = puzzle_dir / "bypass/bypass_runs" / experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_multi_config_sequential[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + PUZZLETRON_FAMILIES, +) +def test_bypass_checkpoint_contents( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Verify that a bypass checkpoint contains expected HuggingFace model files.""" + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_checkpoint_contents_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_checkpoint_contents_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, + ) + + bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink: {ckpt_symlink}" + ) + + # The symlink resolves to the latest checkpoint dir; verify HF config exists. + resolved = ckpt_symlink.resolve() + config_json = resolved / "config.json" + assert config_json.exists(), ( + f"Expected HuggingFace config.json inside checkpoint: {config_json}" + ) + + # The saving_completed marker must be present (set by save_bypass_checkpoint). + saving_completed = resolved / "saving_completed" + assert saving_completed.exists(), ( + f"Expected saving_completed marker inside checkpoint: {saving_completed}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_checkpoint_contents[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) + + +# --------------------------------------------------------------------------- +# Tests below this line target a single (or two) family deliberately — they +# exercise paths where parametrizing over all 9 families is overkill or +# requires extras (e.g. NemotronH's mamba-ssm dep). +# --------------------------------------------------------------------------- + +# Llama-3.2-3B is the smallest dense family and the canonical "FFN bypass" path. +LLAMA_FAMILY = pytest.param( + "meta-llama/Llama-3.2-3B-Instruct", "llama", None, False, id="llama-3.2-3B" +) +# GPT-OSS adds MoE expert pruning (mlp_init_mode="ExpertRemoval") and windowed +# attention with sinks — different code paths than dense Llama. +GPT_OSS_FAMILY = pytest.param("openai/gpt-oss-20b", "gpt_oss", None, True, id="gpt-oss-20b") + + +# --------------------------------------------------------------------------- +# Resume from checkpoint +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + [LLAMA_FAMILY], +) +def test_bypass_resume_from_checkpoint( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Two-phase test: train + save, then re-launch with resume and verify continuity. + + Phase 1: short bypass run (2 steps), checkpoint saved under + ``puzzle_dir/bypass/bypass_runs//step-NNNNNN-ckpt/``. + Phase 2: same hydra_cfg + ``find_last_ckpt_for_resume=True`` + double the + training_tokens budget. The resume path in + ``training_loop.run_bypassed_training:805-840`` must restore + ``iter_num`` / ``step_num`` / ``token_count`` from the saved + ``args.json`` and load stitched-module + optimizer state from disk. + + The GradScaler save/load mechanism added in the recent CodeRabbit-driven + fix is tested separately in + ``tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py`` because + GradScaler is fp16-only and the bypass test infrastructure ships bf16, + which makes ``GradScaler.step()`` raise on the unscale path. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_resume_from_checkpoint_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_resume_from_checkpoint_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, + ) + + # ---- Phase 1: train + save --------------------------------------------- + phase1_cfg = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + phase1_cfg["find_last_ckpt_for_resume"] = False + OmegaConf.update(hydra_cfg, "bypass", phase1_cfg, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + expected_experiment_id = _expected_experiment_id(phase1_cfg) + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + + if rank == 0: + resume_checkpoint = find_latest_run_dir(experiment_dir) + assert resume_checkpoint is not None, f"Phase 1 missing resume checkpoint: {experiment_dir}" + args_json_path = Path(resume_checkpoint) / "args.json" + stitched_dir = Path(resume_checkpoint) / "stitched" + # Phase 1 must have produced the canonical artifacts. + assert args_json_path.exists(), f"Phase 1 missing args.json: {args_json_path}" + with open(args_json_path) as f: + phase1_state = json.load(f) + phase1_iter_num = phase1_state["iter_num"] + assert phase1_iter_num > 1, ( + f"Phase 1 should have advanced past iter 1, got {phase1_iter_num}" + ) + + # Optimizer state must be present (covers the resume path's load). + assert (stitched_dir / "block_0.optimizer_state.pth").exists(), stitched_dir + + dist.barrier() + + # ---- Phase 2: resume and continue -------------------------------------- + phase2_cfg = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + phase2_cfg["find_last_ckpt_for_resume"] = True + # Double the budget so the resumed run takes additional steps. + phase2_cfg["training"]["training_tokens"] = TRAINING_TOKENS * 2 + OmegaConf.update(hydra_cfg, "bypass", phase2_cfg, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + phase2_resume_checkpoint = find_latest_run_dir(experiment_dir) + assert phase2_resume_checkpoint is not None, f"Phase 2 missing checkpoint: {experiment_dir}" + phase2_args_json_path = Path(phase2_resume_checkpoint) / "args.json" + assert phase2_args_json_path.exists(), "Phase 2 should have args.json" + with open(phase2_args_json_path) as f: + phase2_state = json.load(f) + phase2_iter_num = phase2_state["iter_num"] + # The resumed run must have moved past phase 1's last iter — proves + # both that resume happened (didn't restart at 1) and that further + # training executed. + assert phase2_iter_num > phase1_iter_num, ( + f"Resume did not advance: phase1={phase1_iter_num}, phase2={phase2_iter_num}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_resume_from_checkpoint[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) + + +# --------------------------------------------------------------------------- +# Per-subblock training modes (Llama dense + GPT-OSS MoE/windowed-attn-sinks) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("keys_to_learn", ["subblock_ffn", "subblock_attention", "entire_block"]) +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + [LLAMA_FAMILY, GPT_OSS_FAMILY], +) +def test_bypass_subblock_modes( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + keys_to_learn: str, +): + """Verify that ``keys_to_learn`` correctly freezes the right param groups. + + For each (family, keys_to_learn) cell: + - Run bypass for 2 steps with that keys_to_learn. + - After training, load the saved stitched_module state dict. + - Compare against the teacher-derived initialization (``copied_dir`` of + the bypass experiment, which holds the post-init pre-train weights): + * subblock_ffn → only FFN keys differ from init; attention identical. + * subblock_attention → only attention keys differ; FFN identical. + * entire_block → both differ. + + GPT-OSS coverage matters because the MoE expert path uses + ``mlp_init_mode="ExpertRemoval"`` instead of ``"Truncate"``, and GPT-OSS's + windowed attention adds attention-sink parameters that the freeze must + correctly include in the "attention" group. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_subblock_modes_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + keys_to_learn, + ), + backend="nccl", + ) + + +def _test_bypass_subblock_modes_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + keys_to_learn: str, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, + ) + + bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + bypass_cfg_dict["model_factory"]["keys_to_learn"] = keys_to_learn + # Save start-of-training checkpoint so we can diff trained-vs-init. + bypass_cfg_dict["save_checkpoint_before_training"] = True + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + # `start-step-*` is the pre-training snapshot (saved when + # save_checkpoint_before_training=True). The post-training snapshot + # under this short-budget config lives at `final-step-*` (saved by the + # early-exit branch in training_loop.py); the periodic `step-*` save + # never fires because the budget is only 2 steps. `latest` is now a + # resume pointer only, so use the final checkpoint directly. + start_dirs = sorted(experiment_dir.glob("start-step-*-ckpt")) + assert start_dirs, f"Expected a start-step-* checkpoint under {experiment_dir}" + start_dir = start_dirs[0] + final_dirs = sorted(experiment_dir.glob("final-step-*-ckpt")) + assert final_dirs, f"Expected a final-step-* checkpoint under {experiment_dir}" + end_dir = final_dirs[-1].resolve() + assert end_dir != start_dir.resolve(), ( + f"Final checkpoint still points at the pre-training snapshot {end_dir} - " + "no post-training checkpoint was written." + ) + + # Diff every saved stitched module's state dict between start (pre-train) + # and end (post-train). Block names look like ``block_0``, ``block_1``… + ffn_token_set = {".mlp.", ".experts."} # Llama vs GPT-OSS naming + attn_token = ".self_attn." + + def _key_kind(key: str) -> str: + if attn_token in key: + return "attn" + if any(t in key for t in ffn_token_set): + return "ffn" + return "other" + + ffn_changed = False + attn_changed = False + for state_dict_path in (start_dir / "stitched").glob("block_*.state_dict.pth"): + end_path = end_dir / "stitched" / state_dict_path.name + if not end_path.exists(): + continue + start_state = torch.load(state_dict_path, map_location="cpu", weights_only=True) + end_state = torch.load(end_path, map_location="cpu", weights_only=True) + for key in start_state.keys() & end_state.keys(): + kind = _key_kind(key) + if kind == "other": + continue + changed = not torch.equal(start_state[key], end_state[key]) + if kind == "ffn" and changed: + ffn_changed = True + if kind == "attn" and changed: + attn_changed = True + + if keys_to_learn == "subblock_ffn": + assert ffn_changed, f"subblock_ffn should change FFN weights ({hf_model_name})" + assert not attn_changed, ( + f"subblock_ffn should leave attention weights bit-identical ({hf_model_name})" + ) + elif keys_to_learn == "subblock_attention": + assert attn_changed, ( + f"subblock_attention should change attention weights ({hf_model_name})" + ) + assert not ffn_changed, ( + f"subblock_attention should leave FFN weights bit-identical ({hf_model_name})" + ) + else: # entire_block + assert ffn_changed and attn_changed, ( + f"entire_block should change both groups ({hf_model_name}); " + f"got ffn={ffn_changed}, attn={attn_changed}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_subblock_modes" + f"[{hf_model_name}, keys_to_learn={keys_to_learn}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) + + +# --------------------------------------------------------------------------- +# End-to-end: bypass then build replacement library +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + [LLAMA_FAMILY], +) +def test_bypass_then_build_library( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Run bypass, then build the replacement library; assert bypass entries appear. + + Verifies the wiring between the bypass step and the downstream NAS step: + - ``realize_bypass_checkpoints`` creates a symlink at ``ckpts/``. + - ``_get_last_checkpoint_from_each_experiment`` resolves it back to the + bypass run dir. + - ``_build_subblocks_df``'s priority sort puts the bypass-rooted path + before non-bypass ones in the resulting DataFrame. + - The final ``replacement_library.json`` includes entries pointing at + the bypass experiment. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_then_build_library_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_then_build_library_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, + ) + + bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + ckpts_dir = puzzle_dir / "ckpts" + + # 1. The realize step must have created a symlink for this bypass run. + bypass_symlink = ckpts_dir / expected_experiment_id + assert bypass_symlink.is_symlink() or bypass_symlink.exists(), ( + f"Expected bypass symlink at {bypass_symlink}" + ) + + # 2. Discovery must find the bypass entry alongside the teacher (and any + # pruning-pipeline outputs from the setup helper). + discovered = build_lib._get_last_checkpoint_from_each_experiment(puzzle_dir) + bypass_resolved = bypass_symlink.resolve() + assert bypass_resolved in discovered, ( + f"Bypass run not discovered. Resolved={bypass_resolved}, discovered={discovered}" + ) + # The resolved bypass path must contain "bypass" + "bypass_runs" in its + # parts so the priority sort picks it up. + assert "bypass" in bypass_resolved.parts and "bypass_runs" in bypass_resolved.parts + + # 3. Build the replacement library and verify the bypass entry appears. + teacher_dir = ckpts_dir / "teacher" + subblocks_df = build_lib._build_subblocks_df( + master_puzzle_dir=puzzle_dir, + teacher_checkpoint_dir=teacher_dir, + add_ffn_no_ops=False, + add_attention_no_ops=False, + trust_remote_code=False, + ) + # Some subblock row's checkpoint_dir column must reference the bypass path. + # FFN-only rows leave attention_checkpoint_dir as NaN (and vice versa); we + # drop those before string-casting because pandas' .astype(str) doesn't + # reliably stringify NaN on object-dtype columns, and 'X' in float('nan') + # raises TypeError. + bypass_str = str(bypass_resolved) + attn_sources = subblocks_df["attention_checkpoint_dir"].dropna().astype(str).tolist() + ffn_sources = subblocks_df["ffn_checkpoint_dir"].dropna().astype(str).tolist() + assert any(bypass_str in s for s in attn_sources + ffn_sources), ( + f"replacement_library subblocks_df has no bypass-sourced rows. " + f"attn_sources={set(attn_sources)}, ffn_sources={set(ffn_sources)}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_then_build_library[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) diff --git a/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py new file mode 100644 index 00000000000..dc0df1b4f6b --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -0,0 +1,201 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Load round-trip tests for ``bypass_checkpoint_utils``. + +These pin that ``load_local_state`` correctly restores the optimizer and +grad-scaler state from disk into a fresh descriptor — the resume path's +main job after the recent dedupe (weights are now loaded from the HF +checkpoint via ``load_and_shard_model``, not from ``stitched/*.pth``). + +Lives under ``tests/gpu/`` because the production ``load_local_state`` +builds ``torch.device(f"cuda:{rank}")`` for ``map_location``, so a real CUDA +device is required to round-trip ``torch.load`` without monkeypatching the +device machinery. The full bypass GPU integration test cannot cover this +path because the test infrastructure ships bf16 and ``GradScaler.step()`` +is fp16-only (raises ``NotImplementedError: +_amp_foreach_non_finite_check_and_unscale_cuda not implemented for 'BFloat16'``). +These tests sidestep that by hitting the load functions directly, without +ever invoking ``.step()``. + +The corresponding save tests live in tests/unit/torch/puzzletron/ +test_bypass_checkpoint_utils.py — ``_save_local_state`` no longer touches +CUDA, so it doesn't need a GPU lane. +""" + +from collections import OrderedDict +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from torch.amp.grad_scaler import GradScaler + +from modelopt.torch.puzzletron.bypass_distillation import bypass_checkpoint_utils as bcu +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import ( + StitchedModuleDescriptor, +) + +# --------------------------------------------------------------------------- +# Fixture: silence the dist helpers so the save/load functions run on a +# single GPU process without `torchrun` / NCCL setup. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bcu_no_dist(monkeypatch): + """Mock the dist helpers so ``bypass_checkpoint_utils`` runs without distributed init.""" + monkeypatch.setattr(bcu.dist, "local_rank", lambda: 0) + monkeypatch.setattr(bcu.dist, "is_master", lambda: True) + monkeypatch.setattr(bcu.dist, "barrier", lambda: None) + return bcu + + +def _make_descriptor( + *, + with_optimizer: bool = True, + with_scaler: bool = True, + grad_scaler_init_scale: float = 2.0**16, +): + """Build a minimal StitchedModuleDescriptor on CPU. + + ``stitched_module`` is a real ``nn.Linear`` so ``state_dict()`` / + ``load_state_dict()`` work without needing the actual ``StitchedModule`` + machinery (which depends on the sewing-kit graph, distributed init, etc.). + + The GradScaler is created with ``enabled=True`` so that ``state_dict()`` + actually contains content (a disabled scaler returns ``{}``, making + round-trip tests vacuous). We never call ``.scale()`` / ``.step()`` so + none of the fp16-only kernels run — only the bookkeeping fields + (``scale``, ``growth_factor``, ``backoff_factor``, ``growth_interval``, + ``_growth_tracker``) go through save/load. + """ + module = nn.Linear(4, 4, bias=False) + owned_parameters = dict(module.named_parameters()) + owned_buffers: dict[str, torch.Tensor] = {} + optimizer = torch.optim.AdamW(list(module.parameters()), lr=1e-3) if with_optimizer else None + scaler = ( + GradScaler(device="cpu", enabled=True, init_scale=grad_scaler_init_scale) + if with_scaler + else None + ) + return StitchedModuleDescriptor( + stitched_module=module, + owned_parameters=owned_parameters, + owned_buffers=owned_buffers, + optimizer=optimizer, + grad_scaler=scaler, + ) + + +# --------------------------------------------------------------------------- +# Load: state survives the round-trip and lands back on the live scaler +# --------------------------------------------------------------------------- + + +def test_load_local_state_restores_grad_scaler_state(tmp_path: Path, bcu_no_dist): + """Round-trip: scaler with non-default init_scale → save → load into fresh scaler → state matches. + + This is the regression test for the CodeRabbit-flagged bug: prior to the + fix, ``load_local_state`` skipped the scaler entirely, so a resumed run + would silently start with a default scale (typically 65536.0) regardless + of where the previous run had grown the scale to. + + We compare via ``state_dict()`` rather than poking at private attributes + because the canonical save/load contract is ``state_dict()`` <-> + ``load_state_dict()``; ``state_dict()['scale']`` is the field a real + bypass run would have grown over time. + """ + bcu = bcu_no_dist + + # 1. Save phase: scaler with a non-default init scale. + save_descriptor = _make_descriptor(grad_scaler_init_scale=12345.0) + saved_state = save_descriptor.grad_scaler.state_dict() + assert saved_state["scale"] == 12345.0 # sanity: state actually carries the value + descriptors_save = OrderedDict([("block_0", save_descriptor)]) + bcu._save_local_state(descriptors_save, tmp_path) + + # 2. Load phase: a fresh descriptor with a different init scale; the load + # must overwrite it with the saved value. + load_descriptor = _make_descriptor(grad_scaler_init_scale=999.0) + pre_load_state = load_descriptor.grad_scaler.state_dict() + assert pre_load_state != saved_state # sanity: starts in a distinct state + descriptors_load = OrderedDict([("block_0", load_descriptor)]) + bcu.load_local_state(descriptors_load, tmp_path) + + assert load_descriptor.grad_scaler.state_dict() == saved_state + + +def test_load_local_state_handles_legacy_checkpoint_without_grad_scaler( + tmp_path: Path, bcu_no_dist +): + """Backward compat: a checkpoint saved before the GradScaler-fix must still load. + + Older bypass runs predating the GradScaler save did not write + ``block_0.grad_scaler.pth``. The current ``load_local_state`` must skip + silently in that case rather than raising — our deployed users have + legacy checkpoints they want to resume from. + """ + bcu = bcu_no_dist + + # First save with a scaler so we have a normal "complete" save… + save_descriptor = _make_descriptor() + descriptors_save = OrderedDict([("block_0", save_descriptor)]) + bcu._save_local_state(descriptors_save, tmp_path) + # …then delete the grad_scaler artifact to mimic a legacy checkpoint. + (tmp_path / "stitched" / "block_0.grad_scaler.pth").unlink() + + # Loading must not raise. + load_descriptor = _make_descriptor() + descriptors_load = OrderedDict([("block_0", load_descriptor)]) + bcu.load_local_state(descriptors_load, tmp_path) + + +def test_load_local_state_restores_optimizer_state(tmp_path: Path, bcu_no_dist): + """End-to-end optimizer round-trip — covers the resume path's main job.""" + bcu = bcu_no_dist + + save_descriptor = _make_descriptor() + # Take an optimizer step so AdamW has non-default ``state`` (exp_avg etc). + for p in save_descriptor.stitched_module.parameters(): + p.grad = torch.ones_like(p) + save_descriptor.optimizer.step() + saved_state = save_descriptor.optimizer.state_dict() + descriptors_save = OrderedDict([("block_0", save_descriptor)]) + bcu._save_local_state(descriptors_save, tmp_path) + + load_descriptor = _make_descriptor() + # Fresh optimizer's state dict should differ from `saved_state` until load. + assert load_descriptor.optimizer.state_dict() != saved_state + descriptors_load = OrderedDict([("block_0", load_descriptor)]) + bcu.load_local_state(descriptors_load, tmp_path) + + # After load, AdamW step counter and exp_avg buffers must match. + # Production runs co-locate model + state on cuda:0, but this fixture has the + # model on CPU so the loaded state ends up split: exp_avg / exp_avg_sq follow + # the param device (CPU), while AdamW's `step` tensor is loaded via + # ``map_location='cuda:0'`` and stays there. Move both to CPU for the + # comparison — we're verifying value equality, not device placement. + loaded_state = load_descriptor.optimizer.state_dict() + assert loaded_state["state"].keys() == saved_state["state"].keys() + for param_id in loaded_state["state"]: + for key, val in saved_state["state"][param_id].items(): + loaded_val = loaded_state["state"][param_id][key] + if torch.is_tensor(val): + assert torch.equal(loaded_val.to("cpu"), val.to("cpu")), ( + f"optimizer.state[{param_id}][{key}] not restored" + ) + else: + assert loaded_val == val diff --git a/tests/gpu/torch/puzzletron/test_bypass_resume.py b/tests/gpu/torch/puzzletron/test_bypass_resume.py new file mode 100644 index 00000000000..4cf90e9b685 --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_bypass_resume.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU integration test for the bypass-distillation resume path. + +The existing ``test_bypass.py`` covers the save side: a fresh bypass run +produces a checkpoint and a ``ckpts/`` symlink. What it doesn't cover is +the *resume* side: a re-launched job calling ``find_latest_run_dir`` against +a real experiment directory and loading optimizer / state via ``load_local_state``. + +That contract — between what training writes (``saving_completed`` marker, +``args.json``, ``stitched/*.pth``) and what the resume helpers read — is +exactly the kind of thing that quietly diverges as the save format evolves. +A unit test can pin the regex; only an integration test pins the byte-level +agreement between writer and reader. + +Single dense family (Llama-3.2-3B-Instruct) is enough — the resume code path +is family-agnostic. +""" + +from datetime import timedelta +from functools import partial +from pathlib import Path + +import pytest +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.misc import set_seed +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data +from omegaconf import OmegaConf + +import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations +import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation +import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel import convert_model +from modelopt.torch.puzzletron.bypass_distillation.bypass_checkpoint_utils import ( + find_latest_run_dir, +) +from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import set_experiment_id +from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir + +# Match the constants in test_bypass.py so the run completes in two steps. +SEED = 1234 +TRAINING_TOKENS = 128 +BLOCK_SIZE = 64 +PRUNED_INTERMEDIATE_SIZE = 256 +PRUNED_NUM_KV_HEADS = 4 + +# One dense family — resume path is family-agnostic, so a second parametrize +# row would only add runtime, not coverage. +HF_MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" +CONVERTER = "llama" + + +def _bypass_cfg_dict(*, find_last_ckpt_for_resume: bool) -> dict: + """Minimal bypass config — derived from test_bypass.py's _make_bypass_cfg_dict + for a dense family with FFN+KV pruning.""" + return { + "dtype": "bf16", + "seed": 42, + "experiment_id": None, + "experiment_dir": None, + "iter_num": 1, + "step_num": 1, + "token_count": 0, + "data": { + "data_column": "conversation", + "block_size": BLOCK_SIZE, + "bos_rate": 0.5, + "fim_rate": 0, + "fim_spm_rate": 0, + "source_datasets_to_discard": [], + "load_from_disk": True, + "keep_in_memory": False, + "val_dataset_name": "valid", + "max_eval_samples": 1, + "eval_samples_per_process": None, + "shuffle_train_data_seed": 42, + }, + "training": { + "learning_rate": 1e-4, + "training_tokens": TRAINING_TOKENS, + "micro_batch_size": 1, + "val_micro_batch_size": 1, + "warmup_ratio": 0.05, + "warmup_steps": None, + "min_lr_factor": 1e-5, + "grad_accumulation_steps": 1, + "skip_first_batches": 0, + "weight_decay": 0.1, + "decay_lr": True, + "beta1": 0.9, + "beta2": 0.95, + "use_grad_scaling": False, + "grad_clip": 1.0, + "grad_clip_type": "norm", + "clipping_count": 0, + "log_interval": 5, + "eval_interval": 100, + }, + "resume_checkpoint_path": None, + "find_last_ckpt_for_resume": find_last_ckpt_for_resume, + "parameter_count": None, + "init_checkpoint_path": None, + "model": { + "student_weights_dtype": "bf16", + "model_overrides": { + "delete_old_checkpoints": True, + "save_interval_seconds": None, + "save_interval": 1_000_000_000, + "save_checkpoint_when_done": True, + }, + "model_config_overrides": { + "ffn": [{"intermediate_size": PRUNED_INTERMEDIATE_SIZE, "no_op": None}], + "attention": [{"num_key_value_heads": PRUNED_NUM_KV_HEADS, "no_op": None}], + }, + }, + "model_factory": { + "factory": "bypass_factory_fn", + "block_loss_func": "normalized_mse_loss", + "gqa_init_mode": "AverageKV", + "mlp_init_mode": "Truncate", + "mlp_init_config": {"activations_log_dir": None}, + "linear_init_mode": "FromTeacher", + "submodule_for_loss_calculation": None, + "keys_to_learn": "entire_block", + }, + "disable_initial_validate": True, + "validate_teacher_model": False, + "validate_student_model": False, + "disable_validation": True, + "best_val_loss": 1e9, + "compile": False, + "disable_fa2": False, + "teacher_model_load_on_cpu": False, + "save_checkpoint_before_training": False, + "disable_checkpoint_save": False, + "save_best_ckpt": True, + "kill_after_first_save": False, + "realize_best_or_latest": "best", + "wandb_log": False, + "wandb": {"project": None, "entity": None}, + } + + +def _expected_experiment_dir(puzzle_dir: Path, bypass_cfg_dict: dict) -> Path: + """Compute the experiment directory the runtime will choose.""" + cfg = OmegaConf.create({"bypass": dict(bypass_cfg_dict)}) + set_experiment_id(cfg) + return puzzle_dir / "bypass/bypass_runs" / cfg.bypass.experiment_id + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required") +def test_bypass_resume_finds_latest_checkpoint(project_root_path: Path, tmp_path: Path): + """Run bypass once, verify ``find_latest_run_dir`` locates the saved + checkpoint, then re-launch with ``find_last_ckpt_for_resume=True`` and + verify the second run resumes from the saved iter_num. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_resume_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _resume_job(project_root_path: Path, tmp_path: Path, rank: int, size: int): + set_seed(SEED) + dist.setup(timeout=timedelta(minutes=10)) + + puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, HF_MODEL_NAME, hybrid_override_pattern=None + ) + + hydra_config_dir = str(project_root_path / "tests/gpu/torch/puzzletron/resources/configs") + hydra_config_name = f"{HF_MODEL_NAME}/{Path(HF_MODEL_NAME).name}" + + if rank == 0: + convert_model( + input_dir=str(hf_checkpoint_path), + output_dir=str(puzzle_dir / "ckpts/teacher"), + converter=CONVERTER, + ) + dist.barrier() + + import hydra + + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[f"puzzle_dir={puzzle_dir}", f"dataset_path={dataset_path}"], + ) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + score_pruning_activations.launch_score_activations(hydra_cfg) + if rank == 0: + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + dist.barrier() + + # First bypass run — produces a real checkpoint. + cfg_dict = _bypass_cfg_dict(find_last_ckpt_for_resume=False) + OmegaConf.update(hydra_cfg, "bypass", cfg_dict, merge=True) + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + experiment_dir = _expected_experiment_dir(puzzle_dir, cfg_dict) + if rank == 0: + # The save side wrote what the resume side expects. + assert experiment_dir.exists(), f"Expected experiment dir at {experiment_dir}" + latest = find_latest_run_dir(experiment_dir) + assert latest is not None, f"find_latest_run_dir returned None for {experiment_dir}" + assert (Path(latest) / "saving_completed").exists(), ( + f"Resume target {latest} missing saving_completed marker" + ) + assert (Path(latest) / "args.json").exists(), ( + f"Resume target {latest} missing args.json — load path would crash" + ) + dist.barrier() + + # Second bypass run — re-uses the same experiment_dir, finds the latest + # checkpoint via ``find_last_ckpt_for_resume=True``, and resumes. + # Reset cfg.bypass to a fresh dict (experiment_id back to None so + # set_experiment_id recomputes the same id from model_config_overrides). + cfg_dict_resume = _bypass_cfg_dict(find_last_ckpt_for_resume=True) + cfg_dict_resume["training"]["training_tokens"] = TRAINING_TOKENS * 2 + OmegaConf.update(hydra_cfg, "bypass", cfg_dict_resume, merge=True) + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + # After the second run, iter_num must have advanced past 1 — proving + # the run picked up state from the first run rather than starting fresh. + # (The resume code path overwrites iter_num from args.json on line 826.) + assert hydra_cfg.bypass.iter_num > 1, ( + f"Resume failed: iter_num={hydra_cfg.bypass.iter_num} suggests fresh start, " + f"not a resume from the saved checkpoint" + ) + + dist.cleanup() diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index d5ce4289abb..15653c7f6fc 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -21,7 +21,7 @@ import transformers from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.misc import set_seed -from _test_utils.torch.puzzletron.utils import setup_test_model_and_data +from _test_utils.torch.puzzletron.utils import PUZZLETRON_FAMILIES, setup_test_model_and_data from packaging.version import Version import modelopt.torch.puzzletron as mtpz @@ -38,17 +38,7 @@ @pytest.mark.parametrize( ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), - [ - ("meta-llama/Llama-3.1-8B-Instruct", "llama", None, False), - ("meta-llama/Llama-3.2-3B-Instruct", "llama", None, False), - ("mistralai/Mistral-Small-24B-Instruct-2501", "mistral_small", None, False), - ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16", "nemotron_h", "*E", True), - ("nvidia/NVIDIA-Nemotron-Nano-12B-v2", "nemotron_h_v2", "*-", False), - ("openai/gpt-oss-20b", "gpt_oss", None, True), - ("Qwen/Qwen2.5-7B-Instruct", "qwen2", None, False), - ("Qwen/Qwen3-8B", "qwen3", None, False), - ("Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True), - ], + PUZZLETRON_FAMILIES, ) def test_puzzletron( project_root_path: Path, diff --git a/tests/unit/torch/puzzletron/test_bypass_replacement_library.py b/tests/unit/torch/puzzletron/test_bypass_replacement_library.py new file mode 100644 index 00000000000..fa38d17ba15 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_replacement_library.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for replacement-library checkpoint discovery + bypass priority. + +The ``build_replacement_library`` module is responsible for two correctness-critical +behaviors after a bypass run: + +1. ``_get_last_checkpoint_from_each_experiment`` must surface every valid + checkpoint under ``puzzle_dir/ckpts/``, including those that live there only + as symlinks (which is exactly how bypass writes its results). +2. When a bypass-trained subblock and a Truncate-init subblock would produce + the same architectural identifier, the bypass-trained one must be preferred + by the downstream ``drop_duplicates(keep="first")``. This is enforced by a + tuple-sort closure inside ``_build_subblocks_df`` that gives bypass paths + priority 0 and everything else priority 1. + +A regression in either path silently discards bypass-trained weights — exactly +the kind of bug that's invisible in normal CI runs. +""" + +import json +from pathlib import Path + +import pytest + +from modelopt.torch.puzzletron.replacement_library import build_replacement_library as brl + +# --------------------------------------------------------------------------- +# Filesystem fixture: tiny puzzle_dir with three checkpoints +# --------------------------------------------------------------------------- + + +def _write_minimal_config(checkpoint_dir: Path) -> None: + """Write a placeholder config.json so the discovery rglob finds the dir. + + The actual config contents don't matter — these tests monkeypatch + ``is_valid_decilm_checkpoint`` so no real config parsing happens. + """ + checkpoint_dir.mkdir(parents=True, exist_ok=True) + (checkpoint_dir / "config.json").write_text("{}") + + +@pytest.fixture +def puzzle_dir_with_three_ckpts(tmp_path: Path, monkeypatch) -> Path: + """Build a puzzle_dir tree mirroring a real post-bypass post-prune layout. + + Layout:: + + puzzle_dir/ + ckpts/ + teacher/ # real dir + config.json + bypass_ffn_256_heads_4 -> ../bypass/bypass_runs/.../step-000010-ckpt + pruned_intermediate_256 -> ../pruning/pruned_intermediate_256 + bypass/bypass_runs/bypass_ffn_256_heads_4/step-000010-ckpt/ + config.json + pruning/pruned_intermediate_256/ + config.json + + The two non-teacher entries under ``ckpts/`` are symlinks — that is how + ``puzzletron_nas_plugin.realize_bypass_checkpoints`` and the pruning + pipeline actually write them. ``_get_last_checkpoint_from_each_experiment`` + must `.resolve()` these to see the real path under ``bypass/bypass_runs/`` + or ``pruning/`` — that resolution is what the priority sort later keys on. + """ + puzzle_dir = tmp_path / "puzzle_dir" + ckpts = puzzle_dir / "ckpts" + ckpts.mkdir(parents=True) + + # Teacher: real directory directly under ckpts. + _write_minimal_config(ckpts / "teacher") + + # Bypass: real dir under bypass/bypass_runs/, symlinked from ckpts/. + bypass_real = ( + puzzle_dir / "bypass" / "bypass_runs" / "bypass_ffn_256_heads_4" / "step-000010-ckpt" + ) + _write_minimal_config(bypass_real) + (ckpts / "bypass_ffn_256_heads_4").symlink_to(bypass_real, target_is_directory=True) + + # Truncate-pruned: real dir under pruning/, symlinked from ckpts/. + pruning_real = puzzle_dir / "pruning" / "pruned_intermediate_256" + _write_minimal_config(pruning_real) + (ckpts / "pruned_intermediate_256").symlink_to(pruning_real, target_is_directory=True) + + # Make every config.json look "valid" without parsing — load_model_config + # would otherwise try to load these as real HF configs. + monkeypatch.setattr(brl, "is_valid_decilm_checkpoint", lambda *a, **kw: True) + + return puzzle_dir + + +# --------------------------------------------------------------------------- +# Discovery +# --------------------------------------------------------------------------- + + +def test_get_last_checkpoint_from_each_experiment_finds_all_three( + puzzle_dir_with_three_ckpts: Path, +): + discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) + discovered_names = {p.name for p in discovered} + assert discovered_names == {"teacher", "step-000010-ckpt", "pruned_intermediate_256"} + + +def test_get_last_checkpoint_from_each_experiment_resolves_symlinks( + puzzle_dir_with_three_ckpts: Path, +): + """The resolved paths must reflect the real filesystem location. + + This is what makes the bypass-priority sort work — the closure inside + ``_build_subblocks_df`` checks ``"bypass" in p.parts and "bypass_runs" + in p.parts``, which only succeeds on the resolved path. + """ + discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) + bypass_path = next(p for p in discovered if p.name == "step-000010-ckpt") + assert "bypass" in bypass_path.parts + assert "bypass_runs" in bypass_path.parts + # And the pruning entry must NOT pick up "bypass" anywhere in its parts. + pruning_path = next(p for p in discovered if p.name == "pruned_intermediate_256") + assert "bypass" not in pruning_path.parts + + +def test_get_last_checkpoint_skips_invalid_checkpoints( + puzzle_dir_with_three_ckpts: Path, monkeypatch +): + """Only checkpoints that pass ``is_valid_decilm_checkpoint`` should appear. + + A regression where a malformed config.json silently slips through would + later raise inside ``_construct_subblock_rows_from_current_checkpoint`` + with a much less helpful traceback. + """ + + def _only_teacher_is_valid(checkpoint_dir, trust_remote_code=False): + return Path(checkpoint_dir).name == "teacher" + + monkeypatch.setattr(brl, "is_valid_decilm_checkpoint", _only_teacher_is_valid) + discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) + assert {p.name for p in discovered} == {"teacher"} + + +# --------------------------------------------------------------------------- +# Bypass-priority sort +# --------------------------------------------------------------------------- + + +def _bypass_priority(p: Path) -> tuple[int, str]: + """Re-implementation of the closure inside ``_build_subblocks_df``. + + Kept identical to ``modelopt/torch/puzzletron/replacement_library/ + build_replacement_library.py:222-225``. If that closure is changed, + update this test mirror; this is intentional duplication so the unit + test stays cheap (no need to build an end-to-end DataFrame just to + verify a 3-line priority function). + """ + is_bypass = "bypass" in p.parts and "bypass_runs" in p.parts + return (0 if is_bypass else 1, str(p)) + + +def test_bypass_priority_orders_bypass_before_pruning(puzzle_dir_with_three_ckpts: Path): + """The same input set the real code receives must sort bypass first.""" + discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) + teacher = next(p for p in discovered if p.name == "teacher") + non_teacher_sorted = sorted(discovered - {teacher}, key=_bypass_priority) + + # Bypass must come first; pruning must come second. + assert non_teacher_sorted[0].name == "step-000010-ckpt" + assert non_teacher_sorted[1].name == "pruned_intermediate_256" + + +def test_bypass_priority_is_stable_for_two_bypass_checkpoints(tmp_path: Path): + """Multiple bypass checkpoints must sort deterministically by string. + + Without this, ``set`` iteration order changes the picked-first checkpoint + across Python invocations, defeating the whole point of the priority sort. + """ + p1 = tmp_path / "puzzle/bypass/bypass_runs/bypass_a/step-000010-ckpt" + p2 = tmp_path / "puzzle/bypass/bypass_runs/bypass_b/step-000020-ckpt" + paths = {p2, p1} # insert in non-sorted order + out = sorted(paths, key=_bypass_priority) + assert [p.name for p in out] == ["step-000010-ckpt", "step-000020-ckpt"] + # Repeated runs hit the same order. + assert sorted({p1, p2}, key=_bypass_priority) == out + + +def test_infer_subblocks_to_extract_reads_args_json_attention(tmp_path: Path): + checkpoint_dir = tmp_path / "bypass_ckpt" + checkpoint_dir.mkdir() + (checkpoint_dir / "args.json").write_text( + json.dumps({"model_factory": {"keys_to_learn": "subblock_attention"}}) + ) + + assert brl._infer_subblocks_to_extract(checkpoint_dir, []) == ["attention"] + + +def test_infer_subblocks_to_extract_reads_args_json_ffn(tmp_path: Path): + checkpoint_dir = tmp_path / "bypass_ckpt" + checkpoint_dir.mkdir() + (checkpoint_dir / "args.json").write_text( + json.dumps({"model_factory": {"keys_to_learn": "subblock_ffn"}}) + ) + + assert brl._infer_subblocks_to_extract(checkpoint_dir, []) == ["ffn"] + + +def test_infer_subblocks_to_extract_reads_subblock_list(tmp_path: Path): + checkpoint_dir = tmp_path / "bypass_ckpt" + checkpoint_dir.mkdir() + (checkpoint_dir / "args.json").write_text( + json.dumps({"model_factory": {"keys_to_learn": ["subblock_attention", "subblock_ffn"]}}) + ) + + assert brl._infer_subblocks_to_extract(checkpoint_dir, []) == ["attention", "ffn"] + + +def test_infer_subblocks_to_extract_reads_args_json_entire_block(tmp_path: Path): + checkpoint_dir = tmp_path / "bypass_ckpt" + checkpoint_dir.mkdir() + (checkpoint_dir / "args.json").write_text( + json.dumps({"model_factory": {"keys_to_learn": "entire_block"}}) + ) + + assert brl._infer_subblocks_to_extract(checkpoint_dir, []) == ["block"] + + +def test_infer_subblocks_to_extract_rejects_non_subblock_keys_to_learn(tmp_path: Path): + checkpoint_dir = tmp_path / "bypass_ckpt" + checkpoint_dir.mkdir() + (checkpoint_dir / "args.json").write_text( + json.dumps({"model_factory": {"keys_to_learn": "q_proj"}}) + ) + + with pytest.raises(ValueError, match="keys_to_learn must be one of"): + brl._infer_subblocks_to_extract(checkpoint_dir, []) diff --git a/tests/unit/torch/puzzletron/test_puzzletron_progress.py b/tests/unit/torch/puzzletron/test_puzzletron_progress.py new file mode 100644 index 00000000000..f83b07cdbe2 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_puzzletron_progress.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``_total_steps`` / ``_progress_step`` in ``puzzletron_nas_plugin``. + +These two helpers are the single source of truth for the user-facing +``Puzzletron Progress N/T`` log lines emitted by ``convert_puzzletron_model`` +and ``PuzzletronSearcher.run_search``. A regression that drops or reorders a +stage silently misnumbers every progress message; worse, an off-by-one would +hide which stage the pipeline crashed in. +""" + +import pytest +from omegaconf import OmegaConf + +from modelopt.torch.puzzletron.puzzletron_nas_plugin import ( + _STAGE_ORDER, + _progress_step, + _total_steps, +) + + +def _cfg_with_bypass(): + return OmegaConf.create({"bypass": {"experiment_dir": "/tmp/x"}}) + + +def _cfg_without_bypass(): + return OmegaConf.create({"some_other_key": True}) + + +def _cfg_with_null_bypass(): + return OmegaConf.create({"bypass": None}) + + +def test_total_steps_with_bypass_is_nine(): + assert _total_steps(_cfg_with_bypass()) == 9 + + +def test_total_steps_without_bypass_key_is_eight(): + assert _total_steps(_cfg_without_bypass()) == 8 + + +def test_total_steps_with_null_bypass_is_eight(): + """``bypass: null`` (typical override-to-disable) must read as 'no bypass'.""" + assert _total_steps(_cfg_with_null_bypass()) == 8 + + +def test_progress_step_walks_eight_stages_without_bypass(): + cfg = _cfg_without_bypass() + expected_no_bypass = [s for s in _STAGE_ORDER if s != "bypass"] + seen = [] + for stage in expected_no_bypass: + step, total = _progress_step(cfg, stage) + seen.append((stage, step, total)) + assert seen == [ + ("start", 1, 8), + ("convert", 2, 8), + ("score_activations", 3, 8), + ("prune", 4, 8), + ("build_library", 5, 8), + ("score_blocks", 6, 8), + ("mip", 7, 8), + ("complete", 8, 8), + ] + + +def test_progress_step_walks_nine_stages_with_bypass(): + cfg = _cfg_with_bypass() + seen = [(stage, *_progress_step(cfg, stage)) for stage in _STAGE_ORDER] + assert seen == [ + ("start", 1, 9), + ("convert", 2, 9), + ("score_activations", 3, 9), + ("prune", 4, 9), + ("bypass", 5, 9), + ("build_library", 6, 9), + ("score_blocks", 7, 9), + ("mip", 8, 9), + ("complete", 9, 9), + ] + + +def test_progress_step_bypass_stage_unknown_when_absent(): + """Asking for the bypass stage when bypass isn't configured is a programming + error — must raise, not silently return 0/8.""" + cfg = _cfg_without_bypass() + with pytest.raises(ValueError, match="Unknown pipeline stage"): + _progress_step(cfg, "bypass") + + +def test_progress_step_unknown_stage_raises(): + cfg = _cfg_with_bypass() + with pytest.raises(ValueError, match="Unknown pipeline stage"): + _progress_step(cfg, "definitely_not_a_real_stage") + + +def test_mip_step_shifts_when_bypass_added_or_removed(): + """Removing bypass must shift MIP from 8/9 to 7/8 — pinned by the docstring + on _progress_step which calls this out explicitly.""" + assert _progress_step(_cfg_with_bypass(), "mip") == (8, 9) + assert _progress_step(_cfg_without_bypass(), "mip") == (7, 8) From 27b3f1d43eb22945e195deb90ff8568ed2cb20db Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Mon, 8 Jun 2026 10:19:41 +0200 Subject: [PATCH 2/8] Trim Puzzletron bypass tests Signed-off-by: Sepehr Sameni --- tests/gpu/torch/puzzletron/test_bypass.py | 167 ++++-------- .../torch/puzzletron/test_bypass_resume.py | 251 ------------------ .../test_bypass_replacement_library.py | 246 ----------------- .../puzzletron/test_puzzletron_progress.py | 105 ++------ .../test_replacement_library_bypass_config.py | 96 ++++++- 5 files changed, 162 insertions(+), 703 deletions(-) delete mode 100644 tests/gpu/torch/puzzletron/test_bypass_resume.py delete mode 100644 tests/unit/torch/puzzletron/test_bypass_replacement_library.py diff --git a/tests/gpu/torch/puzzletron/test_bypass.py b/tests/gpu/torch/puzzletron/test_bypass.py index 59881a242b1..ddfef355f40 100644 --- a/tests/gpu/torch/puzzletron/test_bypass.py +++ b/tests/gpu/torch/puzzletron/test_bypass.py @@ -15,8 +15,13 @@ """GPU integration tests for bypass distillation (blockwise local distillation). -Each test is parametrized over the same model families covered by ``test_puzzletron.py`` -(see ``PUZZLETRON_FAMILIES`` in ``tests/_test_utils/torch/puzzletron/utils.py``). +The tests use representative model families instead of parametrizing every scenario over +the full Puzzletron family matrix: + + - Llama-3.2-3B: dense FFN pruning with ``mlp_init_mode="Truncate"``. + - GPT-OSS-20B: MoE expert pruning, windowed attention, and attention sinks. + +The broader no-bypass family matrix remains covered by ``test_puzzletron.py``. Tiny model dimensions used throughout (set by ``setup_test_model_and_data``): - hidden_size: 256, intermediate_size: 512, num_layers: max(2, world_size) @@ -36,14 +41,8 @@ sourced from the family's pruning YAML (``mlp_init_config_yaml``) — no per-family branching needed in this test file. -To add a new model family: - 1. Append one row to PUZZLETRON_FAMILIES in tests/_test_utils/torch/puzzletron/utils.py. - 2. Ensure tests/gpu/torch/puzzletron/resources/configs//.yaml exists - and that setup_test_model_and_data() can build a tiny stand-in for it. - 3. For MoE families, ensure the family's descriptor registers ``"kv_heads"`` and - ``"experts_removal"`` in ``pruning_mixins()`` (see e.g. NemotronH, GPT-OSS, - Qwen3-VL descriptors). - 4. The four bypass tests below pick up the new row automatically. +To add a new bypass-specific model family, add it deliberately to the targeted +case lists below instead of expanding every test by default. """ import copy @@ -57,7 +56,7 @@ import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.misc import set_seed -from _test_utils.torch.puzzletron.utils import PUZZLETRON_FAMILIES, setup_test_model_and_data +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data from omegaconf import OmegaConf import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations @@ -92,6 +91,34 @@ TRAINING_TOKENS = 128 BLOCK_SIZE = 64 +# Llama-3.2-3B is the smallest dense family and the canonical "FFN bypass" path. +LLAMA_FAMILY = pytest.param( + "meta-llama/Llama-3.2-3B-Instruct", "llama", None, False, id="llama-3.2-3B" +) +# GPT-OSS adds MoE expert pruning (mlp_init_mode="ExpertRemoval") and windowed +# attention with sinks — different code paths than dense Llama. +GPT_OSS_FAMILY = pytest.param("openai/gpt-oss-20b", "gpt_oss", None, True, id="gpt-oss-20b") +BYPASS_SMOKE_FAMILIES = [LLAMA_FAMILY, GPT_OSS_FAMILY] + +BYPASS_SUBBLOCK_MODE_CASES = [ + pytest.param( + "meta-llama/Llama-3.2-3B-Instruct", + "llama", + None, + False, + "subblock_ffn", + id="llama-subblock-ffn", + ), + pytest.param( + "openai/gpt-oss-20b", + "gpt_oss", + None, + True, + "subblock_attention", + id="gpt-oss-subblock-attention", + ), +] + # --------------------------------------------------------------------------- # Helpers @@ -337,13 +364,13 @@ def _setup_hydra_cfg_and_pruning( # --------------------------------------------------------------------------- -# Tests — each parametrized over PUZZLETRON_FAMILIES +# Tests # --------------------------------------------------------------------------- @pytest.mark.parametrize( ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), - PUZZLETRON_FAMILIES, + BYPASS_SMOKE_FAMILIES, ) def test_bypass_block_pruning( project_root_path: Path, @@ -410,6 +437,13 @@ def _test_bypass_block_pruning_job( assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" ) + resolved = ckpt_symlink.resolve() + assert (resolved / "config.json").exists(), ( + f"Expected HuggingFace config.json inside checkpoint: {resolved}" + ) + assert (resolved / "saving_completed").exists(), ( + f"Expected saving_completed marker inside checkpoint: {resolved}" + ) dist.cleanup() @@ -421,7 +455,7 @@ def _test_bypass_block_pruning_job( @pytest.mark.parametrize( ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), - PUZZLETRON_FAMILIES, + [LLAMA_FAMILY], ) def test_bypass_kv_head_compression( project_root_path: Path, @@ -504,7 +538,7 @@ def _test_bypass_kv_head_compression_job( @pytest.mark.parametrize( ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), - PUZZLETRON_FAMILIES, + [LLAMA_FAMILY], ) def test_bypass_multi_config_sequential( project_root_path: Path, @@ -605,104 +639,6 @@ def _test_bypass_multi_config_sequential_job( ) -@pytest.mark.parametrize( - ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), - PUZZLETRON_FAMILIES, -) -def test_bypass_checkpoint_contents( - project_root_path: Path, - tmp_path: Path, - hf_model_name: str, - converter: str, - hybrid_override_pattern: str | None, - has_moe_layers: bool, -): - """Verify that a bypass checkpoint contains expected HuggingFace model files.""" - spawn_multiprocess_job( - size=torch.cuda.device_count(), - job=partial( - _test_bypass_checkpoint_contents_job, - project_root_path, - tmp_path, - hf_model_name, - converter, - hybrid_override_pattern, - has_moe_layers, - ), - backend="nccl", - ) - - -def _test_bypass_checkpoint_contents_job( - project_root_path: Path, - tmp_path: Path, - hf_model_name: str, - converter: str, - hybrid_override_pattern: str | None, - has_moe_layers: bool, - rank: int, - size: int, -): - puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( - project_root_path, - tmp_path, - rank, - size, - hf_model_name, - converter, - hybrid_override_pattern, - ) - - bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) - OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) - - bypass_distillation.launch_bypass_distillation(hydra_cfg) - dist.barrier() - - if rank == 0: - expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) - ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id - - assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( - f"Expected bypass checkpoint symlink: {ckpt_symlink}" - ) - - # The symlink resolves to the latest checkpoint dir; verify HF config exists. - resolved = ckpt_symlink.resolve() - config_json = resolved / "config.json" - assert config_json.exists(), ( - f"Expected HuggingFace config.json inside checkpoint: {config_json}" - ) - - # The saving_completed marker must be present (set by save_bypass_checkpoint). - saving_completed = resolved / "saving_completed" - assert saving_completed.exists(), ( - f"Expected saving_completed marker inside checkpoint: {saving_completed}" - ) - - dist.cleanup() - - print( - f"PYTEST SUMMARY: test_bypass_checkpoint_contents[{hf_model_name}] completed. " - f"Puzzle directory: {puzzle_dir}" - ) - - -# --------------------------------------------------------------------------- -# Tests below this line target a single (or two) family deliberately — they -# exercise paths where parametrizing over all 9 families is overkill or -# requires extras (e.g. NemotronH's mamba-ssm dep). -# --------------------------------------------------------------------------- - -# Llama-3.2-3B is the smallest dense family and the canonical "FFN bypass" path. -LLAMA_FAMILY = pytest.param( - "meta-llama/Llama-3.2-3B-Instruct", "llama", None, False, id="llama-3.2-3B" -) -# GPT-OSS adds MoE expert pruning (mlp_init_mode="ExpertRemoval") and windowed -# attention with sinks — different code paths than dense Llama. -GPT_OSS_FAMILY = pytest.param("openai/gpt-oss-20b", "gpt_oss", None, True, id="gpt-oss-20b") - - # --------------------------------------------------------------------------- # Resume from checkpoint # --------------------------------------------------------------------------- @@ -839,10 +775,9 @@ def _test_bypass_resume_from_checkpoint_job( # --------------------------------------------------------------------------- -@pytest.mark.parametrize("keys_to_learn", ["subblock_ffn", "subblock_attention", "entire_block"]) @pytest.mark.parametrize( - ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), - [LLAMA_FAMILY, GPT_OSS_FAMILY], + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers", "keys_to_learn"), + BYPASS_SUBBLOCK_MODE_CASES, ) def test_bypass_subblock_modes( project_root_path: Path, diff --git a/tests/gpu/torch/puzzletron/test_bypass_resume.py b/tests/gpu/torch/puzzletron/test_bypass_resume.py deleted file mode 100644 index 4cf90e9b685..00000000000 --- a/tests/gpu/torch/puzzletron/test_bypass_resume.py +++ /dev/null @@ -1,251 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""GPU integration test for the bypass-distillation resume path. - -The existing ``test_bypass.py`` covers the save side: a fresh bypass run -produces a checkpoint and a ``ckpts/`` symlink. What it doesn't cover is -the *resume* side: a re-launched job calling ``find_latest_run_dir`` against -a real experiment directory and loading optimizer / state via ``load_local_state``. - -That contract — between what training writes (``saving_completed`` marker, -``args.json``, ``stitched/*.pth``) and what the resume helpers read — is -exactly the kind of thing that quietly diverges as the save format evolves. -A unit test can pin the regex; only an integration test pins the byte-level -agreement between writer and reader. - -Single dense family (Llama-3.2-3B-Instruct) is enough — the resume code path -is family-agnostic. -""" - -from datetime import timedelta -from functools import partial -from pathlib import Path - -import pytest -import torch -from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from _test_utils.torch.misc import set_seed -from _test_utils.torch.puzzletron.utils import setup_test_model_and_data -from omegaconf import OmegaConf - -import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations -import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation -import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts -import modelopt.torch.utils.distributed as dist -from modelopt.torch.puzzletron.anymodel import convert_model -from modelopt.torch.puzzletron.bypass_distillation.bypass_checkpoint_utils import ( - find_latest_run_dir, -) -from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import set_experiment_id -from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir - -# Match the constants in test_bypass.py so the run completes in two steps. -SEED = 1234 -TRAINING_TOKENS = 128 -BLOCK_SIZE = 64 -PRUNED_INTERMEDIATE_SIZE = 256 -PRUNED_NUM_KV_HEADS = 4 - -# One dense family — resume path is family-agnostic, so a second parametrize -# row would only add runtime, not coverage. -HF_MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" -CONVERTER = "llama" - - -def _bypass_cfg_dict(*, find_last_ckpt_for_resume: bool) -> dict: - """Minimal bypass config — derived from test_bypass.py's _make_bypass_cfg_dict - for a dense family with FFN+KV pruning.""" - return { - "dtype": "bf16", - "seed": 42, - "experiment_id": None, - "experiment_dir": None, - "iter_num": 1, - "step_num": 1, - "token_count": 0, - "data": { - "data_column": "conversation", - "block_size": BLOCK_SIZE, - "bos_rate": 0.5, - "fim_rate": 0, - "fim_spm_rate": 0, - "source_datasets_to_discard": [], - "load_from_disk": True, - "keep_in_memory": False, - "val_dataset_name": "valid", - "max_eval_samples": 1, - "eval_samples_per_process": None, - "shuffle_train_data_seed": 42, - }, - "training": { - "learning_rate": 1e-4, - "training_tokens": TRAINING_TOKENS, - "micro_batch_size": 1, - "val_micro_batch_size": 1, - "warmup_ratio": 0.05, - "warmup_steps": None, - "min_lr_factor": 1e-5, - "grad_accumulation_steps": 1, - "skip_first_batches": 0, - "weight_decay": 0.1, - "decay_lr": True, - "beta1": 0.9, - "beta2": 0.95, - "use_grad_scaling": False, - "grad_clip": 1.0, - "grad_clip_type": "norm", - "clipping_count": 0, - "log_interval": 5, - "eval_interval": 100, - }, - "resume_checkpoint_path": None, - "find_last_ckpt_for_resume": find_last_ckpt_for_resume, - "parameter_count": None, - "init_checkpoint_path": None, - "model": { - "student_weights_dtype": "bf16", - "model_overrides": { - "delete_old_checkpoints": True, - "save_interval_seconds": None, - "save_interval": 1_000_000_000, - "save_checkpoint_when_done": True, - }, - "model_config_overrides": { - "ffn": [{"intermediate_size": PRUNED_INTERMEDIATE_SIZE, "no_op": None}], - "attention": [{"num_key_value_heads": PRUNED_NUM_KV_HEADS, "no_op": None}], - }, - }, - "model_factory": { - "factory": "bypass_factory_fn", - "block_loss_func": "normalized_mse_loss", - "gqa_init_mode": "AverageKV", - "mlp_init_mode": "Truncate", - "mlp_init_config": {"activations_log_dir": None}, - "linear_init_mode": "FromTeacher", - "submodule_for_loss_calculation": None, - "keys_to_learn": "entire_block", - }, - "disable_initial_validate": True, - "validate_teacher_model": False, - "validate_student_model": False, - "disable_validation": True, - "best_val_loss": 1e9, - "compile": False, - "disable_fa2": False, - "teacher_model_load_on_cpu": False, - "save_checkpoint_before_training": False, - "disable_checkpoint_save": False, - "save_best_ckpt": True, - "kill_after_first_save": False, - "realize_best_or_latest": "best", - "wandb_log": False, - "wandb": {"project": None, "entity": None}, - } - - -def _expected_experiment_dir(puzzle_dir: Path, bypass_cfg_dict: dict) -> Path: - """Compute the experiment directory the runtime will choose.""" - cfg = OmegaConf.create({"bypass": dict(bypass_cfg_dict)}) - set_experiment_id(cfg) - return puzzle_dir / "bypass/bypass_runs" / cfg.bypass.experiment_id - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required") -def test_bypass_resume_finds_latest_checkpoint(project_root_path: Path, tmp_path: Path): - """Run bypass once, verify ``find_latest_run_dir`` locates the saved - checkpoint, then re-launch with ``find_last_ckpt_for_resume=True`` and - verify the second run resumes from the saved iter_num. - """ - spawn_multiprocess_job( - size=torch.cuda.device_count(), - job=partial(_resume_job, project_root_path, tmp_path), - backend="nccl", - ) - - -def _resume_job(project_root_path: Path, tmp_path: Path, rank: int, size: int): - set_seed(SEED) - dist.setup(timeout=timedelta(minutes=10)) - - puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( - tmp_path, rank, HF_MODEL_NAME, hybrid_override_pattern=None - ) - - hydra_config_dir = str(project_root_path / "tests/gpu/torch/puzzletron/resources/configs") - hydra_config_name = f"{HF_MODEL_NAME}/{Path(HF_MODEL_NAME).name}" - - if rank == 0: - convert_model( - input_dir=str(hf_checkpoint_path), - output_dir=str(puzzle_dir / "ckpts/teacher"), - converter=CONVERTER, - ) - dist.barrier() - - import hydra - - hydra_cfg = initialize_hydra_config_for_dir( - config_dir=hydra_config_dir, - config_name=hydra_config_name, - overrides=[f"puzzle_dir={puzzle_dir}", f"dataset_path={dataset_path}"], - ) - hydra_cfg = hydra.utils.instantiate(hydra_cfg) - - score_pruning_activations.launch_score_activations(hydra_cfg) - if rank == 0: - pruning_ckpts.launch_prune_ckpt(hydra_cfg) - dist.barrier() - - # First bypass run — produces a real checkpoint. - cfg_dict = _bypass_cfg_dict(find_last_ckpt_for_resume=False) - OmegaConf.update(hydra_cfg, "bypass", cfg_dict, merge=True) - bypass_distillation.launch_bypass_distillation(hydra_cfg) - dist.barrier() - - experiment_dir = _expected_experiment_dir(puzzle_dir, cfg_dict) - if rank == 0: - # The save side wrote what the resume side expects. - assert experiment_dir.exists(), f"Expected experiment dir at {experiment_dir}" - latest = find_latest_run_dir(experiment_dir) - assert latest is not None, f"find_latest_run_dir returned None for {experiment_dir}" - assert (Path(latest) / "saving_completed").exists(), ( - f"Resume target {latest} missing saving_completed marker" - ) - assert (Path(latest) / "args.json").exists(), ( - f"Resume target {latest} missing args.json — load path would crash" - ) - dist.barrier() - - # Second bypass run — re-uses the same experiment_dir, finds the latest - # checkpoint via ``find_last_ckpt_for_resume=True``, and resumes. - # Reset cfg.bypass to a fresh dict (experiment_id back to None so - # set_experiment_id recomputes the same id from model_config_overrides). - cfg_dict_resume = _bypass_cfg_dict(find_last_ckpt_for_resume=True) - cfg_dict_resume["training"]["training_tokens"] = TRAINING_TOKENS * 2 - OmegaConf.update(hydra_cfg, "bypass", cfg_dict_resume, merge=True) - bypass_distillation.launch_bypass_distillation(hydra_cfg) - dist.barrier() - - if rank == 0: - # After the second run, iter_num must have advanced past 1 — proving - # the run picked up state from the first run rather than starting fresh. - # (The resume code path overwrites iter_num from args.json on line 826.) - assert hydra_cfg.bypass.iter_num > 1, ( - f"Resume failed: iter_num={hydra_cfg.bypass.iter_num} suggests fresh start, " - f"not a resume from the saved checkpoint" - ) - - dist.cleanup() diff --git a/tests/unit/torch/puzzletron/test_bypass_replacement_library.py b/tests/unit/torch/puzzletron/test_bypass_replacement_library.py deleted file mode 100644 index fa38d17ba15..00000000000 --- a/tests/unit/torch/puzzletron/test_bypass_replacement_library.py +++ /dev/null @@ -1,246 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for replacement-library checkpoint discovery + bypass priority. - -The ``build_replacement_library`` module is responsible for two correctness-critical -behaviors after a bypass run: - -1. ``_get_last_checkpoint_from_each_experiment`` must surface every valid - checkpoint under ``puzzle_dir/ckpts/``, including those that live there only - as symlinks (which is exactly how bypass writes its results). -2. When a bypass-trained subblock and a Truncate-init subblock would produce - the same architectural identifier, the bypass-trained one must be preferred - by the downstream ``drop_duplicates(keep="first")``. This is enforced by a - tuple-sort closure inside ``_build_subblocks_df`` that gives bypass paths - priority 0 and everything else priority 1. - -A regression in either path silently discards bypass-trained weights — exactly -the kind of bug that's invisible in normal CI runs. -""" - -import json -from pathlib import Path - -import pytest - -from modelopt.torch.puzzletron.replacement_library import build_replacement_library as brl - -# --------------------------------------------------------------------------- -# Filesystem fixture: tiny puzzle_dir with three checkpoints -# --------------------------------------------------------------------------- - - -def _write_minimal_config(checkpoint_dir: Path) -> None: - """Write a placeholder config.json so the discovery rglob finds the dir. - - The actual config contents don't matter — these tests monkeypatch - ``is_valid_decilm_checkpoint`` so no real config parsing happens. - """ - checkpoint_dir.mkdir(parents=True, exist_ok=True) - (checkpoint_dir / "config.json").write_text("{}") - - -@pytest.fixture -def puzzle_dir_with_three_ckpts(tmp_path: Path, monkeypatch) -> Path: - """Build a puzzle_dir tree mirroring a real post-bypass post-prune layout. - - Layout:: - - puzzle_dir/ - ckpts/ - teacher/ # real dir - config.json - bypass_ffn_256_heads_4 -> ../bypass/bypass_runs/.../step-000010-ckpt - pruned_intermediate_256 -> ../pruning/pruned_intermediate_256 - bypass/bypass_runs/bypass_ffn_256_heads_4/step-000010-ckpt/ - config.json - pruning/pruned_intermediate_256/ - config.json - - The two non-teacher entries under ``ckpts/`` are symlinks — that is how - ``puzzletron_nas_plugin.realize_bypass_checkpoints`` and the pruning - pipeline actually write them. ``_get_last_checkpoint_from_each_experiment`` - must `.resolve()` these to see the real path under ``bypass/bypass_runs/`` - or ``pruning/`` — that resolution is what the priority sort later keys on. - """ - puzzle_dir = tmp_path / "puzzle_dir" - ckpts = puzzle_dir / "ckpts" - ckpts.mkdir(parents=True) - - # Teacher: real directory directly under ckpts. - _write_minimal_config(ckpts / "teacher") - - # Bypass: real dir under bypass/bypass_runs/, symlinked from ckpts/. - bypass_real = ( - puzzle_dir / "bypass" / "bypass_runs" / "bypass_ffn_256_heads_4" / "step-000010-ckpt" - ) - _write_minimal_config(bypass_real) - (ckpts / "bypass_ffn_256_heads_4").symlink_to(bypass_real, target_is_directory=True) - - # Truncate-pruned: real dir under pruning/, symlinked from ckpts/. - pruning_real = puzzle_dir / "pruning" / "pruned_intermediate_256" - _write_minimal_config(pruning_real) - (ckpts / "pruned_intermediate_256").symlink_to(pruning_real, target_is_directory=True) - - # Make every config.json look "valid" without parsing — load_model_config - # would otherwise try to load these as real HF configs. - monkeypatch.setattr(brl, "is_valid_decilm_checkpoint", lambda *a, **kw: True) - - return puzzle_dir - - -# --------------------------------------------------------------------------- -# Discovery -# --------------------------------------------------------------------------- - - -def test_get_last_checkpoint_from_each_experiment_finds_all_three( - puzzle_dir_with_three_ckpts: Path, -): - discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) - discovered_names = {p.name for p in discovered} - assert discovered_names == {"teacher", "step-000010-ckpt", "pruned_intermediate_256"} - - -def test_get_last_checkpoint_from_each_experiment_resolves_symlinks( - puzzle_dir_with_three_ckpts: Path, -): - """The resolved paths must reflect the real filesystem location. - - This is what makes the bypass-priority sort work — the closure inside - ``_build_subblocks_df`` checks ``"bypass" in p.parts and "bypass_runs" - in p.parts``, which only succeeds on the resolved path. - """ - discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) - bypass_path = next(p for p in discovered if p.name == "step-000010-ckpt") - assert "bypass" in bypass_path.parts - assert "bypass_runs" in bypass_path.parts - # And the pruning entry must NOT pick up "bypass" anywhere in its parts. - pruning_path = next(p for p in discovered if p.name == "pruned_intermediate_256") - assert "bypass" not in pruning_path.parts - - -def test_get_last_checkpoint_skips_invalid_checkpoints( - puzzle_dir_with_three_ckpts: Path, monkeypatch -): - """Only checkpoints that pass ``is_valid_decilm_checkpoint`` should appear. - - A regression where a malformed config.json silently slips through would - later raise inside ``_construct_subblock_rows_from_current_checkpoint`` - with a much less helpful traceback. - """ - - def _only_teacher_is_valid(checkpoint_dir, trust_remote_code=False): - return Path(checkpoint_dir).name == "teacher" - - monkeypatch.setattr(brl, "is_valid_decilm_checkpoint", _only_teacher_is_valid) - discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) - assert {p.name for p in discovered} == {"teacher"} - - -# --------------------------------------------------------------------------- -# Bypass-priority sort -# --------------------------------------------------------------------------- - - -def _bypass_priority(p: Path) -> tuple[int, str]: - """Re-implementation of the closure inside ``_build_subblocks_df``. - - Kept identical to ``modelopt/torch/puzzletron/replacement_library/ - build_replacement_library.py:222-225``. If that closure is changed, - update this test mirror; this is intentional duplication so the unit - test stays cheap (no need to build an end-to-end DataFrame just to - verify a 3-line priority function). - """ - is_bypass = "bypass" in p.parts and "bypass_runs" in p.parts - return (0 if is_bypass else 1, str(p)) - - -def test_bypass_priority_orders_bypass_before_pruning(puzzle_dir_with_three_ckpts: Path): - """The same input set the real code receives must sort bypass first.""" - discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) - teacher = next(p for p in discovered if p.name == "teacher") - non_teacher_sorted = sorted(discovered - {teacher}, key=_bypass_priority) - - # Bypass must come first; pruning must come second. - assert non_teacher_sorted[0].name == "step-000010-ckpt" - assert non_teacher_sorted[1].name == "pruned_intermediate_256" - - -def test_bypass_priority_is_stable_for_two_bypass_checkpoints(tmp_path: Path): - """Multiple bypass checkpoints must sort deterministically by string. - - Without this, ``set`` iteration order changes the picked-first checkpoint - across Python invocations, defeating the whole point of the priority sort. - """ - p1 = tmp_path / "puzzle/bypass/bypass_runs/bypass_a/step-000010-ckpt" - p2 = tmp_path / "puzzle/bypass/bypass_runs/bypass_b/step-000020-ckpt" - paths = {p2, p1} # insert in non-sorted order - out = sorted(paths, key=_bypass_priority) - assert [p.name for p in out] == ["step-000010-ckpt", "step-000020-ckpt"] - # Repeated runs hit the same order. - assert sorted({p1, p2}, key=_bypass_priority) == out - - -def test_infer_subblocks_to_extract_reads_args_json_attention(tmp_path: Path): - checkpoint_dir = tmp_path / "bypass_ckpt" - checkpoint_dir.mkdir() - (checkpoint_dir / "args.json").write_text( - json.dumps({"model_factory": {"keys_to_learn": "subblock_attention"}}) - ) - - assert brl._infer_subblocks_to_extract(checkpoint_dir, []) == ["attention"] - - -def test_infer_subblocks_to_extract_reads_args_json_ffn(tmp_path: Path): - checkpoint_dir = tmp_path / "bypass_ckpt" - checkpoint_dir.mkdir() - (checkpoint_dir / "args.json").write_text( - json.dumps({"model_factory": {"keys_to_learn": "subblock_ffn"}}) - ) - - assert brl._infer_subblocks_to_extract(checkpoint_dir, []) == ["ffn"] - - -def test_infer_subblocks_to_extract_reads_subblock_list(tmp_path: Path): - checkpoint_dir = tmp_path / "bypass_ckpt" - checkpoint_dir.mkdir() - (checkpoint_dir / "args.json").write_text( - json.dumps({"model_factory": {"keys_to_learn": ["subblock_attention", "subblock_ffn"]}}) - ) - - assert brl._infer_subblocks_to_extract(checkpoint_dir, []) == ["attention", "ffn"] - - -def test_infer_subblocks_to_extract_reads_args_json_entire_block(tmp_path: Path): - checkpoint_dir = tmp_path / "bypass_ckpt" - checkpoint_dir.mkdir() - (checkpoint_dir / "args.json").write_text( - json.dumps({"model_factory": {"keys_to_learn": "entire_block"}}) - ) - - assert brl._infer_subblocks_to_extract(checkpoint_dir, []) == ["block"] - - -def test_infer_subblocks_to_extract_rejects_non_subblock_keys_to_learn(tmp_path: Path): - checkpoint_dir = tmp_path / "bypass_ckpt" - checkpoint_dir.mkdir() - (checkpoint_dir / "args.json").write_text( - json.dumps({"model_factory": {"keys_to_learn": "q_proj"}}) - ) - - with pytest.raises(ValueError, match="keys_to_learn must be one of"): - brl._infer_subblocks_to_extract(checkpoint_dir, []) diff --git a/tests/unit/torch/puzzletron/test_puzzletron_progress.py b/tests/unit/torch/puzzletron/test_puzzletron_progress.py index f83b07cdbe2..052939a4792 100644 --- a/tests/unit/torch/puzzletron/test_puzzletron_progress.py +++ b/tests/unit/torch/puzzletron/test_puzzletron_progress.py @@ -13,23 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for ``_total_steps`` / ``_progress_step`` in ``puzzletron_nas_plugin``. - -These two helpers are the single source of truth for the user-facing -``Puzzletron Progress N/T`` log lines emitted by ``convert_puzzletron_model`` -and ``PuzzletronSearcher.run_search``. A regression that drops or reorders a -stage silently misnumbers every progress message; worse, an off-by-one would -hide which stage the pipeline crashed in. -""" +"""Tests for progress numbering with the optional bypass stage.""" import pytest from omegaconf import OmegaConf -from modelopt.torch.puzzletron.puzzletron_nas_plugin import ( - _STAGE_ORDER, - _progress_step, - _total_steps, -) +from modelopt.torch.puzzletron.puzzletron_nas_plugin import _progress_step def _cfg_with_bypass(): @@ -44,70 +33,28 @@ def _cfg_with_null_bypass(): return OmegaConf.create({"bypass": None}) -def test_total_steps_with_bypass_is_nine(): - assert _total_steps(_cfg_with_bypass()) == 9 - - -def test_total_steps_without_bypass_key_is_eight(): - assert _total_steps(_cfg_without_bypass()) == 8 - - -def test_total_steps_with_null_bypass_is_eight(): - """``bypass: null`` (typical override-to-disable) must read as 'no bypass'.""" - assert _total_steps(_cfg_with_null_bypass()) == 8 - - -def test_progress_step_walks_eight_stages_without_bypass(): - cfg = _cfg_without_bypass() - expected_no_bypass = [s for s in _STAGE_ORDER if s != "bypass"] - seen = [] - for stage in expected_no_bypass: - step, total = _progress_step(cfg, stage) - seen.append((stage, step, total)) - assert seen == [ - ("start", 1, 8), - ("convert", 2, 8), - ("score_activations", 3, 8), - ("prune", 4, 8), - ("build_library", 5, 8), - ("score_blocks", 6, 8), - ("mip", 7, 8), - ("complete", 8, 8), - ] - - -def test_progress_step_walks_nine_stages_with_bypass(): - cfg = _cfg_with_bypass() - seen = [(stage, *_progress_step(cfg, stage)) for stage in _STAGE_ORDER] - assert seen == [ - ("start", 1, 9), - ("convert", 2, 9), - ("score_activations", 3, 9), - ("prune", 4, 9), - ("bypass", 5, 9), - ("build_library", 6, 9), - ("score_blocks", 7, 9), - ("mip", 8, 9), - ("complete", 9, 9), - ] - - -def test_progress_step_bypass_stage_unknown_when_absent(): - """Asking for the bypass stage when bypass isn't configured is a programming - error — must raise, not silently return 0/8.""" - cfg = _cfg_without_bypass() - with pytest.raises(ValueError, match="Unknown pipeline stage"): - _progress_step(cfg, "bypass") - - -def test_progress_step_unknown_stage_raises(): - cfg = _cfg_with_bypass() +@pytest.mark.parametrize( + ("cfg", "stage", "expected_step"), + [ + (_cfg_without_bypass(), "mip", (7, 8)), + (_cfg_with_null_bypass(), "mip", (7, 8)), + (_cfg_with_bypass(), "bypass", (5, 9)), + (_cfg_with_bypass(), "mip", (8, 9)), + ], +) +def test_progress_step_accounts_for_optional_bypass( + cfg, stage: str, expected_step: tuple[int, int] +): + assert _progress_step(cfg, stage) == expected_step + + +@pytest.mark.parametrize( + ("cfg", "stage"), + [ + (_cfg_without_bypass(), "bypass"), + (_cfg_with_bypass(), "definitely_not_a_real_stage"), + ], +) +def test_progress_step_rejects_unreachable_stages(cfg, stage: str): with pytest.raises(ValueError, match="Unknown pipeline stage"): - _progress_step(cfg, "definitely_not_a_real_stage") - - -def test_mip_step_shifts_when_bypass_added_or_removed(): - """Removing bypass must shift MIP from 8/9 to 7/8 — pinned by the docstring - on _progress_step which calls this out explicitly.""" - assert _progress_step(_cfg_with_bypass(), "mip") == (8, 9) - assert _progress_step(_cfg_without_bypass(), "mip") == (7, 8) + _progress_step(cfg, stage) diff --git a/tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py b/tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py index 018807ee97b..fa775a656e9 100644 --- a/tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py +++ b/tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py @@ -18,14 +18,15 @@ import json from pathlib import Path +import pandas as pd import pytest -from modelopt.torch.puzzletron.replacement_library.build_replacement_library import ( - _infer_subblocks_to_extract, -) +from modelopt.torch.puzzletron.block_config import FFNConfig +from modelopt.torch.puzzletron.replacement_library import build_replacement_library as brl -def test_infer_subblocks_to_extract_accepts_bypass_keys(tmp_path: Path): +@pytest.mark.parametrize("metadata_file", ["bypass_config.json", "args.json"]) +def test_infer_subblocks_to_extract_accepts_bypass_keys(tmp_path: Path, metadata_file: str): for i, (keys_to_learn, expected_subblocks) in enumerate( [ ("entire_block", ["block"]), @@ -37,20 +38,93 @@ def test_infer_subblocks_to_extract_accepts_bypass_keys(tmp_path: Path): ): checkpoint_dir = tmp_path / f"checkpoint_{i}" checkpoint_dir.mkdir() - (checkpoint_dir / "bypass_config.json").write_text( - json.dumps({"keys_to_learn": keys_to_learn}) + metadata = ( + {"keys_to_learn": keys_to_learn} + if metadata_file == "bypass_config.json" + else {"model_factory": {"keys_to_learn": keys_to_learn}} ) + (checkpoint_dir / metadata_file).write_text(json.dumps(metadata)) - assert _infer_subblocks_to_extract(checkpoint_dir, []) == expected_subblocks + assert brl._infer_subblocks_to_extract(checkpoint_dir, []) == expected_subblocks -def test_infer_subblocks_to_extract_rejects_legacy_keys(tmp_path: Path): +@pytest.mark.parametrize("metadata_file", ["bypass_config.json", "args.json"]) +def test_infer_subblocks_to_extract_rejects_legacy_keys(tmp_path: Path, metadata_file: str): for i, keys_to_learn in enumerate(["mlp", "attn", ["mlp", "attn"]]): checkpoint_dir = tmp_path / f"legacy_checkpoint_{i}" checkpoint_dir.mkdir() - (checkpoint_dir / "bypass_config.json").write_text( - json.dumps({"keys_to_learn": keys_to_learn}) + metadata = ( + {"keys_to_learn": keys_to_learn} + if metadata_file == "bypass_config.json" + else {"model_factory": {"keys_to_learn": keys_to_learn}} ) + (checkpoint_dir / metadata_file).write_text(json.dumps(metadata)) with pytest.raises(ValueError, match="keys_to_learn"): - _infer_subblocks_to_extract(checkpoint_dir, []) + brl._infer_subblocks_to_extract(checkpoint_dir, []) + + +def test_get_last_checkpoint_from_each_experiment_resolves_ckpts_symlinks( + tmp_path: Path, monkeypatch +): + puzzle_dir = tmp_path / "puzzle_dir" + ckpts_dir = puzzle_dir / "ckpts" + ckpts_dir.mkdir(parents=True) + + teacher_dir = ckpts_dir / "teacher" + bypass_dir = puzzle_dir / "bypass" / "bypass_runs" / "bypass_ffn" / "step-000010-ckpt" + pruned_dir = puzzle_dir / "pruning" / "pruned_ffn" + for checkpoint_dir in (teacher_dir, bypass_dir, pruned_dir): + checkpoint_dir.mkdir(parents=True) + (checkpoint_dir / "config.json").write_text("{}") + + (ckpts_dir / "bypass_ffn").symlink_to(bypass_dir, target_is_directory=True) + (ckpts_dir / "pruned_ffn").symlink_to(pruned_dir, target_is_directory=True) + monkeypatch.setattr(brl, "is_valid_decilm_checkpoint", lambda *args, **kwargs: True) + + discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir) + + assert discovered == {teacher_dir.resolve(), bypass_dir.resolve(), pruned_dir.resolve()} + + +def test_build_subblocks_df_prefers_bypass_rows_over_pruned_duplicates(tmp_path: Path, monkeypatch): + puzzle_dir = tmp_path / "puzzle_dir" + teacher_dir = puzzle_dir / "ckpts" / "teacher" + bypass_dir = puzzle_dir / "bypass" / "bypass_runs" / "bypass_ffn" / "step-000010-ckpt" + pruned_dir = puzzle_dir / "pruning" / "pruned_ffn" + + monkeypatch.setattr( + brl, + "_get_last_checkpoint_from_each_experiment", + lambda *args, **kwargs: {teacher_dir, pruned_dir, bypass_dir}, + ) + + def _construct_rows(checkpoint_dir: Path, *args, **kwargs): + if checkpoint_dir == teacher_dir: + return [] + return [ + { + "attention_checkpoint_dir": None, + "ffn_checkpoint_dir": str(checkpoint_dir), + "block_config": None, + "attention_config": None, + "ffn_config": FFNConfig(intermediate_size=256), + "block_idx": 0, + "block_repr": None, + "attention_repr": None, + "ffn_repr": None, + } + ] + + monkeypatch.setattr(brl, "_construct_subblock_rows_from_current_checkpoint", _construct_rows) + + subblocks_df = brl._build_subblocks_df( + master_puzzle_dir=puzzle_dir, + teacher_checkpoint_dir=teacher_dir, + add_ffn_no_ops=False, + add_attention_no_ops=False, + ) + + assert len(subblocks_df) == 1 + assert subblocks_df["ffn_checkpoint_dir"].item() == str(bypass_dir) + assert not pd.isna(subblocks_df["ffn_repr"].item()) From 43ed82ed540b56c77f6f6ff73f6d88440eb2be33 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Mon, 8 Jun 2026 10:37:43 +0200 Subject: [PATCH 3/8] Clarify Puzzletron bypass tutorial Signed-off-by: Sepehr Sameni --- examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md | 6 +++--- .../configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml | 7 ++++--- .../replacement_library/build_replacement_library.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md b/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md index 1f3ad6983df..70ae8a6395e 100644 --- a/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md +++ b/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md @@ -6,7 +6,7 @@ A minimal end-to-end demonstration that **bypass distillation improves quality** The teacher has 6 attention layers (each with `num_key_value_heads=2`) interleaved between Mamba and MoE-FFN blocks — **12 KV heads total** across the whole model. We compress to **9 KV heads (75% of teacher)** in two ways and compare: -1. **Without bypass** — replacement library uses Truncate-init weights (KV heads sliced from teacher; no further training). +1. **Without bypass** — replacement library uses the `PruneKVHeads` initialization from activation scoring (selected KV heads copied from the teacher; no further training). 2. **With bypass** — the bypass step runs ~50M tokens of per-block knowledge distillation, training a 1-KV-head variant per attention layer against the teacher. Both runs use the same MIP solver and the same constraint (`target_num_kv_heads: 9`), so MIP picks per attention layer from `{teacher 2-head, 1-head}`. FFN/MoE/Mamba blocks are copied verbatim from the teacher in both runs — only attention weights change. @@ -35,7 +35,7 @@ torchrun --nproc_per_node=8 examples/puzzletron/main.py \ --config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml ``` -This runs the 8-step puzzletron pipeline (convert → score pruning activations → prune → build replacement library → score replacements → MIP → realize). With `bypass:` added in Step B the pipeline grows to 9 steps; without it, the bypass step is skipped and progress prints `N/8`. Wall-clock: roughly **1h on 8×H100** for this KV-heads-only task (KV-head importance scoring is one forward pass via `IndependentKvHeadContributionHook`, much cheaper than iterative FFN-channel scoring). +This runs the no-bypass puzzletron pipeline (convert → score pruning activations → prune → build replacement library → score replacements → MIP → realize). The progress counter includes start/complete messages and prints `N/8`. With `bypass:` added in Step B the pipeline grows to 9 steps. Wall-clock: roughly **1h on 8×H100** for this KV-heads-only task (KV-head importance scoring is one forward pass via `IndependentKvHeadContributionHook`, much cheaper than iterative FFN-channel scoring). When the realize-model step finishes, the log lines at `${puzzle_dir}/log.txt` contain: @@ -80,7 +80,7 @@ Reducing total KV heads from 12 → 9 (75% of teacher) at fixed FFN/MoE/Mamba on | Run | `target_num_kv_heads` | `lm_loss` | `token_accuracy_top_1` | |------------------------------|----------------------:|----------:|-----------------------:| | Teacher | 12 | 0.5950 | 0.8468 | -| Pruned, **no bypass** (Truncate-init) | 9 | 0.6347 | 0.8373 | +| Pruned, **no bypass** (`PruneKVHeads` init) | 9 | 0.6347 | 0.8373 | | Pruned, **with bypass** (50M-token BLD) | 9 | **0.6055**| **0.8441** | **Bypass closes ~74% of the regression gap** at this compression budget: diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml index a1c63ac913a..412c8de2fdb 100644 --- a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml @@ -4,10 +4,11 @@ # Trains a single 1-KV-head variant per attention layer using per-block knowledge # distillation against the teacher (`subblock_attention` keys only — FFN/MoE/Mamba # blocks are frozen). The trained weights are saved into the replacement library -# and consumed by the MIP solver alongside the no_op variant. +# and consumed by the MIP solver alongside the teacher/no-op variant. # -# Tutorial budget: ~10M tokens (quick sanity, ~30 min on 4×H100). Increase -# `training_tokens` for a stronger bypass effect. +# Tutorial budget: 50M tokens for the bypass run used in the tutorial. Lower +# `training_tokens` for a quick sanity check, or increase it for a stronger +# bypass effect. # Runtime Configuration dtype: "bf16" diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py index 57a1de039b5..6c5fe860c64 100644 --- a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -214,7 +214,7 @@ def _build_subblocks_df( ) # Order the non-teacher checkpoints so that downstream `drop_duplicates(keep="first")` - # deterministically prefers bypass-trained subblocks over Truncate-init pruned ones + # deterministically prefers bypass-trained subblocks over untrained pruned ones # when both produce a row with the same architectural identifier. Without this, # `set` iteration order makes the choice random (hash-of-path) and we'd sometimes # discard the BLD-trained weights we just paid 30+ min to compute. From 3cb4f61d3ca686b784b0fd2fb460cd15f7653256 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Mon, 8 Jun 2026 11:02:46 +0200 Subject: [PATCH 4/8] Fix Puzzletron cache resume handling Signed-off-by: Sepehr Sameni --- .../torch/puzzletron/puzzletron_nas_plugin.py | 104 ++++++++++++++---- tests/gpu/torch/puzzletron/test_bypass.py | 20 ++-- .../puzzletron/test_puzzletron_nas_plugin.py | 91 +++++++++++++++ 3 files changed, 182 insertions(+), 33 deletions(-) create mode 100644 tests/unit/torch/puzzletron/test_puzzletron_nas_plugin.py diff --git a/modelopt/torch/puzzletron/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py index abe94b15b02..ab3ef1ace17 100644 --- a/modelopt/torch/puzzletron/puzzletron_nas_plugin.py +++ b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py @@ -20,9 +20,11 @@ and save pruned checkpoints, and by mtn.search() to perform the MIP-based NAS search. """ +import json from pathlib import Path import hydra +from safetensors import safe_open from torch import nn import modelopt.torch.utils.distributed as dist @@ -160,6 +162,68 @@ def _find_incomplete_bypass_runs(hydra_cfg, puzzle_dir: str | Path) -> list[str] return incomplete_runs +def _is_readable_safetensors_file(path: Path) -> bool: + if not path.is_file(): + return False + try: + with safe_open(str(path), framework="pt", device="cpu") as tensors: + list(tensors.keys()) + except Exception: + return False + return True + + +def _is_complete_anymodel_checkpoint(checkpoint_dir: Path) -> bool: + config_path = checkpoint_dir / "config.json" + index_path = checkpoint_dir / "model.safetensors.index.json" + if not config_path.is_file() or not index_path.is_file(): + return False + + try: + index = json.loads(index_path.read_text()) + except (json.JSONDecodeError, OSError): + return False + + weight_map = index.get("weight_map") + if not isinstance(weight_map, dict) or not weight_map: + return False + + for relative_weight_path in set(weight_map.values()): + if not isinstance(relative_weight_path, str): + return False + if not _is_readable_safetensors_file(checkpoint_dir / relative_weight_path): + return False + return True + + +def _scoring_output_dir(hydra_cfg) -> Path: + output_dir = getattr(hydra_cfg.scoring, "output_dir", None) + if output_dir is not None: + return Path(output_dir) + solutions_path = Path(hydra_cfg.scoring.solutions_path) + return solutions_path.with_name(f"{solutions_path.stem}--validation") + + +def _invalidate_scoring_cache(hydra_cfg) -> None: + scoring_output_dir = _scoring_output_dir(hydra_cfg) + if not scoring_output_dir.exists(): + return + + stale_paths = [scoring_output_dir / "teacher.json"] + stale_paths.extend(scoring_output_dir.glob("solution_*.json")) + stale_paths = [path for path in stale_paths if path.exists()] + for path in stale_paths: + path.unlink() + + if stale_paths: + mprint(f"Invalidated {len(stale_paths)} cached scoring result(s) in {scoring_output_dir}") + + +def _force_scoring_revalidation(hydra_cfg) -> None: + hydra_cfg.scoring.skip_existing_solutions = False + hydra_cfg.scoring.solutions_to_validate = None + + def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> ConvertReturnType: """1. Convert the model from HF format to AnyModel format. 2. Score the pruning activations. @@ -195,7 +259,7 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable teacher_dir = Path(config.puzzle_dir) / hf_ckpt_teacher_dir if dist.is_master(): - if (teacher_dir / "config.json").exists(): + if _is_complete_anymodel_checkpoint(teacher_dir): mprint( f"Puzzletron Progress {convert_step}/{N}: teacher checkpoint already exists, skipping conversion" ) @@ -233,32 +297,15 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv dist.barrier() # Step 3: Score pruning activations (distributed processing) - activations_log_dir = Path(hydra_cfg.pruning.activations_log_dir) - if activations_log_dir.exists() and any(activations_log_dir.glob("rank_*.pth")): - mprint( - f"Puzzletron Progress {score_step}/{N}: pruning activation scores already " - f"exist at {activations_log_dir} — delete this directory to re-score with " - f"the current config." - ) - dist.barrier() - else: - mprint(f"Puzzletron Progress {score_step}/{N}: scoring pruning activations (multi-gpu)") - launch_score_activations(hydra_cfg) + mprint(f"Puzzletron Progress {score_step}/{N}: scoring pruning activations (multi-gpu)") + launch_score_activations(hydra_cfg) # Step 4: Prune the model and save pruned checkpoints (single process) - pruned_ckpts_dir = Path(hydra_cfg.pruning.pruned_ckpts_output_dir) if dist.is_master(): - if pruned_ckpts_dir.exists() and any(pruned_ckpts_dir.iterdir()): - mprint( - f"Puzzletron Progress {prune_step}/{N}: pruned checkpoints already " - f"exist at {pruned_ckpts_dir} — delete this directory to re-prune with " - f"the current config." - ) - else: - mprint( - f"Puzzletron Progress {prune_step}/{N}: pruning the model and saving pruned checkpoints (single-gpu)" - ) - launch_prune_ckpt(hydra_cfg) + mprint( + f"Puzzletron Progress {prune_step}/{N}: pruning the model and saving pruned checkpoints (single-gpu)" + ) + launch_prune_ckpt(hydra_cfg) dist.barrier() # Step 5: Bypass distillation (optional, distributed processing) @@ -400,6 +447,7 @@ def run_search(self) -> None: f"Replacement library is stale: '{entry.name}/config.json' is newer than the existing library, will rebuild." ) break + library_was_built = False if dist.is_master(): if ( replacement_library_path.exists() @@ -414,7 +462,15 @@ def run_search(self) -> None: f"Puzzletron Progress {library_step}/{N}: building replacement library and subblock statistics (single-gpu)" ) launch_build_library_and_stats(hydra_cfg) + library_was_built = True dist.barrier() + library_was_built = dist.broadcast(library_was_built if dist.is_master() else None, src=0) + + if library_was_built: + if dist.is_master(): + _invalidate_scoring_cache(hydra_cfg) + dist.barrier() + _force_scoring_revalidation(hydra_cfg) # Calculate one block scores (distributed processing) mprint(f"Puzzletron Progress {scoring_step}/{N}: calculating one block scores (multi-gpu)") diff --git a/tests/gpu/torch/puzzletron/test_bypass.py b/tests/gpu/torch/puzzletron/test_bypass.py index ddfef355f40..9879a0ffe44 100644 --- a/tests/gpu/torch/puzzletron/test_bypass.py +++ b/tests/gpu/torch/puzzletron/test_bypass.py @@ -291,12 +291,14 @@ def _make_bypass_cfg_dict( return cfg -def _expected_experiment_id(bypass_cfg_dict: dict) -> str: +def _expected_experiment_id(hydra_cfg, bypass_cfg_dict: dict) -> str: """Compute the experiment_id that ``set_experiment_id`` will assign. - Avoids duplicating the formula in tests — uses the same function the runtime uses. + Avoids duplicating the formula in tests while preserving the top-level + teacher identity that the runtime includes in the hash. """ - cfg = OmegaConf.create({"bypass": copy.deepcopy(bypass_cfg_dict)}) + cfg = copy.deepcopy(hydra_cfg) + OmegaConf.update(cfg, "bypass", copy.deepcopy(bypass_cfg_dict), merge=False) set_experiment_id(cfg) return cfg.bypass.experiment_id @@ -427,7 +429,7 @@ def _test_bypass_block_pruning_job( dist.barrier() if rank == 0: - expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + expected_experiment_id = _expected_experiment_id(hydra_cfg, bypass_cfg_dict) experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id @@ -517,7 +519,7 @@ def _test_bypass_kv_head_compression_job( dist.barrier() if rank == 0: - expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + expected_experiment_id = _expected_experiment_id(hydra_cfg, bypass_cfg_dict) experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id @@ -618,7 +620,7 @@ def _test_bypass_multi_config_sequential_job( sub_cfg = copy.deepcopy(bypass_cfg_dict) sub_cfg["model"]["model_config_overrides"] = sub["model_config_overrides"] sub_cfg["experiment_id"] = None - expected_ids.append(_expected_experiment_id(sub_cfg)) + expected_ids.append(_expected_experiment_id(hydra_cfg, sub_cfg)) for experiment_id in expected_ids: experiment_dir = puzzle_dir / "bypass/bypass_runs" / experiment_id @@ -715,7 +717,7 @@ def _test_bypass_resume_from_checkpoint_job( bypass_distillation.launch_bypass_distillation(hydra_cfg) dist.barrier() - expected_experiment_id = _expected_experiment_id(phase1_cfg) + expected_experiment_id = _expected_experiment_id(hydra_cfg, phase1_cfg) experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id if rank == 0: @@ -851,7 +853,7 @@ def _test_bypass_subblock_modes_job( dist.barrier() if rank == 0: - expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + expected_experiment_id = _expected_experiment_id(hydra_cfg, bypass_cfg_dict) experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id # `start-step-*` is the pre-training snapshot (saved when # save_checkpoint_before_training=True). The post-training snapshot @@ -997,7 +999,7 @@ def _test_bypass_then_build_library_job( dist.barrier() if rank == 0: - expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + expected_experiment_id = _expected_experiment_id(hydra_cfg, bypass_cfg_dict) ckpts_dir = puzzle_dir / "ckpts" # 1. The realize step must have created a symlink for this bypass run. diff --git a/tests/unit/torch/puzzletron/test_puzzletron_nas_plugin.py b/tests/unit/torch/puzzletron/test_puzzletron_nas_plugin.py new file mode 100644 index 00000000000..defdcf6dc8f --- /dev/null +++ b/tests/unit/torch/puzzletron/test_puzzletron_nas_plugin.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Puzzletron NAS orchestration helpers.""" + +import json +from pathlib import Path + +from omegaconf import OmegaConf +from safetensors.torch import save_file +from torch import tensor + +from modelopt.torch.puzzletron.puzzletron_nas_plugin import ( + _force_scoring_revalidation, + _invalidate_scoring_cache, + _is_complete_anymodel_checkpoint, +) + + +def test_complete_anymodel_checkpoint_rejects_config_only(tmp_path: Path): + checkpoint_dir = tmp_path / "ckpt" + checkpoint_dir.mkdir() + (checkpoint_dir / "config.json").write_text("{}") + + assert not _is_complete_anymodel_checkpoint(checkpoint_dir) + + +def test_complete_anymodel_checkpoint_requires_indexed_weights(tmp_path: Path): + checkpoint_dir = tmp_path / "ckpt" + subblocks_dir = checkpoint_dir / "subblocks_safetensors" + subblocks_dir.mkdir(parents=True) + weight_path = subblocks_dir / "embeddings.safetensors" + (checkpoint_dir / "config.json").write_text("{}") + save_file({"model.embed_tokens.weight": tensor([1.0])}, weight_path) + (checkpoint_dir / "model.safetensors.index.json").write_text( + json.dumps( + { + "metadata": {"format": "pt"}, + "weight_map": { + "model.embed_tokens.weight": "subblocks_safetensors/embeddings.safetensors" + }, + } + ) + ) + + assert _is_complete_anymodel_checkpoint(checkpoint_dir) + + weight_path.unlink() + + assert not _is_complete_anymodel_checkpoint(checkpoint_dir) + + +def test_invalidate_scoring_cache_removes_validation_jsons(tmp_path: Path): + output_dir = tmp_path / "single_sequence_replacement_solutions--validation" + output_dir.mkdir() + teacher_path = output_dir / "teacher.json" + solution_path = output_dir / "solution_0.json" + unrelated_path = output_dir / "notes.json" + for path in (teacher_path, solution_path, unrelated_path): + path.write_text("{}") + + cfg = OmegaConf.create({"scoring": {"output_dir": str(output_dir)}}) + + _invalidate_scoring_cache(cfg) + + assert not teacher_path.exists() + assert not solution_path.exists() + assert unrelated_path.exists() + + +def test_force_scoring_revalidation_ignores_existing_solutions(): + cfg = OmegaConf.create( + {"scoring": {"skip_existing_solutions": True, "solutions_to_validate": [0]}} + ) + + _force_scoring_revalidation(cfg) + + assert not cfg.scoring.skip_existing_solutions + assert cfg.scoring.solutions_to_validate is None From a94730385cb282db6290e22fcc7dea7c0535d676 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Mon, 8 Jun 2026 11:32:52 +0200 Subject: [PATCH 5/8] Trigger CI Signed-off-by: Sepehr Sameni From e8b73f67e3f42d6e82bcf378f8fac29148de5be5 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Mon, 8 Jun 2026 11:48:20 +0200 Subject: [PATCH 6/8] Document HF Hub optional import Signed-off-by: Sepehr Sameni --- modelopt/torch/puzzletron/puzzletron_nas_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/puzzletron/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py index ab3ef1ace17..8eccaa896e7 100644 --- a/modelopt/torch/puzzletron/puzzletron_nas_plugin.py +++ b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py @@ -276,6 +276,7 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv # Auto-download from HuggingFace if path doesn't exist locally input_model_path = config.input_model_path if not Path(input_model_path).exists(): + # Guard optional dependency: only require huggingface_hub for HF auto-downloads. from huggingface_hub import snapshot_download if input_model_path.startswith("https://huggingface.co/"): From 7b128ec92808f076bbea709498002ac62a5a1df6 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Tue, 9 Jun 2026 10:02:40 +0200 Subject: [PATCH 7/8] Fix bypass experiment ID path canonicalization Signed-off-by: Sepehr Sameni --- .../puzzletron/bypass_distillation/bypass_utils.py | 6 +++--- tests/unit/torch/puzzletron/test_bypass_utils.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py index 6baf42c4c7a..da4a875d618 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py @@ -124,9 +124,9 @@ def _teacher_dir_identity(cfg: DictConfig) -> str | None: if teacher_dir is None: return None teacher_dir = str(teacher_dir) - if teacher_dir.startswith("~"): - return str(Path(teacher_dir).expanduser()) - return teacher_dir + if "://" in teacher_dir: + return teacher_dir + return str(Path(teacher_dir).expanduser()) def get_bypass_run_identity(cfg: DictConfig) -> dict[str, Any]: diff --git a/tests/unit/torch/puzzletron/test_bypass_utils.py b/tests/unit/torch/puzzletron/test_bypass_utils.py index 4cca9ff5499..1100a27acbe 100644 --- a/tests/unit/torch/puzzletron/test_bypass_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_utils.py @@ -192,6 +192,18 @@ def test_config_fingerprint_and_experiment_id_canonicalize_keys_to_learn(): assert cfg_a.bypass.experiment_id == cfg_b.bypass.experiment_id +def test_config_fingerprint_and_experiment_id_canonicalize_teacher_path(): + cfg_a = _experiment_cfg("subblock_attention") + cfg_b = _experiment_cfg("subblock_attention") + cfg_b.teacher_dir = f"{cfg_a.teacher_dir}/" + + assert get_bypass_config_fingerprint(cfg_a) == get_bypass_config_fingerprint(cfg_b) + + set_experiment_id(cfg_a) + set_experiment_id(cfg_b) + assert cfg_a.bypass.experiment_id == cfg_b.bypass.experiment_id + + def test_experiment_id_uses_teacher_source_not_dataset_path(): cfg_a = _experiment_cfg("subblock_attention") cfg_b = _experiment_cfg("subblock_attention") From fbfa1663139d27d625c7ab092d96bf0776eae04a Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Tue, 9 Jun 2026 11:20:58 +0200 Subject: [PATCH 8/8] Fix bypass subblock mode checkpoint diff Signed-off-by: Sepehr Sameni --- tests/gpu/torch/puzzletron/test_bypass.py | 59 ++++++++++++----------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/tests/gpu/torch/puzzletron/test_bypass.py b/tests/gpu/torch/puzzletron/test_bypass.py index 9879a0ffe44..445332f60e6 100644 --- a/tests/gpu/torch/puzzletron/test_bypass.py +++ b/tests/gpu/torch/puzzletron/test_bypass.py @@ -64,11 +64,13 @@ import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts import modelopt.torch.puzzletron.replacement_library.build_replacement_library as build_lib import modelopt.torch.utils.distributed as dist -from modelopt.torch.puzzletron.anymodel import convert_model +from modelopt.torch.puzzletron.anymodel import ModelDescriptorFactory, convert_model from modelopt.torch.puzzletron.bypass_distillation.bypass_checkpoint_utils import ( find_latest_run_dir, ) from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import set_experiment_id +from modelopt.torch.puzzletron.tools.checkpoint_utils import load_state_dict +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir # --------------------------------------------------------------------------- @@ -794,9 +796,8 @@ def test_bypass_subblock_modes( For each (family, keys_to_learn) cell: - Run bypass for 2 steps with that keys_to_learn. - - After training, load the saved stitched_module state dict. - - Compare against the teacher-derived initialization (``copied_dir`` of - the bypass experiment, which holds the post-init pre-train weights): + - After training, load the saved HF-format checkpoints. + - Compare against the start checkpoint, which holds the post-init pre-train weights: * subblock_ffn → only FFN keys differ from init; attention identical. * subblock_attention → only attention keys differ; FFN identical. * entire_block → both differ. @@ -872,35 +873,35 @@ def _test_bypass_subblock_modes_job( "no post-training checkpoint was written." ) - # Diff every saved stitched module's state dict between start (pre-train) - # and end (post-train). Block names look like ``block_0``, ``block_1``… - ffn_token_set = {".mlp.", ".experts."} # Llama vs GPT-OSS naming - attn_token = ".self_attn." - - def _key_kind(key: str) -> str: - if attn_token in key: - return "attn" - if any(t in key for t in ffn_token_set): - return "ffn" - return "other" + # Diff the HF-format checkpoint tensors. ``stitched/`` stores only + # optimizer/scaler state; model weights live in the checkpoint root. + start_state = load_state_dict(start_dir) + end_state = load_state_dict(end_dir) + descriptor = ModelDescriptorFactory.get(hydra_cfg.descriptor) + model_config = load_model_config(start_dir) + lm_config = descriptor.get_language_model_config(model_config) + weight_groups = descriptor.get_weight_groups( + start_state.keys() & end_state.keys(), + lm_config.num_hidden_layers, + ) + key_kinds = { + key: "attn" if group_name.endswith("_attention") else "ffn" + for group_name, keys in weight_groups.items() + if group_name.endswith(("_attention", "_ffn")) + for key in keys + } ffn_changed = False attn_changed = False - for state_dict_path in (start_dir / "stitched").glob("block_*.state_dict.pth"): - end_path = end_dir / "stitched" / state_dict_path.name - if not end_path.exists(): + for key in start_state.keys() & end_state.keys(): + kind = key_kinds.get(key) + if kind is None: continue - start_state = torch.load(state_dict_path, map_location="cpu", weights_only=True) - end_state = torch.load(end_path, map_location="cpu", weights_only=True) - for key in start_state.keys() & end_state.keys(): - kind = _key_kind(key) - if kind == "other": - continue - changed = not torch.equal(start_state[key], end_state[key]) - if kind == "ffn" and changed: - ffn_changed = True - if kind == "attn" and changed: - attn_changed = True + changed = not torch.equal(start_state[key], end_state[key]) + if kind == "ffn" and changed: + ffn_changed = True + if kind == "attn" and changed: + attn_changed = True if keys_to_learn == "subblock_ffn": assert ffn_changed, f"subblock_ffn should change FFN weights ({hf_model_name})"