Skip to content

add example script for depth importance estimation#1016

Open
chochowski wants to merge 4 commits intomainfrom
depth_importance_example
Open

add example script for depth importance estimation#1016
chochowski wants to merge 4 commits intomainfrom
depth_importance_example

Conversation

@chochowski
Copy link

@chochowski chochowski commented Mar 10, 2026

What does this PR do?

Type of change: ? new-feature

The script goes through the GPTModel transformer patching each block as no-op one-by-one estimating how much removing this block affects the final output representation.

Usage

PYTHONPATH=`pwd`:$PYTHONPATH python  -m torch.distributed.run --nproc_per_node=2 examples/pruning/rank_layer_importance.py --hf_model_name_or_path /.../nvidia/NVIDIA-Nemotron-Nano-12B-v2 --trust_remote_code --calib_dataset_name wikitext  --num_layers_in_first_pipeline_stage 31 --num_layers_in_last_pipeline_stage 31 --num_layers 62 --drop_layers 3 4 5 6 7 8 9 10 11

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features
    • Added a layer importance estimation utility for model pruning that measures per-layer significance by comparing outputs with and without specific layers.
    • Supports distributed execution and calibration-based evaluation.
    • Enables iterative layer ranking and selective layer dropping for pruning analysis.

@chochowski chochowski requested a review from a team as a code owner March 10, 2026 15:17
@chochowski chochowski requested a review from jenchen13 March 10, 2026 15:17
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 10, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 10, 2026

📝 Walkthrough

Walkthrough

A new layer importance estimation utility is introduced for model pruning. The script rank_layer_importance.py measures per-layer importance by comparing MSE of final hidden representations with and without each layer during a calibration run. Documentation is updated with usage examples.

Changes

Cohort / File(s) Summary
Documentation
examples/pruning/README.md
Adds subsection describing the rank_layer_importance.py utility, including two-node distributed run example and variant for dropping specific layers with example commands.
Layer Importance Estimation
examples/pruning/rank_layer_importance.py
New script implementing layer-wise importance scoring for pruning. Includes hook-based importance tracking (LastHiddenImportanceHook), forward-pass patching functions for multiple architectures (MLP, attention, Mamba, transformer blocks), normalized MSE loss computation, argument parsing, score collection, and a main orchestration function (estimate_layer_importance) that manages calibration loops and per-layer scoring with distributed aggregation.

Sequence Diagram

sequenceDiagram
    participant User
    participant Script as rank_layer_importance.py
    participant Model
    participant CalibData as Calibration Dataset
    participant Hook as LastHiddenImportanceHook
    participant Scorer as Score Collector

    User->>Script: parse args (model, dataset, drop_layers)
    Script->>Model: load HF model + prepare
    Script->>Model: setup_gates (attach importance hooks)
    Script->>Model: initialize LM head reference
    
    rect rgba(100, 150, 255, 0.5)
        Note over Script,Hook: Calibration Phase
        Script->>CalibData: load calibration samples
        loop for each batch
            CalibData->>Model: feed forward
            Model->>Hook: collect reference hidden states
            Hook->>Hook: store reference outputs
        end
    end
    
    rect rgba(100, 255, 150, 0.5)
        Note over Script,Scorer: Per-Layer Scoring Phase
        loop for each layer (excluding drop_layers)
            Script->>Model: patch layer to noop_forward
            CalibData->>Model: run forward pass
            Model->>Hook: collect layer output
            Hook->>Hook: compute MSE vs reference
            Hook->>Scorer: accumulate per-layer metrics
            Script->>Model: unpatch layer
        end
    end
    
    rect rgba(255, 200, 100, 0.5)
        Note over Script,Scorer: Finalization
        Script->>Scorer: collect_scores (aggregate DP results)
        Scorer->>Scorer: compute final per-layer importance
        Script->>Script: serialize scores file
    end
    
    Script->>User: output per-layer importance rankings
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 8.47% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: a new example script for estimating layer/depth importance. It directly aligns with the PR's primary objective of adding depth importance estimation functionality.
Security Anti-Patterns ✅ Passed The PR contains no security anti-patterns: no unsafe torch.load, numpy.load, eval/exec on external input, or misconfigurations; trust_remote_code defaults to False.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch depth_importance_example
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/pruning/depth_ranking.py`:
- Around line 303-320: In collect_scores(), avoid hardcoding a length-10 CUDA
tensor and metric assumptions: use the provided use_metric argument when reading
stats, infer device and per-sample vector lengths dynamically, and replace
torch.zeros((10,)).cuda() with a padding strategy (e.g., pad_sequence or
creating zeros using the same device and shape as stat) so torch.stack(res)
won't fail for different DP/batch sizes; ensure aggregation (median/mean)
computes over the correct dim and that drop_group/drop_blocks are applied to the
sorted_indices from the selected metric; apply the same fixes in the related
block around the other occurrence (lines ~510-514) so both places respect
variable sample counts, device, and the caller's metric/drop settings.
- Line 1: Add the standard NVIDIA Apache 2.0 license header to the top of
depth_ranking.py (replace the solitary copyright line), using the full Apache
2.0 block required by the repository policy so the file begins with the complete
license header comment instead of just the copyright notice.
- Around line 16-17: The example imports test-only utilities
(get_mcore_gpt_model, set_seed and the inference helpers) from _test_utils which
couples the example to the test tree; replace those imports with supported
runtime equivalents or relocate the helper implementations into the examples
package. Concretely, remove references to
_test_utils.torch.megatron.models.get_mcore_gpt_model and
_test_utils.torch.misc.set_seed and either import model-building and seed
utilities from the public API (or a stable examples.helpers module), and copy
any inference helper functions used on lines ~93-96 into
examples/pruning/helpers.py and import them from there so the example no longer
depends on tests.
- Around line 78-80: Wrap the transformer_engine imports in a try/except and set
a flag (e.g., HAS_TE) so importing the example won't fail when
transformer_engine is not installed, and import/alias torch.nn.LayerNorm as the
fallback norm; then update setup_gates() to detect and handle both TE norm
classes (RMSNorm/LayerNorm from transformer_engine when HAS_TE) and standard
PyTorch norms (torch.nn.LayerNorm and any other supported norm types) by
creating the same logits_gate_list entries for either case; additionally, add a
defensive check before using model.logits_gate_list[0] to raise a clear error or
construct a default gate when logits_gate_list is empty, and optionally ensure
main() pins transformer_impl to the TE backend when intended (or document that
transformer_engine must be installed).
- Around line 353-384: The layer-to-rank mapping is wrong because
layer_id_in_this_rank() assumes equal split via num_layers_per_pp and offset;
update it to use Megatron's true pipeline offsets (e.g., call the utility used
in the distill plugin like TransformerLayer._get_layer_offset() or compute
start/end using config.num_layers_in_first_pipeline_stage and
config.num_layers_in_last_pipeline_stage together with
get_pipeline_model_parallel_rank()/get_pipeline_model_parallel_world_size()) so
that global layer IDs map to local indices correctly; adjust
layer_id_in_this_rank() to return the local index or -1 based on that computed
start/end range, and ensure patch_model/unpatch_model still use the returned
local index consistently.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f0060447-cc9c-4f31-b6c7-4a5eebc9255f

📥 Commits

Reviewing files that changed from the base of the PR and between cbab377 and 7a29122.

📒 Files selected for processing (1)
  • examples/pruning/depth_ranking.py

Signed-off-by: mchochowski <mchochowski@nvidia.com>
@chochowski chochowski changed the title [DRAFT] add example script for depth importance estimation add example script for depth importance estimation Mar 13, 2026
Signed-off-by: mchochowski <mchochowski@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/pruning/rank_layer_importance.py`:
- Around line 245-248: The topology flags (--num_layers,
--num_layers_in_first_pipeline_stage, --num_layers_in_last_pipeline_stage) are
treated as required but are optional and the current layer_id_in_this_rank()
math assumes uniform per-rank spans causing incorrect global IDs; make these
args required via parser.add_argument(..., required=True) or add validation
after parsing to compute per-rank layer spans from the loaded model (e.g.,
derive total num_layers from model.num_layers or model.config and compute
per-rank start/end counts using world_rank/world_size and the provided
first/last stage sizes), then replace the existing ternary/offset logic with
computed per-rank start_index/end_index used by layer_id_in_this_rank() and the
loops over range(args.num_layers) so that patches target the correct global
layer IDs.
- Around line 166-183: The hook_fn currently appends full GPU tensors to
self.reference_hidden in the reference-loading phase (self.reference_load),
which causes OOM; instead either (a) avoid caching full tensors by performing
the comparison immediately per calibration batch (compute
normalized_mse_loss_per_sample between hidden_out and the corresponding
reference and append results to self.hidden_distance/self.logits_distance), or
(b) if caching is required, move/slice the activation off device before storing
(e.g., hidden_out = hidden_out.detach().cpu() or store a reduced
summary/selected slices) and for logits use teacher_logits =
self.lm_head(...).detach().cpu() so memory stays on CPU/disk; update references
to self.reference_hidden, hook_fn, normalized_mse_loss_per_sample,
self.hidden_distance, self.logits_distance, and self.reference_load accordingly.
- Around line 81-82: Remove the unused module-level CUDA allocations kl_loss and
mse_loss (the torch.nn.KLDivLoss(...) and torch.nn.MSELoss(...) calls) which run
at import time and call .cuda(); delete these lines entirely since they are
never referenced and cause unwanted GPU initialization before dist.setup(); if
loss objects are needed later, instantiate them locally after distributed setup
(e.g., inside main or training functions) and avoid calling .cuda() at import
time.

In `@examples/pruning/README.md`:
- Around line 233-240: Update the two command fenced blocks around the usage of
rank_layer_importance.py to include a blank line before and after each fence and
add the language tag "bash" after the opening ```; also correct the typo
"interative" to "iterative" in the explanatory sentence that references
iterative estimation (search for rank_layer_importance.py and the paragraph
mentioning interative/iterative).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a5d03e5d-3a39-4448-a24c-be180264fe5e

📥 Commits

Reviewing files that changed from the base of the PR and between 7a29122 and c92db19.

📒 Files selected for processing (2)
  • examples/pruning/README.md
  • examples/pruning/rank_layer_importance.py

Comment on lines +81 to +82
kl_loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=True).cuda()
mse_loss = torch.nn.MSELoss(reduce=True, reduction='mean').cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Module-scope CUDA allocations:"
rg -n '^\s*[A-Za-z_][A-Za-z0-9_]*\s*=\s*torch\.nn\..*\.cuda\(\)' examples/pruning/rank_layer_importance.py

echo
echo "Distributed/device setup points in the same file:"
rg -n -C2 'dist\.setup\(|torch\.cuda\.set_device|LOCAL_RANK|\.cuda\(' examples/pruning/rank_layer_importance.py

Repository: NVIDIA/Model-Optimizer

Length of output: 914


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Usage of kl_loss and mse_loss variables ==="
rg -n 'kl_loss|mse_loss' examples/pruning/rank_layer_importance.py

echo
echo "=== File structure around lines 81-82 to confirm module scope ==="
sed -n '75,95p' examples/pruning/rank_layer_importance.py | cat -n

echo
echo "=== Verify __main__ block and entry point ==="
sed -n '410,419p' examples/pruning/rank_layer_importance.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 2114


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Total lines in file ==="
wc -l examples/pruning/rank_layer_importance.py

echo
echo "=== Exact match for 'kl_loss' references (case-sensitive, whole word) ==="
rg -w 'kl_loss' examples/pruning/rank_layer_importance.py

echo
echo "=== Exact match for 'mse_loss' references (case-sensitive, whole word) ==="
rg -w 'mse_loss' examples/pruning/rank_layer_importance.py

Repository: NVIDIA/Model-Optimizer

Length of output: 581


Remove unused module-scope CUDA allocations.

Lines 81–82 allocate kl_loss and mse_loss at module import time, before dist.setup() runs. This causes every worker to touch the default GPU during import, making the script fragile under torch.distributed.run. Additionally, these variables are never referenced anywhere in the file—the code uses F.mse_loss() instead—so they should be removed entirely rather than moved.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/pruning/rank_layer_importance.py` around lines 81 - 82, Remove the
unused module-level CUDA allocations kl_loss and mse_loss (the
torch.nn.KLDivLoss(...) and torch.nn.MSELoss(...) calls) which run at import
time and call .cuda(); delete these lines entirely since they are never
referenced and cause unwanted GPU initialization before dist.setup(); if loss
objects are needed later, instantiate them locally after distributed setup
(e.g., inside main or training functions) and avoid calling .cuda() at import
time.

Comment on lines +166 to +183
def hook_fn(self, module, input, output):
# seq x batch x dim
hidden_out = output.detach().permute(1, 0, 2) # batch x seq x dim

# if loading the reference form teacher
if self.reference_load:
self.reference_hidden.append(hidden_out)
return

# if computing the distance to the reference
sample_id = len(self.hidden_distance)
#MSE
self.hidden_distance.append( normalized_mse_loss_per_sample(hidden_out, self.reference_hidden[sample_id]).mean() )
# if computing the distance to the teacher's logits
if self.lm_head:
teacher_logits = self.lm_head(self.reference_hidden[sample_id].permute(1, 0, 2))[0].detach()
logits = self.lm_head(hidden_out.permute(1, 0, 2))[0].detach()
self.logits_distance.append( normalized_mse_loss_per_sample(logits, teacher_logits).mean() )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Caching every reference hidden state on-device will OOM quickly.

self.reference_hidden.append(hidden_out) keeps the full final hidden tensor for every calibration batch. With the defaults in this file (calib_num_samples=1024, seq_length=4096) and the 12B example documented in the README, that is on the order of tens of GB on the last PP rank before the ranking loop even starts. Please stream the comparison batch-by-batch, or at least slice/offload the reference activations before caching them.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/pruning/rank_layer_importance.py` around lines 166 - 183, The
hook_fn currently appends full GPU tensors to self.reference_hidden in the
reference-loading phase (self.reference_load), which causes OOM; instead either
(a) avoid caching full tensors by performing the comparison immediately per
calibration batch (compute normalized_mse_loss_per_sample between hidden_out and
the corresponding reference and append results to
self.hidden_distance/self.logits_distance), or (b) if caching is required,
move/slice the activation off device before storing (e.g., hidden_out =
hidden_out.detach().cpu() or store a reduced summary/selected slices) and for
logits use teacher_logits = self.lm_head(...).detach().cpu() so memory stays on
CPU/disk; update references to self.reference_hidden, hook_fn,
normalized_mse_loss_per_sample, self.hidden_distance, self.logits_distance, and
self.reference_load accordingly.

Comment on lines +245 to +248
# Uneven Pipeline Parallelism parameters
parser.add_argument("--num_layers", type=int, default=None)
parser.add_argument("--num_layers_in_first_pipeline_stage", type=int, default=None)
parser.add_argument("--num_layers_in_last_pipeline_stage", type=int, default=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

These topology flags are effectively required, but the parser leaves them optional.

range(args.num_layers) will raise when --num_layers is omitted, and when --num_layers_in_first_pipeline_stage is omitted the ternary in Line 354 evaluates to 0, so layer_id_in_this_rank() returns -1 for every layer and no block is ever patched. Even when values are provided, the offset math still assumes every rank owns num_layers_in_first_pipeline_stage layers, so uneven last-stage sizes patch the wrong global IDs. Please either infer the per-rank spans from the loaded model or make these arguments required and validate them before scoring.

Also applies to: 323-357, 394-405

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/pruning/rank_layer_importance.py` around lines 245 - 248, The
topology flags (--num_layers, --num_layers_in_first_pipeline_stage,
--num_layers_in_last_pipeline_stage) are treated as required but are optional
and the current layer_id_in_this_rank() math assumes uniform per-rank spans
causing incorrect global IDs; make these args required via
parser.add_argument(..., required=True) or add validation after parsing to
compute per-rank layer spans from the loaded model (e.g., derive total
num_layers from model.num_layers or model.config and compute per-rank start/end
counts using world_rank/world_size and the provided first/last stage sizes),
then replace the existing ternary/offset logic with computed per-rank
start_index/end_index used by layer_id_in_this_rank() and the loops over
range(args.num_layers) so that patches target the correct global layer IDs.

Comment on lines +233 to +240
- To estimate importance of each layer one can run `rank_layer_importance.py` script. This script computes importance of each layer by comparing the MSE between the final hidden representation with and without that layer.

```
python -m torch.distributed.run --nproc_per_node=2 /path/to/modelopt/examples/pruning/rank_layer_importance.py --hf_model_name_or_path /path/to/hf-checkpoint/nvidia/NVIDIA-Nemotron-Nano-12B-v2 --trust_remote_code --calib_dataset_name wikitext --num_layers_in_first_pipeline_stage 31 --num_layers_in_last_pipeline_stage 31 --num_layers 62
```
- One can also pass indices of layers that should be dropped always. This allows running an interative estimation e.g. in first iteration score all layers, pick 5 least important, and in the next iteration pass these 5 layers to be dropped, so that it ranks the rest of the layers assuming these 5 are dropped.
```
python -m torch.distributed.run --nproc_per_node=2 /path/to/modelopt/examples/pruning/rank_layer_importance.py --hf_model_name_or_path /path/to/hf-checkpoint/nvidia/NVIDIA-Nemotron-Nano-12B-v2 --trust_remote_code --calib_dataset_name wikitext --num_layers_in_first_pipeline_stage 31 --num_layers_in_last_pipeline_stage 31 --num_layers 62 --drop_layers 6 7 9 32 41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the new command fences and typo.

Both command blocks need a bash language and blank lines around the fences, and Line 238 still says interative.

🧰 Tools
🪛 markdownlint-cli2 (0.21.0)

[warning] 235-235: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


[warning] 237-237: Fenced code blocks should be surrounded by blank lines

(MD031, blanks-around-fences)


[warning] 239-239: Fenced code blocks should be surrounded by blank lines

(MD031, blanks-around-fences)


[warning] 239-239: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/pruning/README.md` around lines 233 - 240, Update the two command
fenced blocks around the usage of rank_layer_importance.py to include a blank
line before and after each fence and add the language tag "bash" after the
opening ```; also correct the typo "interative" to "iterative" in the
explanatory sentence that references iterative estimation (search for
rank_layer_importance.py and the paragraph mentioning interative/iterative).

@@ -0,0 +1,419 @@
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/tree/aa457edc3d64d81530159cd3a182932320c78f8c

# MIT License
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use the NVIDIA Apache license and change the year to 2026

from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine

from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this example belong in the Megatron-Bridge repo? @kevalmorabia97
we probably want to keep MBridge examples in the MBridge repo, to reduce circular dependencies between the two repos

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move to ModelOpt/examples/megatron_bridge where we have pruning and distillation for M-Bridge already

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants