Skip to content
95 changes: 95 additions & 0 deletions examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Bypass Distillation Tutorial: Nemotron-3-Nano-30B-A3B (KV-heads-only)

A minimal end-to-end demonstration that **bypass distillation improves quality** at the same compression budget. The setup is a **toy pruning task on a real production model** — we compress only KV heads (12 → 9, a modest 25% reduction) so a single comparison surfaces the bypass benefit cleanly without needing extensive downstream evaluation. The model itself ([Nemotron-3-Nano-30B-A3B-Base-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16)) is a real 30B-A3B MoE-Mamba hybrid, not a tiny stand-in.

## What this tutorial does

The teacher has 6 attention layers (each with `num_key_value_heads=2`) interleaved between Mamba and MoE-FFN blocks — **12 KV heads total** across the whole model. We compress to **9 KV heads (75% of teacher)** in two ways and compare:

1. **Without bypass** — replacement library uses the `PruneKVHeads` initialization from activation scoring (selected KV heads copied from the teacher; no further training).
2. **With bypass** — the bypass step runs ~50M tokens of per-block knowledge distillation, training a 1-KV-head variant per attention layer against the teacher.

Both runs use the same MIP solver and the same constraint (`target_num_kv_heads: 9`), so MIP picks per attention layer from `{teacher 2-head, 1-head}`. FFN/MoE/Mamba blocks are copied verbatim from the teacher in both runs — only attention weights change.

**Metrics:** `lm_loss` and `token_accuracy_top_1` measured against the same held-out dataset by the realize-model step (printed automatically to `puzzle_dir/log.txt`).

## Hardware & install

- 8×H100 80GB (the teacher needs ≥60 GiB for activation scoring on a 4096 context).
- Container: `nvcr.io/nvidia/nemo:26.04` or later.
- `pip install -e ".[dev]"` from the modelopt repo root.
- Mamba kernels (required by Nemotron-3-Nano's hybrid backbone):

```bash
pip install mamba-ssm[causal-conv1d] --no-build-isolation
```

- HF auth set up so the model is downloadable: `huggingface-cli login`.

## Step A — pipeline without bypass

Edit `examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml` to point `puzzle_dir` and `dataset_path` at writable locations, then:

```bash
torchrun --nproc_per_node=8 examples/puzzletron/main.py \
--config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml
```

This runs the no-bypass puzzletron pipeline (convert → score pruning activations → prune → build replacement library → score replacements → MIP → realize). The progress counter includes start/complete messages and prints `N/8`. With `bypass:` added in Step B the pipeline grows to 9 steps. Wall-clock: roughly **1h on 8×H100** for this KV-heads-only task (KV-head importance scoring is one forward pass via `IndependentKvHeadContributionHook`, much cheaper than iterative FFN-channel scoring).

When the realize-model step finishes, the log lines at `${puzzle_dir}/log.txt` contain:

```text
validate_model_with_kl_div(model_name='teacher', ...)
Average losses = {'lm_loss': ..., 'token_accuracy_top_1': ..., 'token_accuracy_top_5': ..., 'token_accuracy_top_10': ...}
...
validate_model_with_kl_div(model_name='solution_0', ...)
Average losses = {..., 'token_accuracy_top_1': ..., ...}
```

Record the teacher's `token_accuracy_top_1` and `solution_0`'s `token_accuracy_top_1`. **Move or rename `${puzzle_dir}/single_sequence_replacement_solutions--validation/` and `${puzzle_dir}/mip/` aside** before Step B if you want to keep the no-bypass artifacts — Step B reuses the same `puzzle_dir` and the library/scoring/MIP outputs will be overwritten.

## Step B — pipeline with bypass

Use the bypass-enabled config, which overrides the base config's empty `- bypass:` entry with `bypass: defaults`:

```yaml
defaults:
- nemotron-3-nano-30b-a3b
- override bypass: defaults
- _self_
```

Run the bypass config:

```bash
torchrun --nproc_per_node=8 examples/puzzletron/main.py \
--config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b-with-bypass.yaml
```

Skip-if-done caching reuses Step A's converted teacher checkpoint, activation scores, and pruned checkpoints. Only Step 5 (bypass distillation, ~50M tokens) and the downstream library/scoring/MIP rerun.

Bypass writes its outputs under `${puzzle_dir}/bypass/bypass_runs/<bypass_experiment_id>/` and creates a symlink `${puzzle_dir}/ckpts/<bypass_experiment_id>` that the replacement library builder picks up automatically.

Capture `solution_0`'s `token_accuracy_top_1` from the new realize-model log section.

## Results

Reducing total KV heads from 12 → 9 (75% of teacher) at fixed FFN/MoE/Mamba on Nemotron-3-Nano-30B-A3B-Base-BF16:

| Run | `target_num_kv_heads` | `lm_loss` | `token_accuracy_top_1` |
|------------------------------|----------------------:|----------:|-----------------------:|
| Teacher | 12 | 0.5950 | 0.8468 |
| Pruned, **no bypass** (`PruneKVHeads` init) | 9 | 0.6347 | 0.8373 |
| Pruned, **with bypass** (50M-token BLD) | 9 | **0.6055**| **0.8441** |

**Bypass closes ~74% of the regression gap** at this compression budget:

- `lm_loss` gap to teacher: `0.0397` without bypass → `0.0105` with bypass — bypass recovers **74%**.
- `token_accuracy_top_1` gap to teacher: `0.0095` without bypass → `0.0027` with bypass — bypass recovers **72%**.

For 50M tokens of per-block KD, that's a substantial lift on a real 30B-A3B teacher.

## Going further: full accuracy recovery

Bypass distillation is Stage 1 of the PUZZLE pipeline — local, per-block KD that tightens the replacement library. For larger compression targets (or more aggressive KV pruning) you'll want Stage 2: **global knowledge distillation** on the realized student. See [`examples/pruning/puzzletron/`](../pruning/puzzletron/) for the Megatron-Bridge recipe and concrete MMLU recovery numbers.
2 changes: 1 addition & 1 deletion examples/puzzletron/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ To use the Puzzle algorithm effectively, we need to specify the target number of

In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md).

> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md).
> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100, Nemotron-3-Nano-30B-A3B-Base-BF16 on 8x H100 — see the [bypass distillation tutorial](Nemotron-3-Nano-30B-A3B-Base-BF16.md)). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md).

## Environment

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# @package bypass
# Bypass Distillation Configuration
# This config defines parameters for blockwise local distillation (BLD),
# which trains alternative transformer block configurations using per-block
# knowledge distillation from a teacher model.

# Runtime Configuration
dtype: "bf16" # Model precision: bf16 for efficiency, fp32 for stability
seed: 42 # Random seed for reproducibility

# Experiment Tracking
experiment_id: # Unique identifier for this experiment. Will be dynamically set
experiment_dir: # Directory for this experiment. Will be dynamically set
iter_num: 1 # Current iteration number
step_num: 1 # Current step number within iteration
token_count: 0 # Token count tracker (auto-updated during training)

# Data Configuration
data:
data_column: "messages"
block_size: 512 # Sequence length (tokens per sample)
bos_rate: 0.5
fim_rate: 0
fim_spm_rate: 0
source_datasets_to_discard: []
load_from_disk: true # Load preprocessed data from disk or from stream
keep_in_memory: false
val_dataset_name: valid
max_eval_samples: 4
eval_samples_per_process: # Samples per GPU during distributed eval (auto if null)
shuffle_train_data_seed: ${random_int:0,9999} # Seed for shuffling train data

# Training Configuration
training:
learning_rate: 1e-4 # Initial learning rate (1e-4 = 0.0001)
training_tokens: 1e+4 # Total training tokens (10K tokens - sanity check)
micro_batch_size: 2
val_micro_batch_size: 1
warmup_ratio: 0.05
warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.grad_accumulation_steps},${.warmup_ratio}} # Auto-calculated warmup steps
min_lr_factor: 1e-5
grad_accumulation_steps: 1
skip_first_batches: 0 # Use for debugging or to skip few batches which cause crashes or optimization issues.
weight_decay: 0.1
decay_lr: true
beta1: 0.9
beta2: 0.95
use_grad_scaling: false
grad_clip: 1.0
grad_clip_type: norm
clipping_count: 0
log_interval: 5
eval_interval: 5

# Model Loading Configuration
resume_checkpoint_path: # Path to resume training from checkpoint
find_last_ckpt_for_resume: true # Auto-resume by finding last checkpoint (bool)
parameter_count:
init_checkpoint_path: # Path to initialize weights from

model:
student_weights_dtype: "bf16" # Student model weight precision

model_overrides:
delete_old_checkpoints: true # Clean up old checkpoints to save disk space
save_interval_seconds: 12900 # Save checkpoint every ~3.5 hours
save_interval: 1e+9 # Save checkpoint every 1B steps (effectively disabled)
save_checkpoint_when_done: true # Save final checkpoint when training completes

# Architecture modifications for student model
model_config_overrides:
ffn:
- intermediate_size:
no_op: # Disable FFN entirely (true/false)
attention:
- num_key_value_heads: # Number of kv-heads (for GQA)
no_op: # Disable attention entirely (true/false)

# Model Factory Configuration - Controls student model creation and initialization
model_factory:
factory: bypass_factory_fn # Unified factory supporting all layer types
block_loss_func: normalized_mse_loss # Loss function for comparing teacher/student blocks. vectorwise_normalized_mse_loss / batched_normalized_mse_loss / normalized_mse_loss
gqa_init_mode: AverageKV # How to initialize K/V heads in GQA. All options here: GQAInitMode
mlp_init_mode: Truncate # MLP initialization. All options here: MlpInitMode
mlp_init_config: # Configuration for MLP initialization (if needed)
activations_log_dir: # Directory with activation statistics (required for PruneByActivationsLog)
linear_init_mode: FromTeacher # How to initialize linear layers: FromTeacher, Random, etc.
submodule_for_loss_calculation: # Specific submodule for loss calc.
keys_to_learn: # Subblock(s) to train: entire_block, subblock_attention, subblock_ffn, subblock_mamba, or a list of those.

# Validation Configuration
disable_initial_validate: false
validate_teacher_model: true
validate_student_model: true
disable_validation: false # Enable validation to exercise all code paths
best_val_loss: 1e+9 # Track best validation loss achieved

# Performance Optimization
compile: false # Use PyTorch compilation
disable_fa2: false # Disable Flash Attention 2 (false = use FA2 if available)
teacher_model_load_on_cpu: false

# Checkpoint Management
save_checkpoint_before_training: false # Save initial checkpoint before training
disable_checkpoint_save: false # Disable all checkpoint saving
save_best_ckpt: true # Save checkpoint when validation improves
kill_after_first_save: false # Exit after first checkpoint save (for testing)
realize_best_or_latest: "best"

wandb_log: false
wandb:
project:
entity:

# Multiple bypass configurations to train sequentially.
# Each entry overrides model.model_config_overrides and optionally model_factory.keys_to_learn.
# If empty or absent, a single run uses the settings above.
configs:
- model_config_overrides:
ffn:
- intermediate_size: 3072
attention:
- num_key_value_heads: 8
keys_to_learn: subblock_ffn
- model_config_overrides:
ffn:
- intermediate_size: 5888
attention:
- num_key_value_heads: 8
keys_to_learn: subblock_ffn
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
defaults:
- pruning: kv_heads_pruning
- scoring: ../validate_solutions_defaults
- realize_model: ../validate_solutions_defaults
- bypass:
- override hydra/hydra_logging: disabled
- _self_

puzzle_dir: ???
descriptor: nemotron_h
teacher_dir: ${puzzle_dir}/ckpts/teacher/
replacement_library_path: ${puzzle_dir}/replacement_library.json
dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2

skip_realize_model: false

# KV-heads-only pruning: lock off FFN/MoE-side variants. The replacement library
# exposes {teacher 2-head, 1-head} per attention layer; FFN and Mamba
# blocks are copied verbatim from the teacher.
build_replacement_library:
add_ffn_no_ops: false
add_attention_no_ops: false

calc_subblock_stats:
batch_sizes: [64, 96, 128]
prefill_seq_len: 4096
generation_seq_len: 4096
num_active_tokens_override: # Optional override for sequence lengths
prefill_queue_size: 0
allocate_prefill_query: false
benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
merge_with_existing_stats: false
subblock_stats_filename: "subblock_stats.json"
moe_stats_filename: "moe_stats.json"
runtime_stats:
backend: trt_torch

scoring:
descriptor: ${descriptor}
solutions_to_validate:
skip_existing_solutions: true

replacement_library_path: ${replacement_library_path}
solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
teacher_dir: ${to_path:${teacher_dir}}
output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation

eval_samples: 128
micro_batch_size: 1
seed: 42
shuffle_seed: 444
dataset_path: ${dataset_path}

mip:
single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
gathered_metrics_path:
puzzle_profile:

objective: metrics.cosine_embedding_loss_hidden_states
bigger_is_better: false

subblock_stats_args:
- batch_size: 96
weights_dtype: torch.bfloat16
activations_dtype: torch.bfloat16
kv_cache_dtype: torch.bfloat16

report_additional_costs:
- stats.memory_mib
- stats.num_params
- stats.num_kv_heads
- stats.has_attention
- stats.has_ffn
- stats.kv_cache_memory_mib
- stats.attention_memory_mib
- stats.ffn_memory_mib
- stats.ffn_num_params
- stats.attention_num_params

human_constraints:
target_num_kv_heads: 9 # toy KV-heads-only target; see nemotron-3-nano-30b-a3b.yaml

mip_constraints:
metric_overrides:
max_seconds_per_solution: 60

realize_model:
descriptor: ${descriptor}
teacher_dir: ${to_path:${teacher_dir}}
tokenizer_name: ${to_path:${teacher_dir}}
replacement_library_path: ${replacement_library_path}
save_models: true
solutions_path: # Filled dynamically

# Validate params
skip_validation: false
eval_samples: 128
micro_batch_size: 1
seed: 42
shuffle_seed: 444
dataset_path: ${dataset_path}

nccl_timeout_minutes: ${timedelta_minutes:10}

# This section redirects Hydra outputs
hydra:
run:
dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
Loading
Loading