Skip to content

Activation hooks redesign (reuse hooks component across both minitron and puzzletron)#1022

Open
danielkorzekwa wants to merge 9 commits intomainfrom
dkorzewa/activation_hooks_redesign_minitron_puzzletron
Open

Activation hooks redesign (reuse hooks component across both minitron and puzzletron)#1022
danielkorzekwa wants to merge 9 commits intomainfrom
dkorzewa/activation_hooks_redesign_minitron_puzzletron

Conversation

@danielkorzekwa
Copy link

@danielkorzekwa danielkorzekwa commented Mar 11, 2026

What does this PR do?

Type of change: Redesign of existing feature

This PR introduces a shared activation hooks infrastructure for minitron and puzzletron. The activation hooks framework provides a reusable component for collecting and analyzing activations during forward passes, which is used by both minitron pruning and puzzletron algorithms.

Note! Minitron megatron.py/mcore_minitron.py:ImportanceEstimatorRegistry code does not use this component yet - will be refactored in a separate MR.

Key changes:

  • Added modelopt/torch/prune/importance_hooks module with base hooks framework:

    • base_hooks.py: Core hook infrastructure for registering and managing forward hooks
    • base_hooks_analysis.py: Analysis utilities for processing collected activations
    • megatron_hooks.py: Megatron-specific hook implementations
    • compare_module_outputs.py: Utilities for comparing module outputs
  • Added unit tests in tests/gpu/torch/prune/importance_hooks:

    • test_base_hooks.py: Tests for base hooks functionality
    • test_base_hooks_analysis.py: Tests for activation analysis utilities
  • Updated test_mcore_gpt_minitron_pruning.py to validate activation collection

  • Updated test utilities for distributed testing support

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅ Yes - This is a new module that doesn't affect existing functionality
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ No
  • Did you write any new necessary tests?: ✅ Yes - Added comprehensive tests for the activation hooks infrastructure
  • Did you update Changelog?: ✅ N/A - This is infrastructure code that will be used by subsequent PRs

Summary by CodeRabbit

  • New Features

    • Robust JSON utilities for complex objects.
    • Flushed-print helper for synchronized output.
    • Activation-based importance estimation framework with multiple hook implementations, Megatron plugin support, and layer-output comparison tools.
    • Project-root test fixture for test suites.
  • Tests

    • Many new end-to-end tests validating hooks, evaluation metrics, and multi-layer output comparisons.
    • Improved pruning tests with deterministic initialization and additional statistical assertions.
  • Chores

    • Tests: mirror rank/size into environment for test init and disable external telemetry.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 11, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a comprehensive activation-based importance-hooks framework (hooks, analysis, comparison, Megatron plugin), robust JSON utilities, a flush-print helper, expanded tests, and minor test/distributed infra updates.

Changes

Cohort / File(s) Summary
Test infra & distributed utils
tests/conftest.py, tests/_test_utils/torch/distributed/utils.py
Added project_root_path pytest fixture; mirrored distributed rank/size into environment (RANK, LOCAL_RANK, WORLD_SIZE, LOCAL_WORLD_SIZE) and set WANDB_DISABLED="true".
Pruning test updates
tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
Added _assert_approx helper, forced CPU deterministic initialization flag, and added gated assertions checking pruning scores and layer activation statistics.
New small utilities
modelopt/torch/utils/logging.py, modelopt/torch/utils/robust_json.py
Exported aprint() (flush-print) and added robust_json module with RobustJSONEncoder, json_dumps, json_dump, and json_load to serialize dataclasses, Path, Enum, OmegaConf, callables, timedeltas, etc.
Importance hooks core
modelopt/torch/prune/importance_hooks/__init__.py, modelopt/torch/prune/importance_hooks/base_hooks.py
Added ForwardHook base and multiple concrete hooks (L2NormHook, IndependentChannelContributionHook, IndependentKvHeadContributionHook, LayerNormContributionHook, IterativeChannelContributionHook) with accumulation, to_dict/state_dict APIs, logging/dump, and distributed aggregation helpers.
Hook analysis & comparison
modelopt/torch/prune/importance_hooks/base_hooks_analysis.py, modelopt/torch/prune/importance_hooks/compare_module_outputs.py
Added evaluate_importance_scores() to simulate pruning impact (RMSE, cosine similarity) and compare_module_outputs utilities with OutputSaveHook, per-layer RMSE/cosine computations, multi-layer comparison and save/load helpers.
Megatron plugin
modelopt/torch/prune/importance_hooks/plugins/__init__.py, modelopt/torch/prune/importance_hooks/plugins/megatron_hooks.py
Added plugin package and MegatronL2NormHook that gathers inputs across tensor-parallel regions (TP-aware input gathering) and exports it behind plugin gate.
Hook unit tests
tests/gpu/torch/prune/importance_hooks/test_base_hooks.py, tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py
Added tests validating hook collection/accumulation and exact expected scores, plus end-to-end evaluation of importance scores (RMSE, cosine similarity) using synthetic and concrete hooks.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main objective: introducing a shared activation hooks infrastructure for reuse across minitron and puzzletron pruning components.
Docstring Coverage ✅ Passed Docstring coverage is 96.20% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed All torch.load(..., weights_only=False) calls include proper inline comments explaining why unsafe deserialization is required and confirming loaded files are internally-generated, fully complying with SECURITY.md requirements.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch dkorzewa/activation_hooks_redesign_minitron_puzzletron
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🧹 Nitpick comments (6)
tests/_test_utils/torch/distributed/utils.py (1)

26-30: Socket not explicitly closed after getting free port.

The socket is bound to get a free port but never closed. While Python's garbage collector will eventually close it, explicitly closing the socket ensures the port is released promptly and avoids potential resource leaks.

♻️ Suggested fix
 def get_free_port():
     sock = socket.socket()
     sock.bind(("", 0))
     port = sock.getsockname()[1]
+    sock.close()
     return port
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/_test_utils/torch/distributed/utils.py` around lines 26 - 30, The
get_free_port function creates a socket but never closes it; update
get_free_port to explicitly close the socket after binding (e.g., use a context
manager with socket.socket(...) as sock or call sock.close()) so the port is
released promptly; modify the function containing get_free_port to ensure the
socket is closed before returning the port while keeping the same behavior of
binding to ("", 0) and returning getsockname()[1].
modelopt/torch/puzzletron/tools/logger.py (3)

96-108: Fragile stack frame navigation.

The triple f_back navigation assumes a fixed call depth. If the call chain changes (e.g., adding a wrapper or calling dist_log directly), incorrect source locations will be reported. Consider using inspect.stack() with a more robust approach.

♻️ Suggested improvement
     `@staticmethod`
-    def get_caller_location() -> str:
+    def get_caller_location(depth: int = 3) -> str:
         """Get the caller location from the stack frame."""
-        # Get the caller's stack frame
-        frame = inspect.currentframe()
-
-        # f_back -> class method, 2 x f_back -> utils method, 3 x f_back -> original source
-        caller_frame = frame.f_back.f_back.f_back
-
-        # Get the filename and line number from the caller's stack frame
-        filename = os.path.basename(caller_frame.f_code.co_filename)
-        lineno = caller_frame.f_lineno
-        return f"{filename}:{lineno}"
+        stack = inspect.stack()
+        if len(stack) > depth:
+            frame_info = stack[depth]
+            return f"{os.path.basename(frame_info.filename)}:{frame_info.lineno}"
+        return "unknown:0"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/logger.py` around lines 96 - 108, The
get_caller_location method currently walks a fixed f_back chain which is
brittle; replace that logic in get_caller_location with a robust
inspect.stack()-based search: call inspect.stack() and iterate frames to find
the first frame whose module/filename is not this logger module (or where the
function name is not get_caller_location/dist_log), then use that frame's
filename and lineno to return "filename:lineno"; ensure you fall back to the
last non-None frame if none match to avoid exceptions.

23-23: Unused import of torch.distributed.launch.

This import is marked with # noqa: F401 (unused import) but there's no apparent reason for its presence. If it's needed for a side effect, please add a comment explaining why.

♻️ Suggested fix - remove if not needed
-import torch.distributed.launch  # noqa: F401
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/logger.py` at line 23, The unused import
torch.distributed.launch (currently annotated with # noqa: F401) should be
removed to eliminate dead code; if the import is intentionally required for a
side effect, replace the bare # noqa with a short explanatory comment (e.g.,
"import to register torch distributed launch entrypoint for CLI/side-effects")
next to the import so future readers understand its purpose—update the import
statement at the top of logger.py accordingly.

111-114: Global logger class modification may have unintended side effects.

logging.setLoggerClass(DistributedLogger) changes the default logger class for the entire process. Any subsequent logging.getLogger() calls in other modules will create DistributedLogger instances, which may not be intended.

Consider using a factory function or creating the logger directly without modifying the global logger class.

♻️ Alternative approach
-# Initialize logger
-logging.setLoggerClass(DistributedLogger)
-logger = logging.getLogger(__name__)
+# Initialize logger without modifying global logger class
+logger = DistributedLogger(__name__)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/logger.py` around lines 111 - 114, Remove the
global side-effect call to logging.setLoggerClass(DistributedLogger) and instead
instantiate or provide a factory that returns a DistributedLogger explicitly;
replace the current sequence with something like creating the module logger by
calling DistributedLogger(__name__) (or implement a get_distributed_logger(name)
helper that returns DistributedLogger(name)), assign it to the logger variable
and keep logger.propagate = False. Ensure no other global logger class is
changed so other modules calling logging.getLogger() are unaffected.
modelopt/torch/puzzletron/tools/robust_json.py (2)

31-31: Hard import of optional dependency at module level.

Per coding guidelines, optional dependencies should be gated. omegaconf may not be installed in all environments.

♻️ Suggested fix - gate the import
-from omegaconf import DictConfig, ListConfig, OmegaConf
+try:
+    from omegaconf import DictConfig, ListConfig, OmegaConf
+    _HAS_OMEGACONF = True
+except ImportError:
+    DictConfig = ListConfig = None
+    _HAS_OMEGACONF = False

Then in RobustJSONEncoder.default:

-        if isinstance(o, (DictConfig, ListConfig)):
+        if _HAS_OMEGACONF and isinstance(o, (DictConfig, ListConfig)):
             return OmegaConf.to_container(o, resolve=True)

As per coding guidelines: "Avoid hard imports of optional dependencies at module level; gate features by install extras."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/robust_json.py` at line 31, The module
currently hard-imports omegaconf at top-level; gate this optional dependency by
moving the import into the code path that needs it and handling ImportError:
remove the top-level "from omegaconf ..." import and instead import omegaconf
(or specific symbols) inside RobustJSONEncoder.default (and/or any other
functions that use DictConfig/ListConfig/OmegaConf), catch ImportError and fall
back to treating those objects as regular mappings or raise a clear runtime
error instructing to install the optional extra; update references to
DictConfig/ListConfig/OmegaConf in RobustJSONEncoder.default to use the local
import or the fallback behavior.

74-78: Return type hint is too restrictive.

json.loads can return any JSON-compatible type (dict, list, str, int, etc.), not just dict. Either update the type hint or add validation.

♻️ Suggested fix
-def json_load(path: Path | str) -> dict:
+def json_load(path: Path | str) -> Any:
     """Load JSON from file and return as dictionary."""
     path = Path(path)
     text = path.read_text()
     return json.loads(text)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/robust_json.py` around lines 74 - 78, The
function json_load currently types its return as dict but json.loads can return
any JSON-compatible type; update json_load's signature and implementation:
either change the return type from dict to a broad JSON type (e.g., Any or a
custom JSONType alias for dict|list|str|int|float|bool|None) or keep dict but
validate the loaded value and raise a clear exception if it's not a dict; refer
to the function name json_load and the code that calls
Path.read_text()/json.loads to implement the chosen approach.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py`:
- Around line 24-103: In evaluate_importance_scores, avoid building autograd
graphs and guard against empty input: wrap the per-batch evaluation loop (the
section using linear_layer, pruned_activations, pruned_output, rmse/cosine
computation) in a torch.no_grad() context to prevent gradient graph allocation,
and add an early validation at the top that checks activations_batches is
non-empty (raise a ValueError with a clear message or return zeroed metrics)
before computing num_to_prune; reference the function name
evaluate_importance_scores and the local symbols activations_batches,
linear_layer, pruned_activations, original_output, pruned_output, rmse_values,
and cosine_values when making the edits.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py`:
- Around line 154-156: The code currently mutates the live config by calling
args.activation_hooks_kwargs.pop("model") before json_dump, which removes the
model reference for later uses; instead, create a copy of
activation_hooks_kwargs (e.g., tmp = dict(args.activation_hooks_kwargs) or use
copy.deepcopy) and remove "model" from that copy if present, then call
json_dump(OmegaConf.to_container(args_copy_or_modified_args, resolve=True),
activations_log_dir / "args.json") so the original args.activation_hooks_kwargs
remains unchanged; reference the symbols args.activation_hooks_kwargs,
json_dump, OmegaConf.to_container, and activations_log_dir / "args.json".
- Around line 162-180: save_hook_states currently calls hook.state_dict() for
every hook which fails when certain hooks (e.g.,
IndependentKvHeadContributionHook, LayerNormContributionHook) raise
NotImplementedError; modify save_hook_states to safely handle hooks that don't
support checkpointing by wrapping the state collection in a try/except (catch
NotImplementedError and skip that hook) or by checking for a supported
method/attribute before calling state_dict, and optionally record/log skipped
module names; ensure you reference save_hook_states and state_dict in your
change so only hooks that successfully return state are included in the saved
hook_states dict.
- Around line 231-259: The L2NormHook.load_state_dict currently assigns
checkpointed "_activations" directly, which can leave tensors on the wrong
device and later cause device-mismatch in L2NormHook.accumulate (and in the "+="
path). Update L2NormHook.load_state_dict to move the loaded activations to the
current module/device before assigning (follow the pattern used in
IndependentChannelContributionHook.load_state_dict): detect the target device
(e.g., from a provided module or torch.device of existing tensors), call
.to(device) on state_dict["activations"] (preserving dtype), then assign to
self._activations so subsequent accumulate() and add_ operations are
device-safe.
- Around line 456-480: The recomputed forward output (output_curr) omits the
layer bias while output_tensor includes it, skewing scaling_factor_per_token;
modify the recomputation in the block using curr_activations and
self.weight_matrix so it also adds the layer bias when present (e.g., compute
output_curr = F.linear(input=curr_activations, weight=self.weight_matrix,
bias=self.bias) or add self.bias expanded to output_curr's shape), ensuring the
bias shape matches output_curr before computing output_norms and
scaling_factor_per_token; reference symbols: output_tensor, output_curr,
curr_activations, self.weight_matrix, self.bias, scaling_factor_per_token.

In `@modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py`:
- Around line 240-244: Replace the current prints-and-return branch that handles
layer mismatches with a hard failure: instead of printing "ERROR: Layer
mismatch!" and returning, raise a RuntimeError (or other appropriate exception)
that includes the mismatched ref_layers and comp_layers in the message so the
CLI/CI exits non-zero; modify the code in compare_module_outputs.py where
ref_layers and comp_layers are compared (the current if set(ref_layers) !=
set(comp_layers): block) to raise the exception with clear context.

In `@modelopt/torch/puzzletron/tools/logger.py`:
- Around line 69-73: The NotImplementedError message in the ranks validation
(inside the function/method that broadcasts messages, referencing variables msg
and ranks) is missing the 'last' choice; update the error text to list all valid
options consistently with the check — e.g., include 'last' alongside 'all',
'main', and 'local_main' — so the raised message accurately reflects the allowed
ranks values.
- Around line 79-84: The "last" branch incorrectly compares self.local_rank to
self.world_size - 1; replace that check to target the last local rank on node 0
by testing node and local sizes: change the condition (ranks == "last" and
self.local_rank != self.world_size - 1) to require that the process is not the
last local rank on node 0 — e.g., use self.node_rank and self.local_world_size
and only allow printing when (self.node_rank == 0 and self.local_rank ==
self.local_world_size - 1); update the condition accordingly in the function
that contains the ranks logic so it references self.node_rank, self.local_rank,
and self.local_world_size instead of self.world_size.

In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py`:
- Around line 217-257: The test uses fragile, hardware-specific hard-coded
activation values (e.g., pruning_scores["layer_scores"], activations entries
like "decoder.layers.0.mlp" / "decoder.layers.0.self_attention") checked with
_assert_approx and abs=1e-3; replace these strict exact-value checks with more
robust assertions: validate tensor shapes and value ranges or use relative
tolerance (rtol) instead of strict absolute checks, or widen/make the tolerance
configurable via a test constant; for layer-level correctness prefer asserting
expected monotonic/relative relationships (e.g., pruned vs unpruned scores)
rather than exact floats; optionally split the large conditional block into
separate focused tests for MHA vs GQA using the same symbols (pruning_scores,
activations, _assert_approx) to improve readability and maintainability.

In `@tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py`:
- Around line 113-154: The forward hook handle created in _run_hook_and_evaluate
is only removed on the happy path; wrap the work that runs the forward passes
and calls hook.accumulate (the loop that appends to all_activations and the call
to importance_scores = hook.accumulate()) in a try/finally and call
handle.remove() in the finally block to guarantee the hook is detached even on
exceptions; re-raise the exception after cleanup if one occurred so failures are
propagated, and keep the subsequent call to evaluate_importance_scores
unchanged.

---

Nitpick comments:
In `@modelopt/torch/puzzletron/tools/logger.py`:
- Around line 96-108: The get_caller_location method currently walks a fixed
f_back chain which is brittle; replace that logic in get_caller_location with a
robust inspect.stack()-based search: call inspect.stack() and iterate frames to
find the first frame whose module/filename is not this logger module (or where
the function name is not get_caller_location/dist_log), then use that frame's
filename and lineno to return "filename:lineno"; ensure you fall back to the
last non-None frame if none match to avoid exceptions.
- Line 23: The unused import torch.distributed.launch (currently annotated with
# noqa: F401) should be removed to eliminate dead code; if the import is
intentionally required for a side effect, replace the bare # noqa with a short
explanatory comment (e.g., "import to register torch distributed launch
entrypoint for CLI/side-effects") next to the import so future readers
understand its purpose—update the import statement at the top of logger.py
accordingly.
- Around line 111-114: Remove the global side-effect call to
logging.setLoggerClass(DistributedLogger) and instead instantiate or provide a
factory that returns a DistributedLogger explicitly; replace the current
sequence with something like creating the module logger by calling
DistributedLogger(__name__) (or implement a get_distributed_logger(name) helper
that returns DistributedLogger(name)), assign it to the logger variable and keep
logger.propagate = False. Ensure no other global logger class is changed so
other modules calling logging.getLogger() are unaffected.

In `@modelopt/torch/puzzletron/tools/robust_json.py`:
- Line 31: The module currently hard-imports omegaconf at top-level; gate this
optional dependency by moving the import into the code path that needs it and
handling ImportError: remove the top-level "from omegaconf ..." import and
instead import omegaconf (or specific symbols) inside RobustJSONEncoder.default
(and/or any other functions that use DictConfig/ListConfig/OmegaConf), catch
ImportError and fall back to treating those objects as regular mappings or raise
a clear runtime error instructing to install the optional extra; update
references to DictConfig/ListConfig/OmegaConf in RobustJSONEncoder.default to
use the local import or the fallback behavior.
- Around line 74-78: The function json_load currently types its return as dict
but json.loads can return any JSON-compatible type; update json_load's signature
and implementation: either change the return type from dict to a broad JSON type
(e.g., Any or a custom JSONType alias for dict|list|str|int|float|bool|None) or
keep dict but validate the loaded value and raise a clear exception if it's not
a dict; refer to the function name json_load and the code that calls
Path.read_text()/json.loads to implement the chosen approach.

In `@tests/_test_utils/torch/distributed/utils.py`:
- Around line 26-30: The get_free_port function creates a socket but never
closes it; update get_free_port to explicitly close the socket after binding
(e.g., use a context manager with socket.socket(...) as sock or call
sock.close()) so the port is released promptly; modify the function containing
get_free_port to ensure the socket is closed before returning the port while
keeping the same behavior of binding to ("", 0) and returning getsockname()[1].

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: afede6d7-95c6-4028-bab9-e7b8241fd750

📥 Commits

Reviewing files that changed from the base of the PR and between fe83270 and 6ce8345.

📒 Files selected for processing (14)
  • modelopt/torch/nas/plugins/megatron_hooks/__init__.py
  • modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py
  • modelopt/torch/nas/plugins/megatron_hooks/base_hooks_analysis.py
  • modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py
  • modelopt/torch/nas/plugins/megatron_hooks/megatron_hooks.py
  • modelopt/torch/puzzletron/__init__.py
  • modelopt/torch/puzzletron/tools/__init__.py
  • modelopt/torch/puzzletron/tools/logger.py
  • modelopt/torch/puzzletron/tools/robust_json.py
  • tests/_test_utils/torch/distributed/utils.py
  • tests/conftest.py
  • tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks.py
  • tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Comment on lines +154 to +156
if rank == 0:
args.activation_hooks_kwargs.pop("model")
json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't mutate args.activation_hooks_kwargs just to dump JSON.

pop("model") edits the live config object. Reusing the same args later loses the model reference, and a second dump on rank 0 can fail on the missing key.

🛠️ Suggested fix
         if rank == 0:
-            args.activation_hooks_kwargs.pop("model")
-            json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json")
+            args_to_dump = OmegaConf.to_container(args, resolve=True)
+            args_to_dump.get("activation_hooks_kwargs", {}).pop("model", None)
+            json_dump(args_to_dump, activations_log_dir / "args.json")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if rank == 0:
args.activation_hooks_kwargs.pop("model")
json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json")
if rank == 0:
args_to_dump = OmegaConf.to_container(args, resolve=True)
args_to_dump.get("activation_hooks_kwargs", {}).pop("model", None)
json_dump(args_to_dump, activations_log_dir / "args.json")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 154 -
156, The code currently mutates the live config by calling
args.activation_hooks_kwargs.pop("model") before json_dump, which removes the
model reference for later uses; instead, create a copy of
activation_hooks_kwargs (e.g., tmp = dict(args.activation_hooks_kwargs) or use
copy.deepcopy) and remove "model" from that copy if present, then call
json_dump(OmegaConf.to_container(args_copy_or_modified_args, resolve=True),
activations_log_dir / "args.json") so the original args.activation_hooks_kwargs
remains unchanged; reference the symbols args.activation_hooks_kwargs,
json_dump, OmegaConf.to_container, and activations_log_dir / "args.json".

Comment on lines +162 to +180
def save_hook_states(
cls: type["ForwardHook"],
activation_hooks: dict[str, "ForwardHook"],
activations_log_dir: Path | str,
) -> None:
"""Save hook states for checkpointing (separate from final results).

This can be called periodically during scoring.
Note: Synchronization should be handled at a higher level to avoid deadlocks.
"""
activations_log_dir = Path(activations_log_dir)
activations_log_dir.mkdir(exist_ok=True, parents=True)
rank = dist.rank()

hook_states_path = activations_log_dir / f"hook_states_rank_{rank}.pth"
hook_states = {
module_name: hook.state_dict() for module_name, hook in activation_hooks.items()
}
torch.save(hook_states, hook_states_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

save_hook_states() needs to handle hooks without checkpoint support.

This helper blindly calls state_dict() on every hook, but IndependentKvHeadContributionHook and LayerNormContributionHook later raise NotImplementedError. One such hook makes periodic checkpointing fail for the whole run.

🛠️ Suggested fix
-        hook_states = {
-            module_name: hook.state_dict() for module_name, hook in activation_hooks.items()
-        }
+        hook_states = {}
+        unsupported = []
+        for module_name, hook in activation_hooks.items():
+            try:
+                hook_states[module_name] = hook.state_dict()
+            except NotImplementedError:
+                unsupported.append(module_name)
+
+        if unsupported:
+            aprint(
+                "Skipping hook checkpoint save for hooks without state support: "
+                f"{unsupported}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 162 -
180, save_hook_states currently calls hook.state_dict() for every hook which
fails when certain hooks (e.g., IndependentKvHeadContributionHook,
LayerNormContributionHook) raise NotImplementedError; modify save_hook_states to
safely handle hooks that don't support checkpointing by wrapping the state
collection in a try/except (catch NotImplementedError and skip that hook) or by
checking for a supported method/attribute before calling state_dict, and
optionally record/log skipped module names; ensure you reference
save_hook_states and state_dict in your change so only hooks that successfully
return state are included in the saved hook_states dict.

Comment on lines +231 to +259
if self._activations is None:
self._activations = activations
else:
self._activations += activations

def accumulate(self) -> torch.Tensor:
"""Return the accumulated L2 norm of activations.

Returns:
Tensor of accumulated scores, one per channel

Raises:
AssertionError: If no activations have been collected yet
"""
assert self._activations is not None, "No activations collected for importance estimation."
# Convert squared sum to L2 norm
return self._activations.pow(0.5)

def to_dict(self) -> dict[str, torch.Tensor]:
"""Convert to dict format for saving."""
return {"score": self.accumulate().cpu()}

def state_dict(self) -> dict:
"""Return the state dictionary containing activations."""
return {"activations": self._activations}

def load_state_dict(self, state_dict: dict) -> None:
"""Load activations from checkpoint."""
self._activations = state_dict["activations"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

wc -l modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 125


🏁 Script executed:

cat -n modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | sed -n '200,280p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3929


🏁 Script executed:

# Find the IndependentChannelContributionHook class
rg "class IndependentChannelContributionHook" -A 50 modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2237


🏁 Script executed:

# Get the full L2NormHook class definition
rg "class L2NormHook" -B 5 -A 100 modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 4345


🏁 Script executed:

# Find IndependentChannelContributionHook's load_state_dict method
rg "class IndependentChannelContributionHook" -A 200 modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | grep -A 20 "load_state_dict"

Repository: NVIDIA/Model-Optimizer

Length of output: 1178


🏁 Script executed:

# Let's get more context on IndependentChannelContributionHook's state management
sed -n '262,500p' modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 10681


Make L2NormHook state reload device-agnostic.

After restoring from checkpoint, _activations remains on the checkpoint device. When the resumed layer runs on a different device, accumulation fails on line 234 due to device mismatch in the += operation. IndependentChannelContributionHook.load_state_dict() already demonstrates the correct pattern for device-agnostic state loading.

🛠️ Suggested fix
         if self._activations is None:
             self._activations = activations
         else:
+            if self._activations.device != activations.device:
+                self._activations = self._activations.to(activations.device)
             self._activations += activations
@@
     def state_dict(self) -> dict:
         """Return the state dictionary containing activations."""
-        return {"activations": self._activations}
+        return {
+            "activations": None
+            if self._activations is None
+            else self._activations.cpu().clone()
+        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self._activations is None:
self._activations = activations
else:
self._activations += activations
def accumulate(self) -> torch.Tensor:
"""Return the accumulated L2 norm of activations.
Returns:
Tensor of accumulated scores, one per channel
Raises:
AssertionError: If no activations have been collected yet
"""
assert self._activations is not None, "No activations collected for importance estimation."
# Convert squared sum to L2 norm
return self._activations.pow(0.5)
def to_dict(self) -> dict[str, torch.Tensor]:
"""Convert to dict format for saving."""
return {"score": self.accumulate().cpu()}
def state_dict(self) -> dict:
"""Return the state dictionary containing activations."""
return {"activations": self._activations}
def load_state_dict(self, state_dict: dict) -> None:
"""Load activations from checkpoint."""
self._activations = state_dict["activations"]
if self._activations is None:
self._activations = activations
else:
if self._activations.device != activations.device:
self._activations = self._activations.to(activations.device)
self._activations += activations
def accumulate(self) -> torch.Tensor:
"""Return the accumulated L2 norm of activations.
Returns:
Tensor of accumulated scores, one per channel
Raises:
AssertionError: If no activations have been collected yet
"""
assert self._activations is not None, "No activations collected for importance estimation."
# Convert squared sum to L2 norm
return self._activations.pow(0.5)
def to_dict(self) -> dict[str, torch.Tensor]:
"""Convert to dict format for saving."""
return {"score": self.accumulate().cpu()}
def state_dict(self) -> dict:
"""Return the state dictionary containing activations."""
return {
"activations": None
if self._activations is None
else self._activations.cpu().clone()
}
def load_state_dict(self, state_dict: dict) -> None:
"""Load activations from checkpoint."""
self._activations = state_dict["activations"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 231 -
259, The L2NormHook.load_state_dict currently assigns checkpointed
"_activations" directly, which can leave tensors on the wrong device and later
cause device-mismatch in L2NormHook.accumulate (and in the "+=" path). Update
L2NormHook.load_state_dict to move the loaded activations to the current
module/device before assigning (follow the pattern used in
IndependentChannelContributionHook.load_state_dict): detect the target device
(e.g., from a provided module or torch.device of existing tensors), call
.to(device) on state_dict["activations"] (preserving dtype), then assign to
self._activations so subsequent accumulate() and add_ operations are
device-safe.

Comment on lines +456 to +480
# Handle case where output is a tuple (e.g., from ColumnParallelLinear/RowParallelLinear)
# TODO: Consider better design to handle RowParallelLinear and nn.Linear
if isinstance(output, tuple):
output_tensor = output[0]
else:
output_tensor = output

activations = args[0]

# Don't aggregate activations from non-max subnets (e.g. from profiling)
if self.max_size is not None and activations.shape[-1] != self.max_size:
return

n_channels_to_prune = self.pruning_schedule[self.curr_iter]

curr_activations = activations.clone() # Shape B,T,I
curr_activations[..., self.pruned_channels] = 0
output_curr = F.linear(input=curr_activations, weight=self.weight_matrix) # Shape B,T,E

if self.calibration_method is None:
scaling_factor_per_token = torch.ones_like(output_tensor[..., 0]) # Shape B,T
elif self.calibration_method == "scale_by_magnitude":
output_norms = torch.linalg.vector_norm(output_tensor, dim=-1) # Shape B,T
output_curr_norms = torch.linalg.vector_norm(output_curr, dim=-1) # Shape B,T
scaling_factor_per_token = output_curr_norms / (output_norms + self.epsilon)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | sed -n '440,500p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3427


🏁 Script executed:

cat -n modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 3769


🏁 Script executed:

ast-grep --pattern 'def $method($args) {
  $$$
}'

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

rg -A 50 "Handle case where output is a tuple" modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2550


🏁 Script executed:

wc -l modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Repository: NVIDIA/Model-Optimizer

Length of output: 125


Include bias in the pruned forward recomputation.

For biased nn.Linear layers, output_tensor includes the bias term but output_curr does not. This mismatch causes the scaling factor computation (line 480) and subsequent channel importance ranking (line 487) to be incorrect. The bias difference must be accounted for in the recomputed forward pass.

Fix
         curr_activations = activations.clone()  # Shape B,T,I
         curr_activations[..., self.pruned_channels] = 0
+        bias = None if isinstance(output, tuple) else getattr(module, "bias", None)
-        output_curr = F.linear(input=curr_activations, weight=self.weight_matrix)  # Shape B,T,E
+        output_curr = F.linear(
+            input=curr_activations,
+            weight=self.weight_matrix,
+            bias=bias,
+        )  # Shape B,T,E
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 456 -
480, The recomputed forward output (output_curr) omits the layer bias while
output_tensor includes it, skewing scaling_factor_per_token; modify the
recomputation in the block using curr_activations and self.weight_matrix so it
also adds the layer bias when present (e.g., compute output_curr =
F.linear(input=curr_activations, weight=self.weight_matrix, bias=self.bias) or
add self.bias expanded to output_curr's shape), ensuring the bias shape matches
output_curr before computing output_norms and scaling_factor_per_token;
reference symbols: output_tensor, output_curr, curr_activations,
self.weight_matrix, self.bias, scaling_factor_per_token.

Comment on lines +240 to +244
if set(ref_layers) != set(comp_layers):
print("\nERROR: Layer mismatch!")
print(f"Reference layers: {ref_layers}")
print(f"Compare layers: {comp_layers}")
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Raise on layer mismatch instead of returning successfully.

A bad comparison is a hard failure. Returning here makes the CLI exit with status 0, so scripts and CI can silently accept invalid results.

🛠️ Suggested fix
     if set(ref_layers) != set(comp_layers):
-        print("\nERROR: Layer mismatch!")
-        print(f"Reference layers: {ref_layers}")
-        print(f"Compare layers: {comp_layers}")
-        return
+        raise ValueError(
+            "Layer mismatch: "
+            f"reference={sorted(ref_layers)}, compare={sorted(comp_layers)}"
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if set(ref_layers) != set(comp_layers):
print("\nERROR: Layer mismatch!")
print(f"Reference layers: {ref_layers}")
print(f"Compare layers: {comp_layers}")
return
if set(ref_layers) != set(comp_layers):
raise ValueError(
"Layer mismatch: "
f"reference={sorted(ref_layers)}, compare={sorted(comp_layers)}"
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/compare_module_outputs.py` around
lines 240 - 244, Replace the current prints-and-return branch that handles layer
mismatches with a hard failure: instead of printing "ERROR: Layer mismatch!" and
returning, raise a RuntimeError (or other appropriate exception) that includes
the mismatched ref_layers and comp_layers in the message so the CLI/CI exits
non-zero; modify the code in compare_module_outputs.py where ref_layers and
comp_layers are compared (the current if set(ref_layers) != set(comp_layers):
block) to raise the exception with clear context.

Comment on lines +69 to +73
if ranks not in ["all", "main", "local_main", "last"]:
raise NotImplementedError(
f"Could not broadcast msg {msg} - "
f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Error message is inconsistent with valid choices.

The validation checks for ["all", "main", "local_main", "last"] but the error message only lists ['all', 'main', 'local_main'], missing 'last'.

🐛 Proposed fix
         if ranks not in ["all", "main", "local_main", "last"]:
             raise NotImplementedError(
                 f"Could not broadcast msg {msg} - "
-                f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}"
+                f"ranks parameters choices are ['all', 'main', 'local_main', 'last']. Got {ranks}"
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if ranks not in ["all", "main", "local_main", "last"]:
raise NotImplementedError(
f"Could not broadcast msg {msg} - "
f"ranks parameters choices are ['all', 'main', 'local_main']. Got {ranks}"
)
if ranks not in ["all", "main", "local_main", "last"]:
raise NotImplementedError(
f"Could not broadcast msg {msg} - "
f"ranks parameters choices are ['all', 'main', 'local_main', 'last']. Got {ranks}"
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/logger.py` around lines 69 - 73, The
NotImplementedError message in the ranks validation (inside the function/method
that broadcasts messages, referencing variables msg and ranks) is missing the
'last' choice; update the error text to list all valid options consistently with
the check — e.g., include 'last' alongside 'all', 'main', and 'local_main' — so
the raised message accurately reflects the allowed ranks values.

Comment on lines +79 to +84
elif (
(ranks == "main" and self.global_rank != 0)
or (ranks == "last" and self.local_rank != self.world_size - 1)
or (ranks == "local_main" and self.local_rank != 0)
):
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Logic bug in 'last' rank check for multi-node environments.

The condition self.local_rank != self.world_size - 1 compares local_rank (0 to local_world_size-1) against world_size (total processes across all nodes). In a multi-node setup with 8 total GPUs across 2 nodes, local_rank ranges 0-3 but world_size - 1 is 7, so the condition local_rank != 7 is always true and no rank ever prints.

Based on the docstring at line 167 ("rank -1 in node 0"), the intent seems to be the last rank on node 0.

🐛 Proposed fix
         elif (
             (ranks == "main" and self.global_rank != 0)
-            or (ranks == "last" and self.local_rank != self.world_size - 1)
+            or (ranks == "last" and self.global_rank != self.world_size - 1)
             or (ranks == "local_main" and self.local_rank != 0)
         ):
             return

Or if the intent is truly "last local rank on each node":

-            or (ranks == "last" and self.local_rank != self.world_size - 1)
+            or (ranks == "last" and self.local_rank != int(os.environ.get("LOCAL_WORLD_SIZE", 1)) - 1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/logger.py` around lines 79 - 84, The "last"
branch incorrectly compares self.local_rank to self.world_size - 1; replace that
check to target the last local rank on node 0 by testing node and local sizes:
change the condition (ranks == "last" and self.local_rank != self.world_size -
1) to require that the process is not the last local rank on node 0 — e.g., use
self.node_rank and self.local_world_size and only allow printing when
(self.node_rank == 0 and self.local_rank == self.local_world_size - 1); update
the condition accordingly in the function that contains the ranks logic so it
references self.node_rank, self.local_rank, and self.local_world_size instead of
self.world_size.

Comment on lines +217 to +257
# TODO: Simplify it: this unit test is too long,
# hard to read (the same set of assertions across different test cases with if-else).

assert len(pruning_scores["activations_per_rank"]) == size
activations = pruning_scores["activations_per_rank"][rank]

# Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4)
if size == 1 and pruned_ffn_div == 4:
# Layer scores
_assert_approx(pruning_scores["layer_scores"], {1: 0.028923, 2: 0.046508})

# Validate decoder.layers.0.mlp activations
mlp_0_acts = activations["decoder.layers.0.mlp"]
_assert_approx(mlp_0_acts.min().item(), 0.000026)
_assert_approx(mlp_0_acts.max().item(), 0.000729)
_assert_approx(mlp_0_acts.mean().item(), 0.000201)

# Validate decoder.layers.1.mlp activations
mlp_1_acts = activations["decoder.layers.1.mlp"]
_assert_approx(mlp_1_acts.min().item(), 0.000022)
_assert_approx(mlp_1_acts.max().item(), 0.000762)
_assert_approx(mlp_1_acts.mean().item(), 0.000162)

# Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2)
elif size == 1 and pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1:
# Layer scores
_assert_approx(pruning_scores["layer_scores"], {1: 0.028056, 2: 0.038353})

# Validate decoder.layers.0.self_attention activations
attn_0_acts = activations["decoder.layers.0.self_attention"]
assert attn_0_acts.shape == torch.Size([hidden_size])
_assert_approx(attn_0_acts.min().item(), 0.010091)
_assert_approx(attn_0_acts.max().item(), 0.023826)
_assert_approx(attn_0_acts.mean().item(), 0.014548)

# Validate decoder.layers.1.self_attention activations
attn_1_acts = activations["decoder.layers.1.self_attention"]
assert attn_1_acts.shape == torch.Size([hidden_size])
_assert_approx(attn_1_acts.min().item(), 0.009982)
_assert_approx(attn_1_acts.max().item(), 0.035644)
_assert_approx(attn_1_acts.mean().item(), 0.020140)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if there are any hook cleanup or state reset mechanisms in the test utilities
rg -n "cleanup|reset|teardown" tests/_test_utils/torch/ -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 4291


🏁 Script executed:

#!/bin/bash
# Check the test file structure and imports
head -50 tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1818


🏁 Script executed:

#!/bin/bash
# Look for fixture definitions and pytest markers in the test file
rg -n "@pytest|@fixture|def setup|def teardown" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 155


🏁 Script executed:

#!/bin/bash
# Check if the test uses DistributedProcess or similar distributed utilities
rg -n "DistributedProcess|distributed|teardown" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 136


🏁 Script executed:

#!/bin/bash
# Check for conftest.py and pytest configuration
find tests/gpu_megatron -name "conftest.py" -exec cat {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 2622


🏁 Script executed:

#!/bin/bash
# Check pytest configuration
cat pyproject.toml | grep -A 20 "\[tool.pytest"

Repository: NVIDIA/Model-Optimizer

Length of output: 756


🏁 Script executed:

#!/bin/bash
# Look at the full test function containing lines 217-257
sed -n '180,260p' tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 3785


🏁 Script executed:

#!/bin/bash
# Get the full function signature and parametrize decorators
sed -n '110,180p' tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2465


🏁 Script executed:

#!/bin/bash
# Search for where this test function is defined to see if it uses dist_workers fixture
grep -n "def test_mcore_gpt_minitron_pruning_and_export" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

#!/bin/bash
# Check how set_seed is implemented
rg -n "def set_seed" tests/_test_utils/torch/ -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 808


🏁 Script executed:

#!/bin/bash
# Look for test function wrapper that calls _test_mcore_gpt_pruning
rg -n "test_mcore_gpt_pruning_and_export|test_mcore_gpt_pruning[^_]" tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py -B 5 -A 15 | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 1842


Hard-coded activation values with tight tolerance may cause test fragility across different hardware and CUDA versions.

The test uses very specific hard-coded values (e.g., 0.028923, 0.046508) with abs=1e-3 tolerance. While proper teardown and seeding infrastructure is in place (via megatron_worker_teardown and set_seed), GPU-specific activation values can vary slightly across different hardware and CUDA versions, causing these assertions to fail even though the pruning logic is correct.

Consider:

  1. Using relative comparisons or shape/range checks instead of exact activation value assertions
  2. Documenting which GPU/CUDA version these values were captured on and re-validating periodically
  3. If exact values are needed, widen tolerance or make it configurable per environment

The TODO comment at line 217 also notes this test's complexity—splitting into focused unit tests would improve maintainability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py`
around lines 217 - 257, The test uses fragile, hardware-specific hard-coded
activation values (e.g., pruning_scores["layer_scores"], activations entries
like "decoder.layers.0.mlp" / "decoder.layers.0.self_attention") checked with
_assert_approx and abs=1e-3; replace these strict exact-value checks with more
robust assertions: validate tensor shapes and value ranges or use relative
tolerance (rtol) instead of strict absolute checks, or widen/make the tolerance
configurable via a test constant; for layer-level correctness prefer asserting
expected monotonic/relative relationships (e.g., pruned vs unpruned scores)
rather than exact floats; optionally split the large conditional block into
separate focused tests for MHA vs GQA using the same symbols (pruning_scores,
activations, _assert_approx) to improve readability and maintainability.

Comment on lines +113 to +154
def _run_hook_and_evaluate(
layer: nn.Linear,
hook,
num_iterations: int,
prune_ratio: float,
) -> dict:
"""Shared helper to run hook, collect scores, and evaluate.

Args:
layer: Linear layer to test
hook: Hook instance (already created)
num_iterations: Number of forward passes
prune_ratio: Fraction of channels to prune

Returns:
Dictionary with evaluation metrics
"""
handle = layer.register_forward_hook(hook) # Store the handle

# Run forward passes
all_activations = []
for _ in range(num_iterations):
activations = torch.randn(16, 8, layer.in_features) # seq=16, batch=8, in_features=50
all_activations.append(activations)
_ = layer(activations)

# Get importance scores from hook
importance_scores = hook.accumulate()

# Remove the hook before evaluation to avoid triggering it again
handle.remove()

# Evaluate the importance scores by simulating pruning on all collected activations
# Pass the list of activations to compute averaged metrics across batches
metrics = evaluate_importance_scores(
layer,
all_activations, # List of activation batches
importance_scores,
prune_ratio=prune_ratio,
)

return metrics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Always remove the forward hook in a finally block.

handle.remove() only runs on the happy path. If the forward loop or hook.accumulate() fails, the hook stays attached in the reused worker and can bleed into the next scenario.

🛠️ Suggested fix
 def _run_hook_and_evaluate(
     layer: nn.Linear,
     hook,
     num_iterations: int,
@@
-    handle = layer.register_forward_hook(hook)  # Store the handle
-
-    # Run forward passes
-    all_activations = []
-    for _ in range(num_iterations):
-        activations = torch.randn(16, 8, layer.in_features)  # seq=16, batch=8, in_features=50
-        all_activations.append(activations)
-        _ = layer(activations)
-
-    # Get importance scores from hook
-    importance_scores = hook.accumulate()
-
-    # Remove the hook before evaluation to avoid triggering it again
-    handle.remove()
+    handle = layer.register_forward_hook(hook)
+    try:
+        all_activations = []
+        for _ in range(num_iterations):
+            activations = torch.randn(16, 8, layer.in_features)
+            all_activations.append(activations)
+            _ = layer(activations)
+
+        importance_scores = hook.accumulate()
+    finally:
+        handle.remove()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/nas/plugins/megatron_hooks/test_base_hooks_analysis.py`
around lines 113 - 154, The forward hook handle created in
_run_hook_and_evaluate is only removed on the happy path; wrap the work that
runs the forward passes and calls hook.accumulate (the loop that appends to
all_activations and the call to importance_scores = hook.accumulate()) in a
try/finally and call handle.remove() in the finally block to guarantee the hook
is detached even on exceptions; re-raise the exception after cleanup if one
occurred so failures are propagated, and keep the subsequent call to
evaluate_importance_scores unchanged.

@danielkorzekwa danielkorzekwa requested a review from a team as a code owner March 11, 2026 14:40
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (5)
tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)

248-259: Shape assertions are robust; consider extending this pattern.

The assert attn_0_acts.shape == torch.Size([hidden_size]) checks (lines 249, 256) are hardware-independent and reliable. Consider adopting similar structural assertions for the MLP activations in the first test case (lines 231-240) to improve robustness while maintaining the statistical checks as secondary validation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py`
around lines 248 - 259, Add hardware-independent shape assertions for the MLP
activation entries similar to the attention checks: for the MLP activations
retrieved from the activations dict (e.g., activations["decoder.layers.0.mlp"]
and activations["decoder.layers.1.mlp"]) assert their .shape equals
torch.Size([hidden_size]) before or alongside the existing min/max/mean
_assert_approx checks; update the test variables corresponding to the first test
case’s MLP activations (the variables that hold those activation tensors) to
include these assert ... .shape == torch.Size([hidden_size]) statements.
modelopt/torch/utils/robust_json.py (2)

47-48: Fragile dtype detection via string comparison.

Matching type(o).__name__ == "dtype" will catch any class named "dtype", not just numpy.dtype or torch.dtype. This could lead to unexpected behavior with other types. Consider adding a comment documenting this intentional duck-typing approach, or checking o.__class__.__module__ as well (e.g., starts with numpy or torch).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/robust_json.py` around lines 47 - 48, The current
fragile dtype detection uses type(o).__name__ == "dtype"; update it to only
match actual numpy/torch dtypes by also checking the class module (e.g., check
o.__class__.__module__ and ensure it startswith "numpy" or "torch") or
explicitly check for instances via duck-typing attributes, and add a brief
comment explaining the intent; replace the condition referencing
type(o).__name__ == "dtype" with a combined check using o.__class__.__module__
(or equivalent) so only numpy/torch dtype classes are matched in robust_json.py.

15-15: Consider removing the blanket mypy suppression.

Disabling mypy for the entire file with # mypy: ignore-errors bypasses type checking. The file already has type annotations, so targeted # type: ignore comments on specific problematic lines (if any) would be preferable to maintain type safety for the rest of the module.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/robust_json.py` at line 15, Remove the top-level blanket
mypy suppression in robust_json.py (the `# mypy: ignore-errors` comment) and
instead address specific typing issues: delete that header, run mypy to find
failing lines, and add targeted `# type: ignore[...]` comments only on the
problematic expressions (or refine annotations on functions/classes such as any
helper functions in robust_json.py) so the rest of the module retains type
checking.
modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py (2)

675-681: Consider documenting why checkpointing is not supported.

The state_dict and load_state_dict methods raise NotImplementedError. While this is a valid design choice, adding a brief explanation in the docstring or error message (e.g., "KV head pruning is designed to complete in a single run") would help users understand the limitation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 675 -
681, Update the docstrings and/or raised NotImplementedError messages for the
state_dict and load_state_dict methods in base_hooks.py to briefly explain why
checkpointing isn't supported (e.g., "Checkpointing not supported because KV
head pruning is a one-time operation and does not maintain persistent mutable
state across runs"), so users understand the design choice; modify the
docstrings of state_dict and load_state_dict and the text passed to
NotImplementedError in those methods to include that brief rationale.

810-815: Consider using json_dump from robust_json for consistency.

The module imports json_dump from modelopt.torch.utils.robust_json (line 30) but uses the standard json.dump here. Using json_dump would ensure consistent JSON serialization across the codebase and automatically create parent directories.

Suggested change
-        output_path = activations_log_dir / "channel_importance_results.json"
-        aprint(f"Saving channel importance data to {output_path}")
-        with open(output_path, "w") as f:
-            json.dump(output_data, f, indent=2)
+        output_path = activations_log_dir / "channel_importance_results.json"
+        aprint(f"Saving channel importance data to {output_path}")
+        json_dump(output_data, output_path)

Note: This would lose the indent=2 formatting. If pretty-printing is important, consider adding an indent parameter to json_dump or keep the current approach.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py` around lines 810 -
815, Replace the use of json.dump when writing output_data to output_path with
the project utility json_dump (imported from modelopt.torch.utils.robust_json)
so serialization is consistent and parent directories are created automatically;
update the write block that currently uses output_path and json.dump to call
json_dump(output_data, output_path) (or add an indent param to json_dump if
pretty-printing is required) and remove the manual open(...) context since
json_dump handles file creation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/utils/logging.py`:
- Around line 206-208: The aprint wrapper currently always passes flush=True
while also forwarding kwargs, causing a duplicate-key TypeError when callers
pass flush in kwargs; update aprint to handle/merge the flush kwarg by
extracting it from kwargs (e.g., pop 'flush' if present) and then call print
with the resolved flush value (use True as the default if not provided) so
callers can override flush without causing a duplicate keyword error; reference
the aprint function in logging.py to locate the change.

In `@modelopt/torch/utils/robust_json.py`:
- Around line 74-78: The return type of json_load is too narrow (annotated as ->
dict) while json.loads can return dict, list, str, int, float, bool, or None;
update json_load's return annotation to reflect that (e.g., use typing.Any or
define a JSONType = Union[dict, list, str, int, float, bool, None] and use it)
and adjust the docstring to say "Return parsed JSON value" instead of "return as
dictionary"; reference the json_load function and the json.loads call when
making this change.

---

Nitpick comments:
In `@modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py`:
- Around line 675-681: Update the docstrings and/or raised NotImplementedError
messages for the state_dict and load_state_dict methods in base_hooks.py to
briefly explain why checkpointing isn't supported (e.g., "Checkpointing not
supported because KV head pruning is a one-time operation and does not maintain
persistent mutable state across runs"), so users understand the design choice;
modify the docstrings of state_dict and load_state_dict and the text passed to
NotImplementedError in those methods to include that brief rationale.
- Around line 810-815: Replace the use of json.dump when writing output_data to
output_path with the project utility json_dump (imported from
modelopt.torch.utils.robust_json) so serialization is consistent and parent
directories are created automatically; update the write block that currently
uses output_path and json.dump to call json_dump(output_data, output_path) (or
add an indent param to json_dump if pretty-printing is required) and remove the
manual open(...) context since json_dump handles file creation.

In `@modelopt/torch/utils/robust_json.py`:
- Around line 47-48: The current fragile dtype detection uses type(o).__name__
== "dtype"; update it to only match actual numpy/torch dtypes by also checking
the class module (e.g., check o.__class__.__module__ and ensure it startswith
"numpy" or "torch") or explicitly check for instances via duck-typing
attributes, and add a brief comment explaining the intent; replace the condition
referencing type(o).__name__ == "dtype" with a combined check using
o.__class__.__module__ (or equivalent) so only numpy/torch dtype classes are
matched in robust_json.py.
- Line 15: Remove the top-level blanket mypy suppression in robust_json.py (the
`# mypy: ignore-errors` comment) and instead address specific typing issues:
delete that header, run mypy to find failing lines, and add targeted `# type:
ignore[...]` comments only on the problematic expressions (or refine annotations
on functions/classes such as any helper functions in robust_json.py) so the rest
of the module retains type checking.

In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py`:
- Around line 248-259: Add hardware-independent shape assertions for the MLP
activation entries similar to the attention checks: for the MLP activations
retrieved from the activations dict (e.g., activations["decoder.layers.0.mlp"]
and activations["decoder.layers.1.mlp"]) assert their .shape equals
torch.Size([hidden_size]) before or alongside the existing min/max/mean
_assert_approx checks; update the test variables corresponding to the first test
case’s MLP activations (the variables that hold those activation tensors) to
include these assert ... .shape == torch.Size([hidden_size]) statements.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ec61ff4f-cb7a-4262-91ac-c9d486e2dd0f

📥 Commits

Reviewing files that changed from the base of the PR and between 6ce8345 and c3870a5.

📒 Files selected for processing (4)
  • modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py
  • modelopt/torch/utils/logging.py
  • modelopt/torch/utils/robust_json.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Comment on lines +206 to +208
def aprint(*args, **kwargs):
"""All ranks from all nodes print."""
print(*args, **kwargs, flush=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid duplicating the flush keyword.

aprint(..., flush=False) currently raises TypeError because flush is passed both via **kwargs and as an explicit keyword. Since this is a new public wrapper around print, it should merge or override the kwarg instead of duplicating it.

Proposed fix
 def aprint(*args, **kwargs):
     """All ranks from all nodes print."""
-    print(*args, **kwargs, flush=True)
+    kwargs = {**kwargs, "flush": True}
+    print(*args, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/logging.py` around lines 206 - 208, The aprint wrapper
currently always passes flush=True while also forwarding kwargs, causing a
duplicate-key TypeError when callers pass flush in kwargs; update aprint to
handle/merge the flush kwarg by extracting it from kwargs (e.g., pop 'flush' if
present) and then call print with the resolved flush value (use True as the
default if not provided) so callers can override flush without causing a
duplicate keyword error; reference the aprint function in logging.py to locate
the change.

Comment on lines +74 to +78
def json_load(path: Path | str) -> dict:
"""Load JSON from file and return as dictionary."""
path = Path(path)
text = path.read_text()
return json.loads(text)
Copy link
Contributor

@coderabbitai coderabbitai bot Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Return type annotation is too restrictive.

json.loads can return various types (dict, list, str, int, float, bool, or None) depending on the JSON content. The current -> dict annotation is incorrect if the JSON root is an array or primitive value. This could cause runtime type mismatches for callers.

🔧 Proposed fix
-def json_load(path: Path | str) -> dict:
-    """Load JSON from file and return as dictionary."""
+def json_load(path: Path | str) -> Any:
+    """Load JSON from file and return the deserialized object."""
     path = Path(path)
     text = path.read_text()
     return json.loads(text)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/robust_json.py` around lines 74 - 78, The return type of
json_load is too narrow (annotated as -> dict) while json.loads can return dict,
list, str, int, float, bool, or None; update json_load's return annotation to
reflect that (e.g., use typing.Any or define a JSONType = Union[dict, list, str,
int, float, bool, None] and use it) and adjust the docstring to say "Return
parsed JSON value" instead of "return as dictionary"; reference the json_load
function and the json.loads call when making this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please apply this suggestion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevalmorabia97, sure! Here's the planned change:

  • Update json_load return type from -> dict to -> Any
  • Update docstring to reflect the broader return type

🧠 Learnings used
Learnt from: kevalmorabia97
Repo: NVIDIA/Model-Optimizer PR: 974
File: modelopt/torch/puzzletron/pruning/pruning_utils.py:195-201
Timestamp: 2026-03-06T14:26:21.439Z
Learning: In NVIDIA/Model-Optimizer, for PyTorch >= 2.6, torch.load() calls without an explicit weights_only argument are safe. Do not flag bare torch.load(...) as a security issue in files under the modelopt package (e.g., modelopt/torch/puzzletron/pruning/pruning_utils.py) as long as the PyTorch version constraint is maintained. If supporting PyTorch < 2.6, require an explicit weights_only argument to torch.load() to avoid potential issues.

Failed to handle agent chat message. Please try again.

- Add base hooks framework in modelopt/torch/nas/plugins/megatron_hooks/
  - base_hooks.py: Core hook infrastructure
  - base_hooks_analysis.py: Analysis utilities for hooks
  - megatron_hooks.py: Megatron-specific hook implementations
  - compare_module_outputs.py: Module comparison utilities
- Add tests for activation hooks
- Update test utilities for distributed testing
- Update minitron pruning tests to use new activation hooks

Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
The activation hooks infrastructure depends on aprint from puzzletron.tools.logger.
Adding minimal logger module to satisfy this dependency.

Note: Some docstring linting warnings are suppressed as this is copied code.
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
… moved later outside of puzzletron module)

Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
…tance_hooks

Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
@danielkorzekwa danielkorzekwa force-pushed the dkorzewa/activation_hooks_redesign_minitron_puzzletron branch from 8b3fb7b to a4b8958 Compare March 13, 2026 10:05
Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (7)
modelopt/torch/prune/importance_hooks/base_hooks_analysis.py (1)

66-103: ⚠️ Potential issue | 🟠 Major

Avoid autograd and fail fast on empty calibration data.

evaluate_importance_scores() is analysis-only, but Lines 76-97 still build a graph for every batch, and Lines 101-102 divide by zero when activations_batches is empty. Guard the input first and run the per-batch evaluation under torch.no_grad().

Suggested fix
 def evaluate_importance_scores(
     linear_layer: nn.Linear,
     activations_batches: list[torch.Tensor],
     importance_scores: torch.Tensor,
     prune_ratio: float = 0.2,
 ) -> dict[str, float]:
@@
-    num_channels = importance_scores.shape[0]
+    if not activations_batches:
+        raise ValueError("activations_batches must be non-empty")
+
+    num_channels = importance_scores.shape[0]
     num_to_prune = int(num_channels * prune_ratio)
@@
-    for activations in activations_batches:
-        # Get original output
-        original_output = linear_layer(activations)
-
-        # Prune by zeroing out identified channels
-        pruned_activations = activations.clone()
-        pruned_activations[..., channels_to_prune] = 0
-
-        # Get pruned output
-        pruned_output = linear_layer(pruned_activations)
-
-        # Compute metrics for this batch
-        rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item()
-        rmse_values.append(rmse)
-
-        # Cosine similarity (flatten to vectors)
-        original_flat = original_output.reshape(-1)
-        pruned_flat = pruned_output.reshape(-1)
-        cosine = F.cosine_similarity(
-            original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1
-        ).item()
-        cosine_values.append(cosine)
+    with torch.no_grad():
+        for activations in activations_batches:
+            original_output = linear_layer(activations)
+
+            pruned_activations = activations.clone()
+            pruned_activations[..., channels_to_prune] = 0
+            pruned_output = linear_layer(pruned_activations)
+
+            rmse = torch.sqrt(F.mse_loss(pruned_output, original_output)).item()
+            rmse_values.append(rmse)
+
+            original_flat = original_output.reshape(-1)
+            pruned_flat = pruned_output.reshape(-1)
+            cosine = F.cosine_similarity(
+                original_flat.unsqueeze(0), pruned_flat.unsqueeze(0), dim=1
+            ).item()
+            cosine_values.append(cosine)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/base_hooks_analysis.py` around lines 66
- 103, The analysis loop in evaluate_importance_scores builds autograd graphs
per batch and will divide by zero if activations_batches is empty; first,
validate activations_batches is non-empty and raise a clear error if empty, then
wrap the per-batch pruning/evaluation (the loop that calls linear_layer on
activations, creates pruned_activations, computes pruned_output/original_output
and metrics) in a torch.no_grad() context so no autograd graph is constructed;
reference the existing symbols importance_scores, channels_to_prune,
activations_batches, linear_layer, and ensure the final averaging uses
len(activations_batches) safely after the emptiness check.
modelopt/torch/prune/importance_hooks/base_hooks.py (4)

152-154: ⚠️ Potential issue | 🟠 Major

Don't mutate the live config just to serialize it.

Line 153 removes model from args.activation_hooks_kwargs in place. Any later reuse of args now sees a different config, and a second dump can fail on the missing key.

Suggested fix
         if rank == 0:
-            args.activation_hooks_kwargs.pop("model")
-            json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json")
+            args_to_dump = OmegaConf.to_container(args, resolve=True)
+            args_to_dump.get("activation_hooks_kwargs", {}).pop("model", None)
+            json_dump(args_to_dump, activations_log_dir / "args.json")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/base_hooks.py` around lines 152 - 154,
The code mutates the live config by calling
args.activation_hooks_kwargs.pop("model") before serializing; instead, create a
copy of the activation-hooks dict or a container representation and remove
"model" from that copy, then call json_dump on the modified copy so args remains
unchanged (i.e., avoid mutating args.activation_hooks_kwargs used elsewhere when
preparing the payload passed to json_dump(OmegaConf.to_container(args,
resolve=True))).

175-178: ⚠️ Potential issue | 🟠 Major

Skip hooks that don't implement checkpoint state.

save_hook_states() currently aborts the entire checkpoint when one hook raises NotImplementedError from state_dict(). IndependentKvHeadContributionHook and LayerNormContributionHook already do that later in this file, so mixed hook sets cannot be checkpointed.

Suggested fix
-        hook_states = {
-            module_name: hook.state_dict() for module_name, hook in activation_hooks.items()
-        }
+        hook_states = {}
+        skipped_modules = []
+        for module_name, hook in activation_hooks.items():
+            try:
+                hook_states[module_name] = hook.state_dict()
+            except NotImplementedError:
+                skipped_modules.append(module_name)
+
+        if skipped_modules:
+            aprint(
+                "Skipping hook checkpoint save for hooks without state support: "
+                f"{skipped_modules}"
+            )
         torch.save(hook_states, hook_states_path)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/base_hooks.py` around lines 175 - 178,
save_hook_states currently calls hook.state_dict() for every entry in
activation_hooks and fails the whole save if any hook raises
NotImplementedError; change save_hook_states to iterate activation_hooks and
call hook.state_dict() inside a try/except that catches NotImplementedError,
skipping that hook (optionally logging debug) and continuing to collect state
dicts for others so mixed sets (e.g., IndependentKvHeadContributionHook,
LayerNormContributionHook) can be checkpointed; ensure the collected hook_states
only includes entries that returned a state dict and then torch.save that
filtered mapping.

469-479: ⚠️ Potential issue | 🟠 Major

Include bias in the recomputed forward path.

For biased nn.Linear, output_tensor already includes the bias term but output_curr on Line 471 does not. That skews scaling_factor_per_token and the downstream channel ranking.

Suggested fix
         curr_activations = activations.clone()  # Shape B,T,I
         curr_activations[..., self.pruned_channels] = 0
-        output_curr = F.linear(input=curr_activations, weight=self.weight_matrix)  # Shape B,T,E
+        bias = None if isinstance(output, tuple) else getattr(module, "bias", None)
+        output_curr = F.linear(
+            input=curr_activations,
+            weight=self.weight_matrix,
+            bias=bias,
+        )  # Shape B,T,E
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/base_hooks.py` around lines 469 - 479,
The recomputed forward pass misses the Linear bias, causing output_curr to be
bias-free and skewing scaling_factor_per_token; update the recompute so
output_curr includes the layer bias (either by passing the bias into F.linear or
adding self.bias to output_curr) before computing output_curr_norms and
scaling_factor_per_token, referencing curr_activations, output_curr,
weight_matrix, self.bias, output_tensor, calibration_method,
scaling_factor_per_token and self.epsilon to ensure the norms compare
like-for-like.

229-233: ⚠️ Potential issue | 🟠 Major

Make L2NormHook checkpoint state device-agnostic.

state_dict() stores _activations as-is and load_state_dict() reattaches it unchanged. If a checkpoint is resumed on a different device, the next self._activations += activations on Line 232 will fail on a device mismatch.

Suggested fix
         if self._activations is None:
             self._activations = activations
         else:
+            if self._activations.device != activations.device:
+                self._activations = self._activations.to(activations.device)
             self._activations += activations
@@
     def state_dict(self) -> dict:
         """Return the state dictionary containing activations."""
-        return {"activations": self._activations}
+        return {
+            "activations": None
+            if self._activations is None
+            else self._activations.cpu().clone()
+        }

Also applies to: 251-257

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/base_hooks.py` around lines 229 - 233,
The L2NormHook's checkpointing stores _activations with device info causing
device-mismatch on resume; update state_dict() to store _activations.cpu() (if
not None) and update load_state_dict() to load the tensor in a device-agnostic
way (e.g. keep on CPU), and before accumulating in the L2NormHook (where
self._activations += activations occurs) ensure device alignment by moving the
stored tensor to the incoming activations device (e.g. self._activations =
self._activations.to(activations.device)) so addition never fails; touch the
L2NormHook methods state_dict, load_state_dict and the accumulation site that
references _activations.
modelopt/torch/prune/importance_hooks/compare_module_outputs.py (1)

240-244: ⚠️ Potential issue | 🟠 Major

Raise on layer-set mismatches.

Printing and returning here makes the CLI exit successfully even though the comparison is invalid, so scripts and CI can accept a bad run as success. Turn this into a hard failure instead.

Suggested fix
     if set(ref_layers) != set(comp_layers):
-        print("\nERROR: Layer mismatch!")
-        print(f"Reference layers: {ref_layers}")
-        print(f"Compare layers: {comp_layers}")
-        return
+        raise ValueError(
+            "Layer mismatch: "
+            f"reference={sorted(ref_layers)}, compare={sorted(comp_layers)}"
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/compare_module_outputs.py` around lines
240 - 244, The code currently prints an error and returns when the sets
ref_layers and comp_layers differ, which hides failures; change this to raise a
hard exception (e.g., RuntimeError) including both ref_layers and comp_layers in
the message so callers and CI see the failure. Locate the mismatch check that
compares set(ref_layers) and set(comp_layers) in compare_module_outputs.py and
replace the print/return block with a raised exception (include clear context
text and the two layer lists) so the process exits non‑zero on mismatch.
tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py (1)

130-143: ⚠️ Potential issue | 🟠 Major

Always remove the forward hook in a finally block.

If the forward loop or hook.accumulate() raises, Lines 142-143 never run and the hook can bleed into the next test executed in the same worker.

Suggested fix
     handle = layer.register_forward_hook(hook)  # Store the handle
-
-    # Run forward passes
-    all_activations = []
-    for _ in range(num_iterations):
-        activations = torch.randn(16, 8, layer.in_features)  # seq=16, batch=8, in_features=50
-        all_activations.append(activations)
-        _ = layer(activations)
-
-    # Get importance scores from hook
-    importance_scores = hook.accumulate()
-
-    # Remove the hook before evaluation to avoid triggering it again
-    handle.remove()
+    try:
+        all_activations = []
+        for _ in range(num_iterations):
+            activations = torch.randn(16, 8, layer.in_features)
+            all_activations.append(activations)
+            _ = layer(activations)
+
+        importance_scores = hook.accumulate()
+    finally:
+        handle.remove()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py` around
lines 130 - 143, The forward hook registered via
layer.register_forward_hook(hook) must be removed inside a finally block to
ensure it is always cleaned up; wrap the forward passes and the call to
hook.accumulate() in a try/finally, call handle.remove() in the finally clause,
and keep the existing order of operations (run iterations, call
hook.accumulate()) so the hook is removed regardless of exceptions during the
forward loop or accumulate call.
🧹 Nitpick comments (1)
tests/gpu/torch/prune/importance_hooks/test_base_hooks.py (1)

24-100: Move this module out of tests/gpu if it stays CPU-only.

Every case here uses CPU nn.Linear and CPU torch.randn inputs, so it spends GPU CI budget without exercising GPU behavior. This belongs under tests/unit unless you plan to add CUDA coverage here.

As per coding guidelines, "Use pytest for all tests; organize tests into tests/unit (CPU-based), tests/gpu (GPU-based), tests/gpu_megatron (Megatron-Core GPU tests), tests/gpu_trtllm (TensorRT-LLM GPU tests), and tests/examples (integration tests)".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/prune/importance_hooks/test_base_hooks.py` around lines 24 -
100, This module only runs CPU code (uses nn.Linear and torch.randn) but is
located under tests/gpu; move the entire file to tests/unit and remove any
GPU-specific pytest markers so it doesn't consume GPU CI budget. Specifically,
relocate the file containing
_test_iterative_channel_contribution_hook_with_shape,
test_iterative_channel_contribution_hook_sbi,
test_iterative_channel_contribution_hook_bsi, and test_l2_norm_hook into the
tests/unit tree, update any imports or package references if they change, and
ensure no `@pytest.mark.gpu` or CI-only decorators remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/prune/importance_hooks/compare_module_outputs.py`:
- Around line 151-187: The code currently computes comparison results (variable
results) inside compare_multi_layer and/or main but only saves them when
args.output_json is provided, so running without --output-json discards metrics;
modify main and/or compare_multi_layer to return the computed results and ensure
main captures that return value (results) and prints a readable summary to
stdout when args.output_json is None, while still writing JSON when
args.output_json is set; reference the main() entry, compare_multi_layer(...)
call, the results variable, and args.output_json so you locate where to
return/capture and branch between printing vs. writing JSON.

---

Duplicate comments:
In `@modelopt/torch/prune/importance_hooks/base_hooks_analysis.py`:
- Around line 66-103: The analysis loop in evaluate_importance_scores builds
autograd graphs per batch and will divide by zero if activations_batches is
empty; first, validate activations_batches is non-empty and raise a clear error
if empty, then wrap the per-batch pruning/evaluation (the loop that calls
linear_layer on activations, creates pruned_activations, computes
pruned_output/original_output and metrics) in a torch.no_grad() context so no
autograd graph is constructed; reference the existing symbols importance_scores,
channels_to_prune, activations_batches, linear_layer, and ensure the final
averaging uses len(activations_batches) safely after the emptiness check.

In `@modelopt/torch/prune/importance_hooks/base_hooks.py`:
- Around line 152-154: The code mutates the live config by calling
args.activation_hooks_kwargs.pop("model") before serializing; instead, create a
copy of the activation-hooks dict or a container representation and remove
"model" from that copy, then call json_dump on the modified copy so args remains
unchanged (i.e., avoid mutating args.activation_hooks_kwargs used elsewhere when
preparing the payload passed to json_dump(OmegaConf.to_container(args,
resolve=True))).
- Around line 175-178: save_hook_states currently calls hook.state_dict() for
every entry in activation_hooks and fails the whole save if any hook raises
NotImplementedError; change save_hook_states to iterate activation_hooks and
call hook.state_dict() inside a try/except that catches NotImplementedError,
skipping that hook (optionally logging debug) and continuing to collect state
dicts for others so mixed sets (e.g., IndependentKvHeadContributionHook,
LayerNormContributionHook) can be checkpointed; ensure the collected hook_states
only includes entries that returned a state dict and then torch.save that
filtered mapping.
- Around line 469-479: The recomputed forward pass misses the Linear bias,
causing output_curr to be bias-free and skewing scaling_factor_per_token; update
the recompute so output_curr includes the layer bias (either by passing the bias
into F.linear or adding self.bias to output_curr) before computing
output_curr_norms and scaling_factor_per_token, referencing curr_activations,
output_curr, weight_matrix, self.bias, output_tensor, calibration_method,
scaling_factor_per_token and self.epsilon to ensure the norms compare
like-for-like.
- Around line 229-233: The L2NormHook's checkpointing stores _activations with
device info causing device-mismatch on resume; update state_dict() to store
_activations.cpu() (if not None) and update load_state_dict() to load the tensor
in a device-agnostic way (e.g. keep on CPU), and before accumulating in the
L2NormHook (where self._activations += activations occurs) ensure device
alignment by moving the stored tensor to the incoming activations device (e.g.
self._activations = self._activations.to(activations.device)) so addition never
fails; touch the L2NormHook methods state_dict, load_state_dict and the
accumulation site that references _activations.

In `@modelopt/torch/prune/importance_hooks/compare_module_outputs.py`:
- Around line 240-244: The code currently prints an error and returns when the
sets ref_layers and comp_layers differ, which hides failures; change this to
raise a hard exception (e.g., RuntimeError) including both ref_layers and
comp_layers in the message so callers and CI see the failure. Locate the
mismatch check that compares set(ref_layers) and set(comp_layers) in
compare_module_outputs.py and replace the print/return block with a raised
exception (include clear context text and the two layer lists) so the process
exits non‑zero on mismatch.

In `@tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py`:
- Around line 130-143: The forward hook registered via
layer.register_forward_hook(hook) must be removed inside a finally block to
ensure it is always cleaned up; wrap the forward passes and the call to
hook.accumulate() in a try/finally, call handle.remove() in the finally clause,
and keep the existing order of operations (run iterations, call
hook.accumulate()) so the hook is removed regardless of exceptions during the
forward loop or accumulate call.

---

Nitpick comments:
In `@tests/gpu/torch/prune/importance_hooks/test_base_hooks.py`:
- Around line 24-100: This module only runs CPU code (uses nn.Linear and
torch.randn) but is located under tests/gpu; move the entire file to tests/unit
and remove any GPU-specific pytest markers so it doesn't consume GPU CI budget.
Specifically, relocate the file containing
_test_iterative_channel_contribution_hook_with_shape,
test_iterative_channel_contribution_hook_sbi,
test_iterative_channel_contribution_hook_bsi, and test_l2_norm_hook into the
tests/unit tree, update any imports or package references if they change, and
ensure no `@pytest.mark.gpu` or CI-only decorators remain.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b237c9fc-1932-4cbf-a572-bf2cebf1729d

📥 Commits

Reviewing files that changed from the base of the PR and between c3870a5 and 8b3fb7b.

📒 Files selected for processing (8)
  • modelopt/torch/prune/importance_hooks/__init__.py
  • modelopt/torch/prune/importance_hooks/base_hooks.py
  • modelopt/torch/prune/importance_hooks/base_hooks_analysis.py
  • modelopt/torch/prune/importance_hooks/compare_module_outputs.py
  • modelopt/torch/prune/importance_hooks/plugins/__init__.py
  • modelopt/torch/prune/importance_hooks/plugins/megatron_hooks.py
  • tests/gpu/torch/prune/importance_hooks/test_base_hooks.py
  • tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/prune/importance_hooks/plugins/init.py

Comment on lines +151 to +187
def main():
"""Compare module output tensors from different model variants."""
parser = argparse.ArgumentParser(
description="Compare module output tensors from different model variants",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--reference",
type=str,
required=True,
help="Path to reference output tensor (e.g., unpruned model)",
)
parser.add_argument(
"--compare",
type=str,
required=True,
help="Path to output tensor to compare against reference",
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save comparison statistics as JSON",
)
args = parser.parse_args()

# Load reference data
print(f"\nLoading reference: {args.reference}")
ref_data = torch.load(args.reference, map_location="cpu")

# Load comparison data
print(f"Loading compare: {args.compare}")
comp_data = torch.load(args.compare, map_location="cpu")

# Compare multi-layer outputs
compare_multi_layer(ref_data, comp_data, args.output_json)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't drop the computed metrics when --output-json is omitted.

Lines 187 and 246-287 compute results and then discard them. python ... --reference ... --compare ... therefore emits no comparison output unless --output-json is set, even though that flag is optional.

Suggested fix
 def main():
@@
-    compare_multi_layer(ref_data, comp_data, args.output_json)
+    results = compare_multi_layer(ref_data, comp_data, args.output_json)
+    if not args.output_json:
+        print(results["aggregated"])
@@
-def compare_multi_layer(ref_data: dict, comp_data: dict, output_json: str | None = None):
+def compare_multi_layer(
+    ref_data: dict, comp_data: dict, output_json: str | None = None
+) -> dict:
@@
     if output_json:
         # Remove raw lists for JSON serialization
         results["aggregated"].pop("rmse", None)
         results["aggregated"].pop("cosine_sim_mean", None)
@@
         with open(output_json, "w") as f:
             json.dump(results, f, indent=2)
         print(f"Saved comparison results to {output_json}")
+
+    return results

Also applies to: 233-287

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/compare_module_outputs.py` around lines
151 - 187, The code currently computes comparison results (variable results)
inside compare_multi_layer and/or main but only saves them when args.output_json
is provided, so running without --output-json discards metrics; modify main
and/or compare_multi_layer to return the computed results and ensure main
captures that return value (results) and prints a readable summary to stdout
when args.output_json is None, while still writing JSON when args.output_json is
set; reference the main() entry, compare_multi_layer(...) call, the results
variable, and args.output_json so you locate where to return/capture and branch
between printing vs. writing JSON.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (7)
modelopt/torch/prune/importance_hooks/base_hooks.py (4)

469-471: ⚠️ Potential issue | 🟠 Major

Include bias in the pruned forward recomputation.

For biased nn.Linear layers, output_tensor includes the bias term but output_curr (line 471) does not. This mismatch causes the scaling factor computation and subsequent channel importance ranking to be incorrect.

🛠️ Suggested fix
         curr_activations = activations.clone()  # Shape B,T,I
         curr_activations[..., self.pruned_channels] = 0
-        output_curr = F.linear(input=curr_activations, weight=self.weight_matrix)  # Shape B,T,E
+        # Include bias if present (for nn.Linear with bias=True)
+        bias = getattr(module, "bias", None)
+        output_curr = F.linear(input=curr_activations, weight=self.weight_matrix, bias=bias)  # Shape B,T,E
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/base_hooks.py` around lines 469 - 471,
The recomputed output_curr used for scaling lacks the linear layer bias, causing
mismatch with output_tensor; update the pruned forward recomputation (around
activations.clone(), self.pruned_channels, and the F.linear call that produces
output_curr) to include the layer bias (the same bias used to compute
output_tensor, e.g. self.bias) so output_curr and output_tensor are consistent
before computing the scaling factors and channel importance.

152-154: ⚠️ Potential issue | 🟠 Major

Don't mutate args.activation_hooks_kwargs just to dump JSON.

pop("model") edits the live config object. Reusing the same args later loses the model reference, and a second dump on rank 0 can fail on the missing key.

🛠️ Suggested fix
         if rank == 0:
-            args.activation_hooks_kwargs.pop("model")
-            json_dump(OmegaConf.to_container(args, resolve=True), activations_log_dir / "args.json")
+            args_to_dump = OmegaConf.to_container(args, resolve=True)
+            args_to_dump.get("activation_hooks_kwargs", {}).pop("model", None)
+            json_dump(args_to_dump, activations_log_dir / "args.json")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/base_hooks.py` around lines 152 - 154,
The code is mutating the live config by calling
args.activation_hooks_kwargs.pop("model") before writing JSON; instead avoid
editing args: make a shallow copy of args.activation_hooks_kwargs (or convert
args to a container via OmegaConf.to_container(args, resolve=True)), remove the
"model" key from that copy, and then call json_dump with the sanitized container
and activations_log_dir / "args.json" so the original
args.activation_hooks_kwargs remains intact (reference:
args.activation_hooks_kwargs, OmegaConf.to_container, json_dump,
activations_log_dir).

174-178: ⚠️ Potential issue | 🟠 Major

save_hook_states() needs to handle hooks without checkpoint support.

This helper blindly calls state_dict() on every hook, but IndependentKvHeadContributionHook and LayerNormContributionHook raise NotImplementedError. One such hook makes periodic checkpointing fail for the whole run.

🛠️ Suggested fix
-        hook_states = {
-            module_name: hook.state_dict() for module_name, hook in activation_hooks.items()
-        }
+        hook_states = {}
+        for module_name, hook in activation_hooks.items():
+            try:
+                hook_states[module_name] = hook.state_dict()
+            except NotImplementedError:
+                pass  # Skip hooks that don't support checkpointing
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/base_hooks.py` around lines 174 - 178,
The save_hook_states helper currently calls hook.state_dict() for every entry in
activation_hooks which crashes when some hooks (e.g.,
IndependentKvHeadContributionHook, LayerNormContributionHook) raise
NotImplementedError; update save_hook_states to catch NotImplementedError (and
optionally other exceptions) when calling hook.state_dict(), skip those hooks
instead of failing the whole checkpoint, and collect only successful state_dicts
into hook_states before torch.save; include reference to activation_hooks,
save_hook_states, hook.state_dict(), and the specific hook classes to locate and
fix the code.

251-257: ⚠️ Potential issue | 🟠 Major

Make L2NormHook state reload device-agnostic.

After restoring from checkpoint, _activations remains on the checkpoint device (typically CPU). When the resumed layer runs on a different device (GPU), accumulation at line 232 fails due to device mismatch in the += operation.

🛠️ Suggested fix
     def state_dict(self) -> dict:
         """Return the state dictionary containing activations."""
-        return {"activations": self._activations}
+        return {
+            "activations": None if self._activations is None else self._activations.cpu().clone()
+        }
 
     def load_state_dict(self, state_dict: dict) -> None:
         """Load activations from checkpoint."""
         self._activations = state_dict["activations"]
+        # Device will be corrected on first accumulation if needed

Then update __call__ to handle device mismatch:

         if self._activations is None:
             self._activations = activations
         else:
+            if self._activations.device != activations.device:
+                self._activations = self._activations.to(activations.device)
             self._activations += activations
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/base_hooks.py` around lines 251 - 257,
The L2NormHook's load_state_dict stores _activations on the checkpoint device
causing device-mismatch on resume; modify load_state_dict in class L2NormHook to
not assume device (e.g., keep activations as CPU tensors or plain lists) or to
defer device placement, and update the __call__ method to ensure
self._activations is moved to the current input/device before performing the
in-place accumulation (use self._activations = self._activations.to(x.device) or
create a device-matched tensor copy) so the += at accumulation time succeeds
without device errors; reference methods: L2NormHook.load_state_dict,
L2NormHook.state_dict, and L2NormHook.__call__, and field _activations.
tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py (1)

130-143: ⚠️ Potential issue | 🟠 Major

Always remove the forward hook in a finally block.

handle.remove() only runs on the happy path. If hook.accumulate() fails, the hook stays attached in the reused worker and can bleed into subsequent tests.

🛠️ Suggested fix
     handle = layer.register_forward_hook(hook)  # Store the handle
 
-    # Run forward passes
-    all_activations = []
-    for _ in range(num_iterations):
-        activations = torch.randn(16, 8, layer.in_features)  # seq=16, batch=8, in_features=50
-        all_activations.append(activations)
-        _ = layer(activations)
-
-    # Get importance scores from hook
-    importance_scores = hook.accumulate()
-
-    # Remove the hook before evaluation to avoid triggering it again
-    handle.remove()
+    try:
+        # Run forward passes
+        all_activations = []
+        for _ in range(num_iterations):
+            activations = torch.randn(16, 8, layer.in_features)
+            all_activations.append(activations)
+            _ = layer(activations)
+
+        # Get importance scores from hook
+        importance_scores = hook.accumulate()
+    finally:
+        # Remove the hook to avoid triggering it again (even on failure)
+        handle.remove()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py` around
lines 130 - 143, The forward hook is only removed on the happy path; wrap the
forward pass and accumulation in a try/finally so the handle is always removed:
after calling layer.register_forward_hook(hook) store the handle, run the
forward passes and call hook.accumulate() inside a try block, and call
handle.remove() in the finally block to ensure the hook is detached even if
hook.accumulate() raises (references: layer.register_forward_hook,
hook.accumulate, handle.remove).
modelopt/torch/prune/importance_hooks/compare_module_outputs.py (2)

239-294: ⚠️ Potential issue | 🟠 Major

Don't drop computed metrics when --output-json is omitted.

compare_multi_layer computes results but doesn't return them or print them when output_json is not set. Running python ... --reference ... --compare ... produces no comparison output unless --output-json is provided, even though that flag is optional.

🛠️ Suggested fix
-def compare_multi_layer(ref_data: dict, comp_data: dict, output_json: str | None = None):
+def compare_multi_layer(ref_data: dict, comp_data: dict, output_json: str | None = None) -> dict:
     """Compare multi-layer outputs."""
     import json
 
@@
     # Save to JSON if requested
     if output_json:
         # Remove raw lists for JSON serialization
         results["aggregated"].pop("rmse", None)
         results["aggregated"].pop("cosine_sim_mean", None)
 
         with open(output_json, "w") as f:
             json.dump(results, f, indent=2)
         print(f"Saved comparison results to {output_json}")
+    else:
+        # Print aggregated results to stdout when no JSON output requested
+        print("\n=== Comparison Results ===")
+        if "rmse_stats" in results["aggregated"]:
+            print(f"RMSE: {results['aggregated']['rmse_stats']}")
+        if "cosine_sim_stats" in results["aggregated"]:
+            print(f"Cosine Similarity: {results['aggregated']['cosine_sim_stats']}")
+
+    return results
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/compare_module_outputs.py` around lines
239 - 294, The function compare_multi_layer builds a results dict but only
writes it to disk when output_json is provided and otherwise drops it; update
compare_multi_layer to return the results dict unconditionally (i.e., at end of
the function return results) and/or print a concise human-readable summary when
output_json is None so CLI use shows output; make the return explicit (return
results) and keep the existing JSON-write branch (remove the pop-only behavior
before returning so callers still receive full metrics), referencing
compare_multi_layer, the results variable, and the output_json parameter.

246-250: ⚠️ Potential issue | 🟠 Major

Raise on layer mismatch instead of returning successfully.

This silent return on error makes the CLI exit with status 0, so scripts and CI can silently accept invalid comparisons. This should raise an exception to fail loudly.

🛠️ Suggested fix
     if set(ref_layers) != set(comp_layers):
-        print("\nERROR: Layer mismatch!")
-        print(f"Reference layers: {ref_layers}")
-        print(f"Compare layers: {comp_layers}")
-        return
+        raise ValueError(
+            f"Layer mismatch: reference={sorted(ref_layers)}, compare={sorted(comp_layers)}"
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/compare_module_outputs.py` around lines
246 - 250, The code currently prints an error and returns when layer sets differ
(in compare_module_outputs.py where ref_layers and comp_layers are compared),
which causes a successful exit; change this to raise an exception (e.g., raise
ValueError or RuntimeError) including the mismatched layer details and context
(e.g., f"Layer mismatch: reference={ref_layers}, compare={comp_layers}") so the
process fails loudly; update the block that checks set(ref_layers) !=
set(comp_layers) to raise the error instead of printing and returning.
🧹 Nitpick comments (2)
tests/_test_utils/torch/distributed/utils.py (1)

139-143: Consider adding LOCAL_WORLD_SIZE and WANDB_DISABLED for consistency with init_process.

The init_process function now sets LOCAL_WORLD_SIZE and WANDB_DISABLED, but _worker_loop does not. This inconsistency could cause different behavior between spawn_multiprocess_job and DistributedWorkerPool.

♻️ Proposed fix for consistency
     `@staticmethod`
     def _worker_loop(rank, world_size, backend, port, cmd_queue, result_queue, teardown_fn):
         os.environ["MASTER_ADDR"] = "localhost"
         os.environ["MASTER_PORT"] = str(port)
         os.environ["LOCAL_RANK"] = str(rank)
         os.environ["RANK"] = str(rank)
         os.environ["WORLD_SIZE"] = str(world_size)
+        os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
+        os.environ["WANDB_DISABLED"] = "true"
         dist.init_process_group(backend, rank=rank, world_size=world_size)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/_test_utils/torch/distributed/utils.py` around lines 139 - 143,
_worker_loop sets essential torch distributed env vars but omits
LOCAL_WORLD_SIZE and WANDB_DISABLED, causing inconsistency with init_process and
differing behavior between spawn_multiprocess_job and DistributedWorkerPool;
update _worker_loop to also set os.environ["LOCAL_WORLD_SIZE"] =
str(world_size_local_or_computed) and os.environ["WANDB_DISABLED"] = "true" (or
mirror the same logic used in init_process) alongside
MASTER_ADDR/MASTER_PORT/LOCAL_RANK/RANK/WORLD_SIZE so spawned workers see the
same environment as init_process.
modelopt/torch/prune/importance_hooks/base_hooks.py (1)

675-681: Consider providing default state_dict implementations instead of raising.

IndependentKvHeadContributionHook and LayerNormContributionHook raise NotImplementedError for state_dict() and load_state_dict(). While this is documented, it creates compatibility issues with save_hook_states(). Consider either:

  1. Implementing minimal state dict support, or
  2. Adding a supports_checkpointing property to the base class

Also applies to: 724-730

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/base_hooks.py` around lines 675 - 681,
The base hook currently raises NotImplementedError in state_dict() and
load_state_dict(), causing save_hook_states() to fail for hooks like
IndependentKvHeadContributionHook and LayerNormContributionHook; modify the base
class (the hook base where state_dict and load_state_dict are defined) to
provide a safe default: implement state_dict() to return an empty dict and
load_state_dict(state_dict) to no-op, OR add a boolean property
supports_checkpointing (default False) and have save_hook_states() check this
property before calling state_dict()/load_state_dict(); update the specific
hooks (IndependentKvHeadContributionHook, LayerNormContributionHook) only if
they need real state to set supports_checkpointing True and override state
methods.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/prune/importance_hooks/base_hooks.py`:
- Around line 469-471: The recomputed output_curr used for scaling lacks the
linear layer bias, causing mismatch with output_tensor; update the pruned
forward recomputation (around activations.clone(), self.pruned_channels, and the
F.linear call that produces output_curr) to include the layer bias (the same
bias used to compute output_tensor, e.g. self.bias) so output_curr and
output_tensor are consistent before computing the scaling factors and channel
importance.
- Around line 152-154: The code is mutating the live config by calling
args.activation_hooks_kwargs.pop("model") before writing JSON; instead avoid
editing args: make a shallow copy of args.activation_hooks_kwargs (or convert
args to a container via OmegaConf.to_container(args, resolve=True)), remove the
"model" key from that copy, and then call json_dump with the sanitized container
and activations_log_dir / "args.json" so the original
args.activation_hooks_kwargs remains intact (reference:
args.activation_hooks_kwargs, OmegaConf.to_container, json_dump,
activations_log_dir).
- Around line 174-178: The save_hook_states helper currently calls
hook.state_dict() for every entry in activation_hooks which crashes when some
hooks (e.g., IndependentKvHeadContributionHook, LayerNormContributionHook) raise
NotImplementedError; update save_hook_states to catch NotImplementedError (and
optionally other exceptions) when calling hook.state_dict(), skip those hooks
instead of failing the whole checkpoint, and collect only successful state_dicts
into hook_states before torch.save; include reference to activation_hooks,
save_hook_states, hook.state_dict(), and the specific hook classes to locate and
fix the code.
- Around line 251-257: The L2NormHook's load_state_dict stores _activations on
the checkpoint device causing device-mismatch on resume; modify load_state_dict
in class L2NormHook to not assume device (e.g., keep activations as CPU tensors
or plain lists) or to defer device placement, and update the __call__ method to
ensure self._activations is moved to the current input/device before performing
the in-place accumulation (use self._activations =
self._activations.to(x.device) or create a device-matched tensor copy) so the +=
at accumulation time succeeds without device errors; reference methods:
L2NormHook.load_state_dict, L2NormHook.state_dict, and L2NormHook.__call__, and
field _activations.

In `@modelopt/torch/prune/importance_hooks/compare_module_outputs.py`:
- Around line 239-294: The function compare_multi_layer builds a results dict
but only writes it to disk when output_json is provided and otherwise drops it;
update compare_multi_layer to return the results dict unconditionally (i.e., at
end of the function return results) and/or print a concise human-readable
summary when output_json is None so CLI use shows output; make the return
explicit (return results) and keep the existing JSON-write branch (remove the
pop-only behavior before returning so callers still receive full metrics),
referencing compare_multi_layer, the results variable, and the output_json
parameter.
- Around line 246-250: The code currently prints an error and returns when layer
sets differ (in compare_module_outputs.py where ref_layers and comp_layers are
compared), which causes a successful exit; change this to raise an exception
(e.g., raise ValueError or RuntimeError) including the mismatched layer details
and context (e.g., f"Layer mismatch: reference={ref_layers},
compare={comp_layers}") so the process fails loudly; update the block that
checks set(ref_layers) != set(comp_layers) to raise the error instead of
printing and returning.

In `@tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py`:
- Around line 130-143: The forward hook is only removed on the happy path; wrap
the forward pass and accumulation in a try/finally so the handle is always
removed: after calling layer.register_forward_hook(hook) store the handle, run
the forward passes and call hook.accumulate() inside a try block, and call
handle.remove() in the finally block to ensure the hook is detached even if
hook.accumulate() raises (references: layer.register_forward_hook,
hook.accumulate, handle.remove).

---

Nitpick comments:
In `@modelopt/torch/prune/importance_hooks/base_hooks.py`:
- Around line 675-681: The base hook currently raises NotImplementedError in
state_dict() and load_state_dict(), causing save_hook_states() to fail for hooks
like IndependentKvHeadContributionHook and LayerNormContributionHook; modify the
base class (the hook base where state_dict and load_state_dict are defined) to
provide a safe default: implement state_dict() to return an empty dict and
load_state_dict(state_dict) to no-op, OR add a boolean property
supports_checkpointing (default False) and have save_hook_states() check this
property before calling state_dict()/load_state_dict(); update the specific
hooks (IndependentKvHeadContributionHook, LayerNormContributionHook) only if
they need real state to set supports_checkpointing True and override state
methods.

In `@tests/_test_utils/torch/distributed/utils.py`:
- Around line 139-143: _worker_loop sets essential torch distributed env vars
but omits LOCAL_WORLD_SIZE and WANDB_DISABLED, causing inconsistency with
init_process and differing behavior between spawn_multiprocess_job and
DistributedWorkerPool; update _worker_loop to also set
os.environ["LOCAL_WORLD_SIZE"] = str(world_size_local_or_computed) and
os.environ["WANDB_DISABLED"] = "true" (or mirror the same logic used in
init_process) alongside MASTER_ADDR/MASTER_PORT/LOCAL_RANK/RANK/WORLD_SIZE so
spawned workers see the same environment as init_process.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d2616fa7-bfde-40af-b662-1b6921a146c1

📥 Commits

Reviewing files that changed from the base of the PR and between 8b3fb7b and 91fa36d.

📒 Files selected for processing (13)
  • modelopt/torch/prune/importance_hooks/__init__.py
  • modelopt/torch/prune/importance_hooks/base_hooks.py
  • modelopt/torch/prune/importance_hooks/base_hooks_analysis.py
  • modelopt/torch/prune/importance_hooks/compare_module_outputs.py
  • modelopt/torch/prune/importance_hooks/plugins/__init__.py
  • modelopt/torch/prune/importance_hooks/plugins/megatron_hooks.py
  • modelopt/torch/utils/logging.py
  • modelopt/torch/utils/robust_json.py
  • tests/_test_utils/torch/distributed/utils.py
  • tests/conftest.py
  • tests/gpu/torch/prune/importance_hooks/test_base_hooks.py
  • tests/gpu/torch/prune/importance_hooks/test_base_hooks_analysis.py
  • tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • modelopt/torch/utils/logging.py
  • modelopt/torch/prune/importance_hooks/plugins/init.py
  • tests/conftest.py
  • modelopt/torch/prune/importance_hooks/init.py
  • modelopt/torch/utils/robust_json.py
  • modelopt/torch/prune/importance_hooks/base_hooks_analysis.py

…alysis

Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com>
@codecov
Copy link

codecov bot commented Mar 13, 2026

Codecov Report

❌ Patch coverage is 50.00000% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 70.09%. Comparing base (fe83270) to head (6345290).
⚠️ Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/utils/logging.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1022      +/-   ##
==========================================
- Coverage   70.25%   70.09%   -0.17%     
==========================================
  Files         220      221       +1     
  Lines       25368    25461      +93     
==========================================
+ Hits        17822    17846      +24     
- Misses       7546     7615      +69     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/prune/importance_hooks/base_hooks_analysis.py`:
- Line 29: The annotated return type dict[str, float] is inaccurate because
num_pruned is an int; update the function's return annotation to accurately
reflect types (e.g., dict[str, float | int] or better, define a TypedDict like
ImportanceMetrics with rmse: float, cosine_similarity: float, num_pruned: int
and use it as the return type). Adjust imports (from typing import TypedDict or
use PEP 604 union syntax) and update the function signature to return the new
type; reference the existing num_pruned variable and the function in
base_hooks_analysis.py when making the change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 23444f08-981c-407b-bfc3-96efaa97330d

📥 Commits

Reviewing files that changed from the base of the PR and between 91fa36d and 6345290.

📒 Files selected for processing (1)
  • modelopt/torch/prune/importance_hooks/base_hooks_analysis.py

activations_batches: list[torch.Tensor],
importance_scores: torch.Tensor,
prune_ratio: float = 0.2,
) -> dict[str, float]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Return type annotation is slightly inaccurate.

The return type is annotated as dict[str, float], but num_pruned is an int (line 110). Consider using a more accurate type hint.

🔧 Proposed fix
-) -> dict[str, float]:
+) -> dict[str, float | int]:

Alternatively, for more precise typing:

from typing import TypedDict

class ImportanceMetrics(TypedDict):
    rmse: float
    cosine_similarity: float
    num_pruned: int
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/prune/importance_hooks/base_hooks_analysis.py` at line 29, The
annotated return type dict[str, float] is inaccurate because num_pruned is an
int; update the function's return annotation to accurately reflect types (e.g.,
dict[str, float | int] or better, define a TypedDict like ImportanceMetrics with
rmse: float, cosine_similarity: float, num_pruned: int and use it as the return
type). Adjust imports (from typing import TypedDict or use PEP 604 union syntax)
and update the function signature to return the new type; reference the existing
num_pruned variable and the function in base_hooks_analysis.py when making the
change.

"""Error for deprecated functions."""


def aprint(*args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aprint function seems unnecessary. Why dont we just use standard print everywhere instead of having to import aprint and using it? Flush=True is generally not needed everywhere either

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants