Skip to content

Sequential calibrate refactor#982

Open
sugunav14 wants to merge 11 commits intomainfrom
svelury/sequential-calibrate-refactor
Open

Sequential calibrate refactor#982
sugunav14 wants to merge 11 commits intomainfrom
svelury/sequential-calibrate-refactor

Conversation

@sugunav14
Copy link
Contributor

@sugunav14 sugunav14 commented Mar 5, 2026

What does this PR do?

Type of change: New feature

The current sequential calibration support has O(N^2) complexity for collecting updated activations for a decoder layer. To solve this, we adopted a modular/plugin based approach which involves hooks to capture the updated activations by running forward on the previous decoder layer using cached prev layer activations. This leads to an issue with nested modules i.e. the logic in the parent module might need to be replicated in the lower level modules to ensure equivalence. For example, in the nemotron model, the parent module NemotronHModel has logic to create and select appropriate mask based on the decoder layer type (mamba vs attention).

This PR implements a more generic solution for sequential calibration, by choosing to collect activations using model forward, thereby ensuring that all the parent module logic is preserved. We use an attribute "state"on the modules to indicate whether to perform recomputation/skip the layer while running module forward. This can help us avoid redundant computations for getting updated activations.

The overall flow is as follows

  1. The user must register a get_decoder_layers() function that returns a list of layers to be calibrated sequentially
  2. LayerActivationCollector, goes through the list of layers and patches module forward with a "state aware" module forward
  3. When model.forward() is called, all the parent logic is recomputed as expected (embeddings, residual connections, generating attention mask etc).
  4. Lets say we are currently calibrating layer N and we want to get updated activations; we set layer N to capture and layer N-1 to run (because this layer was processed previously and updated activations need to be generated). Already processed layers are set to skip. When model.forward() is called, all the previous decoder layer computations are skipped. Layer N-1 uses the cached inputs to generate new activations. Layer N inputs are captured using the same logic as before and cached so that they can be used to get updated activations for Layer N+1.

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, using torch.load(..., weights_only=True), avoiding pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other source, did you follow IP policy in CONTRIBUTING.md?: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Layer-by-layer activation capture and sequential calibration for transformer models, with improved robustness and clearer progress logging.
    • Expanded HuggingFace model support with automatic decoder-layer discovery for more model families.
  • Bug Fixes / Reliability

    • Safer patching/unpatching during calibration to ensure cleanup on errors and reduce resource leaks.
  • Tests

    • Extensive tests added for sequential calibration, activation collection, decoder discovery, and inter-layer behaviors.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 5, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 5, 2026

📝 Walkthrough

Walkthrough

Adds a new LayerActivationCollector module for sequential, per-decoder-layer activation capture and calibration; integrates it into the sequential_calibrate flow, adds HuggingFace decoder discovery helpers, removes prior in-file collector utilities, and expands tests covering discovery and sequential calibration behaviors.

Changes

Cohort / File(s) Summary
Activation Collector
modelopt/torch/quantization/activation_collector.py
New module implementing LayerActivationCollector with per-layer _LayerCalibState, patched forward modes (skip/run/capture/original), discovery/registration of decoder-layer providers, output meta extraction, zero-output reconstruction, and patch/unpatch lifecycle management.
Calibration Flow
modelopt/torch/quantization/model_calib.py
Replaces prior decoder lookup with LayerActivationCollector.get_decoder_layers; patches transformer layers via _patch_all_layers, uses get_input_activations(layer, forward_loop) for per-layer capture, adds logging and ensures _unpatch_all_layers() runs in finally.
HuggingFace Plugin
modelopt/torch/quantization/plugins/huggingface.py
Adds is_nemotron_h_model, get_nemotron_h_decoder_layers, is_homogeneous_hf_model, get_homogeneous_hf_decoder_layers and registers them with LayerActivationCollector to enable HF decoder discovery.
Utils & Network Cleanups
modelopt/torch/quantization/utils.py, modelopt/torch/utils/network.py
Removes embedded LayerActivationCollector and _EarlyStopForwardError from utils.py and imports the new module; removes get_decoder_layers from network.py, delegating discovery to LayerActivationCollector.
Tests - HuggingFace & Discovery
tests/unit/torch/quantization/plugins/test_huggingface.py
Adds tests validating HF model predicates, decoder discovery functions, and registration with LayerActivationCollector.
Tests - Calibration & Collector
tests/unit/torch/quantization/test_calib.py, tests/unit/torch/quantization/test_sequential_calibrate.py, tests/unit/torch/quantization/test_utils.py
Expands coverage for sequential_calibrate, per-layer input capture/replay semantics, inter-layer logic, patch lifecycle (_patch_all_layers/_unpatch_all_layers), output_meta handling, and LayerActivationCollector public API behaviors.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as Caller / Test
    participant LAC as LayerActivationCollector
    participant Model as Model
    participant Layer as DecoderLayer

    Caller->>LAC: _patch_all_layers()
    LAC->>Model: attach _seq_calib to decoder layers
    LAC->>Layer: replace forward with _patched_forward

    loop per target layer
        Caller->>LAC: _set_layer_states(target=Layer)
        LAC->>Layer: set mode=capture (others=skip/run)
        Caller->>LAC: get_input_activations(Layer, forward_loop)
        LAC->>Model: run forward_loop
        Model->>Layer: forward() -> _patched_forward(mode=capture)
        Layer-->>LAC: record inputs + raise _EarlyStopForwardError
        LAC->>LAC: collect inputs, restore layer mode
        LAC-->>Caller: return captured inputs
    end

    Caller->>LAC: _unpatch_all_layers()
    LAC->>Layer: restore original forward
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.79% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Sequential calibrate refactor' is partially related to the changeset. It references the main refactoring work (sequential calibration) but lacks specificity about the key architectural change—migrating from an O(N²) approach to a modular, state-driven plugin-based system using LayerActivationCollector.
Security Anti-Patterns ✅ Passed PR does not introduce any of the five critical security anti-patterns specified in SECURITY.md. New activation_collector.py and modifications follow secure coding practices.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch svelury/sequential-calibrate-refactor

Comment @coderabbitai help to get the list of available commands and usage tips.

sugunav14 and others added 6 commits March 5, 2026 21:20
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
@sugunav14 sugunav14 force-pushed the svelury/sequential-calibrate-refactor branch from 13b9033 to 7f72422 Compare March 5, 2026 21:30
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>

class LayerActivationCollector:
"""Helper class for collecting layer activations during forward passes.
@dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move the patching & capture related logics to a new file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, will do!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors sequential calibration to avoid O(N²) activation collection by introducing a stateful, model-forward-based activation collection approach (skip/run/capture) that preserves parent-module forward logic.

Changes:

  • Replaces the previous decoder-layer detection + per-layer patching with a state-aware LayerActivationCollector that patches all decoder layers and uses early-stop to capture inputs efficiently.
  • Updates sequential_calibrate() to use the new collector and removes the legacy get_decoder_layers() heuristic from modelopt/torch/utils/network.py.
  • Adds/extends unit tests to validate skip/run/capture behavior, decoder-layer discoverer registration, and HuggingFace integration.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tests/unit/torch/quantization/test_utils.py Adds unit tests for the new LayerActivationCollector registration/discovery API and skip/run/capture semantics.
tests/unit/torch/quantization/test_sequential_calibrate.py Adds deeper behavioral tests for tuple outputs, inter-layer ops, mode transitions, and cleanup for sequential calibration patching.
tests/unit/torch/quantization/test_calib.py Adds tests for the new sequential calibration support gate and activation propagation behavior.
tests/unit/torch/quantization/plugins/test_huggingface.py Adds tests validating HuggingFace decoder-layer discoverer registration and homogeneous model detection.
modelopt/torch/utils/network.py Removes the old get_decoder_layers() heuristic helper.
modelopt/torch/quantization/utils.py Introduces the new stateful LayerActivationCollector with patched forward and output-metadata-based skipping.
modelopt/torch/quantization/plugins/huggingface.py Registers HuggingFace decoder-layer discoverers with LayerActivationCollector.
modelopt/torch/quantization/model_calib.py Updates sequential_calibrate() to use the new collector lifecycle (patch once, calibrate layer-by-layer, unpatch).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1198 to +1205
def is_homogenous_hf_model(model: nn.Module) -> bool:
if is_nemotron_h_model(model):
return False
decoder_layers = get_homogeneous_hf_decoder_layers(model)
if decoder_layers is None or len(decoder_layers) == 0:
return False
layer_classes = {type(layer) for layer in decoder_layers}
return len(layer_classes) == 1
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The helper is named is_homogenous_hf_model, but the correct term is “homogeneous”. Since this is a newly introduced API surface (and is referenced by tests/registration), consider renaming it (and its uses) to is_homogeneous_hf_model for clarity and to avoid baking in a misspelling.

Copilot uses AI. Check for mistakes.
Comment on lines +406 to +412
def _register_test_discoverer(monkeypatch):
"""Register a simple discoverer that finds model.layers on any model."""
monkeypatch.setattr(
LayerActivationCollector,
"_decoder_layer_support",
[(lambda m: hasattr(m, "layers"), lambda m: m.layers)],
)
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file now introduces _register_test_discoverer() for decoder-layer discovery, but earlier tests in the same module still call LayerActivationCollector.get_input_activations() / sequential_calibrate() without registering any decoder-layer support (and without patching collector layers). With the new LayerActivationCollector implementation, those tests will fail unless they register a discoverer (like this helper) and patch/unpatch appropriately.

Copilot uses AI. Check for mistakes.
Comment on lines +855 to +877
_decoder_layer_support: list[tuple[Any, Any]] = []
_LAYER_ATTR = "_seq_calib"

def __init__(self, model: nn.Module):
self.model = model
self._decoder_layers: nn.ModuleList | None = None
self._layer_to_idx: dict[nn.Module, int] = {}
self._patched = False

# ------------------------------------------------------------------
# Decoder-layer discovery
# ------------------------------------------------------------------

@staticmethod
def _patch_and_initialize_layer(layer: torch.nn.Module, stop_after_collection: bool = False):
"""Patch a layer to collect inputs during forward passes."""
def get_decoder_layers(model: nn.Module) -> nn.ModuleList | None:
"""Return decoder layers supported by sequential calibration."""
for is_supported, discoverer in LayerActivationCollector._decoder_layer_support:
if not is_supported(model):
continue
decoder_layers = discoverer(model)
if decoder_layers is not None:
return decoder_layers
return None
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LayerActivationCollector._decoder_layer_support defaults to an empty list, so LayerActivationCollector.is_supported() (and thus sequential_calibrate()) will always fail unless some plugin/module has been imported that registers discoverers. If sequential_calibrate/collector are intended to be usable when imported directly (without modelopt.torch.quantization which imports plugins), consider registering baseline discoverers for common patterns (e.g. model.layers, model.model.layers, model.decoder.layers, model.transformer.h) in utils.py, similar to the removed get_decoder_layers heuristic.

Copilot uses AI. Check for mistakes.
Layers before the target are skipped or re-run (if just calibrated), the
target layer captures its inputs, and an early-stop prevents unnecessary
computation beyond the target.
"""
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LayerActivationCollector.get_input_activations() assumes _patch_all_layers() has already been called (it indexes into _layer_to_idx and relies on patched forwards). As written, calling this public method without patching will raise KeyError / behave incorrectly. Either make patching lazy inside get_input_activations() (patch on first use) or add a clear assertion/error telling callers to call _patch_all_layers() first.

Suggested change
"""
"""
# Ensure layers have been patched before attempting to collect activations.
if not hasattr(self, "_layer_to_idx") or layer not in self._layer_to_idx:
raise RuntimeError(
"LayerActivationCollector.get_input_activations() was called before layers were patched. "
"Make sure to call LayerActivationCollector._patch_all_layers() (or the appropriate "
"patching method) before requesting input activations."
)

Copilot uses AI. Check for mistakes.
if info.output_meta is not None:
return LayerActivationCollector._zeros_from_meta(info.output_meta)
print_rank_0(f"Layer {info.name} is in 'skip' mode but has no output meta to return")
return args[0] if args else next(iter(kwargs.values()))
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In skip mode, when output_meta is missing the code falls back to returning the first input (args[0] / first kwarg). This can break parent forwards that expect the layer’s real output structure (e.g., tuple unpacking) and can silently mask state bugs. Prefer raising an explicit error (or computing/storing output_meta earlier) instead of returning an arbitrary fallback.

Suggested change
return args[0] if args else next(iter(kwargs.values()))
raise RuntimeError(
f"Sequential calibration state error: layer {info.name} is in 'skip' mode "
"but has no recorded output_meta. Ensure a prior 'run' or 'capture' phase "
"has executed and populated output_meta before switching to 'skip' mode."
)

Copilot uses AI. Check for mistakes.
Comment on lines +945 to +951
if info.mode == "run":
assert info.cached_inputs, (
f"Layer {info.name} is in 'run' mode but has no cached inputs to replay."
)
real_args, real_kwargs = info.cached_inputs.pop(0)
output = self._original_forward(*real_args, **real_kwargs)
info.output_meta = LayerActivationCollector._extract_output_meta(output)
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run mode consumes cached inputs via pop(0), which is O(n) per batch (shifts the list) and can become costly for large calibration sets. Use a collections.deque for cached_inputs (and popleft()), or track an index pointer instead of mutating the front of a list.

Copilot uses AI. Check for mistakes.
@modelopt-bot
Copy link

Logic Correctness Review

Overall, this is a well-designed refactor that correctly implements the state machine for sequential calibration. However, I found several logic issues that should be addressed:

🔴 Critical Issues

1. Race condition in state transitions for layer 0 and 1

In _set_layer_states(), when calibrating layer 0 or 1, the logic doesn't properly handle edge cases:

if layer_idx > 1:
    done = self._decoder_layers[layer_idx - 2]._seq_calib
    done.mode = "skip"
    done.cached_inputs = []  # Free memory

if layer_idx > 0:
    prev = self._decoder_layers[layer_idx - 1]._seq_calib
    prev.mode = "run"
    prev.cached_inputs = prev.collected_inputs  # Move collected → cached
    prev.collected_inputs = []

Issue: When layer_idx == 0:

  • No layer transitions to "skip" (correct, none should)
  • No layer transitions to "run" (correct, no previous layer)
  • Layer 0 transitions to "capture" (correct)

However, when layer_idx == 1:

  • Layer -1 doesn't exist (no issue, condition prevents access)
  • Layer 0 transitions to "run" ✓
  • Layer 1 transitions to "capture" ✓

But there's a subtle bug: layer 0's collected_inputs from the previous iteration (when it was layer 0) should have been preserved. Let me trace through:

Iteration 0 (calibrating layer 0):

  • Layer 0: mode="capture", collects inputs
  • After capture: mode transitions to "original" in get_input_activations()

Iteration 1 (calibrating layer 1):

  • Layer 0: should transition to "run" with cached_inputs = collected_inputs
  • BUG: Layer 0's collected_inputs was never saved! In iteration 0, after capturing, we do:
    inputs = list(info.collected_inputs)  # Copy for return
    info.mode = "original"  # But collected_inputs is still there
  • Then in iteration 1, we do prev.cached_inputs = prev.collected_inputs, which should work...

Wait, let me re-check. Actually the logic looks correct because collected_inputs persists until explicitly cleared. But there's still an issue:

The real bug: When layer 0 transitions to "original" after its calibration, and then we start calibrating layer 1, we call _set_layer_states(1) which sets layer 0 to "run" and moves collected_inputs to cached_inputs. But what if layer 0 was called multiple times during iteration 0's calibration? The collected_inputs would have multiple entries, and we'd only replay them once in iteration 1.

Actually, looking more carefully at the test test_layer_activation_collector_run_uses_cached_inputs_not_parent, this seems intentional - each layer's inputs are captured once per batch, then replayed once.

Let me reconsider... The logic seems correct for the intended use case. However, there's a different issue:

2. Missing validation: What if collected_inputs is empty when transitioning to "run"?

if info.mode == "run":
    assert info.cached_inputs, (
        f"Layer {info.name} is in 'run' mode but has no cached inputs to replay."
    )

This assert is good, but it only fires during forward. A better place to check is in _set_layer_states():

if layer_idx > 0:
    prev = self._decoder_layers[layer_idx - 1]._seq_calib
    if not prev.collected_inputs:
        raise RuntimeError(
            f"Layer {layer_idx - 1} has no collected inputs to replay. "
            "This may indicate the previous layer was never called during forward."
        )
    prev.mode = "run"
    prev.cached_inputs = prev.collected_inputs
    prev.collected_inputs = []

3. Memory leak: output_meta never cleared for skip layers

When a layer transitions to "skip", we clear cached_inputs but not output_meta:

if layer_idx > 1:
    done = self._decoder_layers[layer_idx - 2]._seq_calib
    done.mode = "skip"
    done.cached_inputs = []  # Cleared
    # output_meta not cleared - minor memory leak for large models

For models with many layers, keeping output_meta for all skip layers could add up (though it's small per layer).

🟡 Medium Issues

4. Incorrect error message in sequential_calibrate

The PR description mentions this, but the error message still says:

"Could not find transformer layers in model'. "  # Note the stray single quote

Should be:

"Could not find transformer layers in model. "

5. _patched flag not reset on exception

In _patch_all_layers():

def _patch_all_layers(self):
    self._decoder_layers = self.get_decoder_layers(self.model)
    assert self._decoder_layers is not None
    # ... patches layers ...
    self._patched = True  # Set at end

If an exception occurs mid-patching (e.g., bind_forward_method fails), _patched remains False, but some layers may be partially patched. The _unpatch_all_layers() checks if not self._patched: return, so it won't clean up.

Better to set _patched = True only after all patches succeed, or handle cleanup on exception.

6. Thread safety: Class-level _decoder_layer_support is mutable

_decoder_layer_support: list[tuple[Any, Any]] = []

This is a class-level mutable default. If multiple models are calibrated concurrently (unlikely but possible), there could be race conditions. Consider:

  • Making it a frozen collection after initialization
  • Or documenting that it's not thread-safe

🟢 Minor Issues

7. Docstring typo

In LayerActivationCollector.get_decoder_layers():

"""Return decoder layers supported by sequential calibration."""

Should probably be:

"""Return decoder layers for sequential calibration, or None if not supported."""

8. Inconsistent handling of non-tensor outputs

In _extract_output_meta():

return ("other", output)

This stores the actual non-tensor value (e.g., None, a dict, etc.). In _zeros_from_meta():

return meta[1]  # Returns the original value

This is correct for None, but what if the output is a dict with mutable state? The skip layer would return the same dict instance that was captured, which could lead to unexpected mutations. Consider deep-copying or documenting this behavior.


Summary

The core logic is sound, but I recommend addressing:

  1. Add validation in _set_layer_states() to check for empty collected_inputs before transitioning to "run"
  2. Clear output_meta when transitioning to "skip" to free memory
  3. Fix the typo in the error message
  4. Consider exception safety in _patch_all_layers()
  5. Document or fix the mutable default in _extract_output_meta() for non-tensor outputs

Would you like me to suggest specific code changes for these issues?

* Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights).
* Layer ``i`` → **capture** (record inputs, then early-stop).
"""
assert self._decoder_layers is not None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding validation here to check if collected_inputs is non-empty before transitioning to 'run' mode:\n\npython\nif layer_idx > 0:\n prev = self._decoder_layers[layer_idx - 1]._seq_calib\n if not prev.collected_inputs:\n raise RuntimeError(\n f"Layer {layer_idx - 1} has no collected inputs to replay. "\n "This may indicate the previous layer was never called during forward."\n )\n prev.mode = "run"\n prev.cached_inputs = prev.collected_inputs\n prev.collected_inputs = []\n\n\nThis would provide a clearer error message if something goes wrong in the capture phase.

# ------------------------------------------------------------------

def _set_layer_states(self, layer_idx: int):
"""Transition layer modes for the next calibration step.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor memory optimization: consider clearing output_meta when transitioning to 'skip' mode:\n\npython\nif layer_idx > 1:\n done = self._decoder_layers[layer_idx - 2]._seq_calib\n done.mode = "skip"\n done.cached_inputs = []\n done.output_meta = None # Free memory for large models with many layers\n\n\nWhile output_meta is small per layer, it could add up for models with hundreds/thousands of layers.

layer._seq_calib = _LayerCalibState(
name=module_to_name.get(layer, type(layer).__name__),
)
bind_forward_method(layer, self._patched_forward, "_original_forward")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exception safety concern: if an exception occurs during patching (e.g., bind_forward_method fails on layer N), the _patched flag remains False but some layers may be partially patched. The _unpatch_all_layers() won't clean up because it checks if not self._patched: return.\n\nConsider either:\n1. Setting _patched = True only after all patches succeed, OR\n2. Cleaning up on exception in _patch_all_layers(), OR \n3. Making _unpatch_all_layers() always attempt cleanup regardless of _patched flag\n\nExample fix:\npython\ndef _patch_all_layers(self):\n try:\n self._decoder_layers = self.get_decoder_layers(self.model)\n assert self._decoder_layers is not None\n self._layer_to_idx = {layer: i for i, layer in enumerate(self._decoder_layers)}\n module_to_name = {m: name for name, m in self.model.named_modules()}\n \n for layer in self._decoder_layers:\n layer._seq_calib = _LayerCalibState(\n name=module_to_name.get(layer, type(layer).__name__),\n )\n bind_forward_method(layer, self._patched_forward, "_original_forward")\n \n self._patched = True # Only set after all patches succeed\n except Exception:\n # Cleanup on failure\n self._unpatch_all_layers()\n raise\n

self._patched = False

# ------------------------------------------------------------------
# Decoder-layer discovery

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential issue with non-tensor outputs: when a layer returns a non-tensor value (e.g., a dict or mutable object), _extract_output_meta stores the actual value in meta[1], and _zeros_from_meta returns the same instance:\n\npython\n# In _extract_output_meta:\nreturn ("other", output) # Stores reference to original object\n\n# In _zeros_from_meta:\nreturn meta[1] # Returns same reference!\n\n\nThis could lead to unexpected mutations if the skip layer's output is modified by downstream code. Consider:\n1. Deep-copying non-tensor values, OR\n2. Documenting that mutable non-tensor outputs are not supported, OR\n3. Returning a safe default (e.g., None or copy.deepcopy(meta[1]))\n\nFor most cases (None, immutable values) this is fine, but it's a potential footgun for heterogeneous architectures.

if layer_idx > 0:
prev = self._decoder_layers[layer_idx - 1]._seq_calib
prev.mode = "run"
prev.cached_inputs = prev.collected_inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confusing why we need both cached_inputs and collected_inputs in _LayerCalibState, is it mainly to indicate which state this data(activation) is in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need it to be separate. Just have it that way to make it clear and avoid using stale values.

Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
@sugunav14 sugunav14 marked this pull request as ready for review March 11, 2026 04:34
@sugunav14 sugunav14 requested a review from a team as a code owner March 11, 2026 04:34
@sugunav14 sugunav14 requested a review from a team as a code owner March 11, 2026 04:34
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/huggingface.py (1)

1199-1206: Consider making the generic HF discoverer structural instead of homogeneous-only.

LayerActivationCollector now tracks output_meta per layer, and the new regression cases cover heterogeneous layer stacks. Requiring a single layer class here seems stricter than the runtime needs and will exclude mixed-block decoders that still expose a clean model.model.layers list.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 1199 - 1206,
The current is_homogenous_hf_model function rejects decoders with mixed layer
classes; instead relax it to a structural discoverer: in is_homogenous_hf_model
(and still respect is_nemotron_h_model) return True whenever
get_homogeneous_hf_decoder_layers(model) yields a non-empty decoder layer list
and the model exposes a layers sequence (e.g. model.model.layers) whose elements
provide the runtime metadata LayerActivationCollector expects (check for the
presence of output_meta or another per-layer attribute/ability that
LayerActivationCollector uses), and remove the strict len(layer_classes) == 1
check; keep using get_homogeneous_hf_decoder_layers and is_nemotron_h_model as
anchors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/activation_collector.py`:
- Around line 50-52: The current implementation stores a single output_meta that
gets overwritten in ActivationCollector.run and reused in skip, causing replayed
batches to get incorrect shapes; change output_meta from a single tuple to a
per-batch sequence (e.g., a deque or list) named output_meta_list or
output_meta: deque, append the metadata for each captured batch inside
ActivationCollector.run (use the same indexing order as
cached_inputs/collected_inputs), update ActivationCollector.skip and any replay
logic to consume/peek the corresponding per-batch metadata in FIFO order
(maintaining alignment with cached_inputs entries) and remove or rotate entries
consistently after replays so each replayed batch uses its original per-batch
output_meta rather than a shared slot (also apply the same change to the other
occurrence referenced around lines 163-179).
- Around line 241-265: The state-transition code in _set_layer_states assumes
prior stages captured/replayed successfully; add fail-fast checks before
changing modes: verify done.output_meta is present (non-empty) before setting
done.mode="skip" and verify prev.collected_inputs is non-empty before setting
prev.mode="run" and copying to prev.cached_inputs; if either check fails, raise
a clear RuntimeError indicating the specific layer index and missing data (use
layer_idx and references to self._decoder_layers[...] ._seq_calib). After
forward_loop() completes, also validate that the current layer's
collected_inputs is non-empty and raise a descriptive error immediately if the
capture is empty. Apply the same guardrails to the analogous transition block
later in the file (the other spot that manipulates ._seq_calib,
.collected_inputs and .cached_inputs).

---

Nitpick comments:
In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 1199-1206: The current is_homogenous_hf_model function rejects
decoders with mixed layer classes; instead relax it to a structural discoverer:
in is_homogenous_hf_model (and still respect is_nemotron_h_model) return True
whenever get_homogeneous_hf_decoder_layers(model) yields a non-empty decoder
layer list and the model exposes a layers sequence (e.g. model.model.layers)
whose elements provide the runtime metadata LayerActivationCollector expects
(check for the presence of output_meta or another per-layer attribute/ability
that LayerActivationCollector uses), and remove the strict len(layer_classes) ==
1 check; keep using get_homogeneous_hf_decoder_layers and is_nemotron_h_model as
anchors.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 913af19a-8050-4673-b980-42453284b8d7

📥 Commits

Reviewing files that changed from the base of the PR and between 31f0783 and 50e31f0.

📒 Files selected for processing (9)
  • modelopt/torch/quantization/activation_collector.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/utils.py
  • modelopt/torch/utils/network.py
  • tests/unit/torch/quantization/plugins/test_huggingface.py
  • tests/unit/torch/quantization/test_calib.py
  • tests/unit/torch/quantization/test_sequential_calibrate.py
  • tests/unit/torch/quantization/test_utils.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/utils/network.py

Comment on lines +50 to +52
cached_inputs: deque = field(default_factory=deque)
collected_inputs: list = field(default_factory=list)
output_meta: tuple | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

output_meta only tracks the last batch.

run overwrites a single output_meta on every replayed batch, and skip reuses that one shape for every later batch. With a normal drop_last=False loader or variable sequence lengths, a later pass can synthesize dummy outputs with the wrong dimensions for earlier batches. This needs per-batch metadata that can be replayed in order on each future pass, not one shared slot per layer.

Also applies to: 163-179

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/activation_collector.py` around lines 50 - 52,
The current implementation stores a single output_meta that gets overwritten in
ActivationCollector.run and reused in skip, causing replayed batches to get
incorrect shapes; change output_meta from a single tuple to a per-batch sequence
(e.g., a deque or list) named output_meta_list or output_meta: deque, append the
metadata for each captured batch inside ActivationCollector.run (use the same
indexing order as cached_inputs/collected_inputs), update
ActivationCollector.skip and any replay logic to consume/peek the corresponding
per-batch metadata in FIFO order (maintaining alignment with cached_inputs
entries) and remove or rotate entries consistently after replays so each
replayed batch uses its original per-batch output_meta rather than a shared slot
(also apply the same change to the other occurrence referenced around lines
163-179).

Comment on lines +241 to +265
def _set_layer_states(self, layer_idx: int):
"""Transition layer modes for the next calibration step.

When calibrating layer *i*, three transitions happen:

* Layer ``i - 2`` → **skip** (fully done, free its cached inputs).
* Layer ``i - 1`` → **run** (replay captured inputs with calibrated weights).
* Layer ``i`` → **capture** (record inputs, then early-stop).
"""
assert self._decoder_layers is not None

if layer_idx > 1:
done = self._decoder_layers[layer_idx - 2]._seq_calib
done.mode = "skip"
done.cached_inputs = deque()

if layer_idx > 0:
prev = self._decoder_layers[layer_idx - 1]._seq_calib
prev.mode = "run"
prev.cached_inputs = deque(prev.collected_inputs)
prev.collected_inputs = []

cur = self._decoder_layers[layer_idx]._seq_calib
cur.mode = "capture"
cur.collected_inputs = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fail fast when the sequential state machine falls out of sync.

These transitions assume the previous pass both captured inputs and replayed once. If a caller skips a decoder layer, or forward_loop() never reaches the target, you only discover it later when a skipped layer hits the RuntimeError on Line 165. Please validate done.output_meta / prev.collected_inputs before switching modes, and reject an empty capture right after forward_loop() so the failure is reported at the source.

Suggested guardrails
 def _set_layer_states(self, layer_idx: int):
     assert self._decoder_layers is not None

     if layer_idx > 1:
         done = self._decoder_layers[layer_idx - 2]._seq_calib
+        if done.output_meta is None:
+            raise RuntimeError(
+                f"Layer {done.name} cannot enter 'skip' before a successful replay."
+            )
         done.mode = "skip"
         done.cached_inputs = deque()

     if layer_idx > 0:
         prev = self._decoder_layers[layer_idx - 1]._seq_calib
+        if not prev.collected_inputs:
+            raise RuntimeError(
+                f"Layer {prev.name} has no captured inputs to replay."
+            )
         prev.mode = "run"
         prev.cached_inputs = deque(prev.collected_inputs)
         prev.collected_inputs = []
         forward_loop(self.model)

         info = layer._seq_calib
         inputs = list(info.collected_inputs)
         # After capture, set to original so calib_func can call the layer's
         # real forward directly.  The layer will transition to run → skip
         # in subsequent iterations via _set_layer_states.
         info.mode = "original"
+        if not inputs:
+            raise RuntimeError(
+                f"Layer {info.name} did not capture any inputs during forward_loop()."
+            )
         return inputs

Also applies to: 293-307

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/activation_collector.py` around lines 241 - 265,
The state-transition code in _set_layer_states assumes prior stages
captured/replayed successfully; add fail-fast checks before changing modes:
verify done.output_meta is present (non-empty) before setting done.mode="skip"
and verify prev.collected_inputs is non-empty before setting prev.mode="run" and
copying to prev.cached_inputs; if either check fails, raise a clear RuntimeError
indicating the specific layer index and missing data (use layer_idx and
references to self._decoder_layers[...] ._seq_calib). After forward_loop()
completes, also validate that the current layer's collected_inputs is non-empty
and raise a descriptive error immediately if the capture is empty. Apply the
same guardrails to the analogous transition block later in the file (the other
spot that manipulates ._seq_calib, .collected_inputs and .cached_inputs).

Copy link
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

The core refactoring from O(N²) to O(N) sequential calibration via skip/run/capture state machine is well-designed. The architecture, cleanup handling, and test coverage are all strong. However, there are a few issues that should be addressed before merge.

Critical: Missing model pattern registrations (regression)

The old get_decoder_layers() in network.py supported 5 model patterns. The PR only registers 2 (Nemotron-H and homogeneous HF). Missing support for:

  • Megatron/MCore: model.decoder.layers — no registration in plugins/megatron.py
  • GPT-style: model.transformer.h
  • Direct layers: model.layers (when it's an nn.ModuleList)
  • Nemotron Super/Nano without block_type: old code matched model.backbone.layers unconditionally; new get_nemotron_h_decoder_layers requires block_type on layers[0]

Models that previously worked with sequential calibration will now raise "Could not find transformer layers".

Medium Issues

  1. Typo in public API: is_homogenous_hf_model → should be is_homogeneous_hf_model (missing 'e'). Appears in function name, test name, and registration. Worth fixing now before it becomes a stable API.

  2. forward_loop is None guard removed: The old code explicitly raised ValueError("forward_loop must not be None ..."). The PR removes this check (and the corresponding test test_seq_calib_raises_on_none_forward_loop) without adding an equivalent guard. A None forward_loop will now produce an unhelpful TypeError deep in the stack.

  3. _decoder_layer_support as mutable class variable (activation_collector.py:83): This list is shared across all instances and grows monotonically. Tests work around this with monkeypatch, but in production, if plugins are re-imported or the module is reloaded, entries could accumulate. Consider documenting this as intentional or adding a guard.

Minor

  • License header on the new file says 2024 — should likely be 2025 for a new file.
  • copy.deepcopy(meta[1]) in _zeros_from_meta for the "other" case could be expensive for complex non-tensor outputs. A comment noting this is expected to be lightweight values (e.g. None) would help.

CI

linux and unit-pr-required-check are failing — should be investigated before merge.

What's Good

  • Clean state machine design with skip/run/capture/original modes
  • Proper try/finally cleanup in both sequential_calibrate and _patch_all_layers
  • Output metadata approach (_extract_output_meta/_zeros_from_meta) correctly handles tuple/list/tensor outputs for skip layers
  • Excellent test coverage (~470 new test lines) covering mode transitions, heterogeneous layers, tuple unpacking, inter-layer norm, error paths, and cleanup
  • Extensible plugin-based decoder discovery via register_decoder_layer_support
  • Backward-compatible re-export of LayerActivationCollector from utils.py

Copy link
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

The core refactoring from O(N²) to O(N) sequential calibration via skip/run/capture state machine is well-designed. The architecture, cleanup handling, and test coverage are all strong. However, there are a few issues that should be addressed before merge.

Critical: Missing model pattern registrations (regression)

The old get_decoder_layers() in network.py supported 5 model patterns. The PR only registers 2 (Nemotron-H and homogeneous HF). Missing support for:

  • Megatron/MCore: model.decoder.layers — no registration in plugins/megatron.py
  • GPT-style: model.transformer.h
  • Direct layers: model.layers (when it is an nn.ModuleList)
  • Nemotron Super/Nano without block_type: old code matched model.backbone.layers unconditionally; new get_nemotron_h_decoder_layers requires block_type on layers[0]

Models that previously worked with sequential calibration will now raise "Could not find transformer layers".

Medium Issues

  1. Typo in public API: is_homogenous_hf_model should be is_homogeneous_hf_model (missing 'e'). Appears in function name, test name, and registration. Worth fixing now before it becomes a stable API.

  2. forward_loop is None guard removed: The old code explicitly raised ValueError("forward_loop must not be None ..."). The PR removes this check (and the corresponding test test_seq_calib_raises_on_none_forward_loop) without adding an equivalent guard. A None forward_loop will now produce an unhelpful TypeError deep in the stack.

  3. _decoder_layer_support as mutable class variable (activation_collector.py:83): This list is shared across all instances and grows monotonically. Tests work around this with monkeypatch, but in production, if plugins are re-imported or the module is reloaded, entries could accumulate. Consider documenting this as intentional or adding a guard.

Minor

  • License header on the new file says 2024 — should likely be 2025 for a new file.
  • copy.deepcopy(meta[1]) in _zeros_from_meta for the "other" case could be expensive for complex non-tensor outputs. A comment noting this is expected to be lightweight values (e.g. None) would help.

CI

linux and unit-pr-required-check are failing — should be investigated before merge.

What is Good

  • Clean state machine design with skip/run/capture/original modes
  • Proper try/finally cleanup in both sequential_calibrate and _patch_all_layers
  • Output metadata approach (_extract_output_meta/_zeros_from_meta) correctly handles tuple/list/tensor outputs for skip layers
  • Excellent test coverage (~470 new test lines) covering mode transitions, heterogeneous layers, tuple unpacking, inter-layer norm, error paths, and cleanup
  • Extensible plugin-based decoder discovery via register_decoder_layer_support
  • Backward-compatible re-export of LayerActivationCollector from utils.py

Comment on lines +1268 to +1274
LayerActivationCollector.register_decoder_layer_support(
is_nemotron_h_model, get_nemotron_h_decoder_layers
)

LayerActivationCollector.register_decoder_layer_support(
is_homogenous_hf_model, get_homogeneous_hf_decoder_layers
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not cover the pattern for Megatron models, which have model.decoder.layers. Maybe a follow up is to add similar support for mcore/

return None


def is_homogenous_hf_model(model: nn.Module) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a typo? should it be is_homogeneous_hf_model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah! I will fix that. Thanks for the catch!

output_meta: tuple | None = None


class LayerActivationCollector:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq, will this work in TP/EP/PP scenarios, which we usually use for mcore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might need additional sync logic to work for these scenarios.

Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/huggingface.py (1)

1209-1216: Consider extending pattern coverage for other model architectures.

The current implementation only checks model.model.layers. As noted in a previous review, this doesn't cover patterns like model.decoder.layers (Megatron) or other architectures.

While not blocking, consider whether additional patterns should be checked here or in a follow-up PR to improve model coverage.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 1209 - 1216,
get_homogeneous_hf_decoder_layers currently only returns model.model.layers;
extend its checks to cover other common HF decoder patterns (e.g.,
model.decoder.layers for Megatron-style models, model.model.decoder.layers, and
model.transformer.layers) so more architectures are detected. Inside
get_homogeneous_hf_decoder_layers, add checks in order (e.g., hasattr(model,
"decoder") and hasattr(model.decoder, "layers"), then hasattr(model, "model")
and hasattr(model.model, "decoder") and hasattr(model.model.decoder, "layers"),
then hasattr(model, "transformer") and hasattr(model.transformer, "layers")) and
return the first matching nn.ModuleList; keep the early return None behavior if
none match. Ensure you reference this exact function name when implementing the
changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/quantization/plugins/huggingface.py`:
- Around line 1209-1216: get_homogeneous_hf_decoder_layers currently only
returns model.model.layers; extend its checks to cover other common HF decoder
patterns (e.g., model.decoder.layers for Megatron-style models,
model.model.decoder.layers, and model.transformer.layers) so more architectures
are detected. Inside get_homogeneous_hf_decoder_layers, add checks in order
(e.g., hasattr(model, "decoder") and hasattr(model.decoder, "layers"), then
hasattr(model, "model") and hasattr(model.model, "decoder") and
hasattr(model.model.decoder, "layers"), then hasattr(model, "transformer") and
hasattr(model.transformer, "layers")) and return the first matching
nn.ModuleList; keep the early return None behavior if none match. Ensure you
reference this exact function name when implementing the changes.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 221748c7-9696-4123-9104-58da5f79313e

📥 Commits

Reviewing files that changed from the base of the PR and between 50e31f0 and 5cf716a.

📒 Files selected for processing (6)
  • modelopt/torch/quantization/activation_collector.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/utils.py
  • tests/unit/torch/quantization/plugins/test_huggingface.py
  • tests/unit/torch/quantization/test_sequential_calibrate.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/quantization/activation_collector.py
  • tests/unit/torch/quantization/plugins/test_huggingface.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sugunav14 can we make ‎modelopt/torch/quantization/utils a folder and move this file to the utils folder? We should preserve the backward compatibility (so we should do __init__ properly)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I can do that!

Comment on lines +154 to +156
# ------------------------------------------------------------------
# Patched forward
# ------------------------------------------------------------------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

obvious comment, can we remove it

Suggested change
# ------------------------------------------------------------------
# Patched forward
# ------------------------------------------------------------------

return tuple(LayerActivationCollector._zeros_from_meta(m) for m in meta[1])
if tag == "list":
return [LayerActivationCollector._zeros_from_meta(m) for m in meta[1]]
return copy.deepcopy(meta[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
return copy.deepcopy(meta[1])
return meta[1]

# ------------------------------------------------------------------

@staticmethod
def _patched_forward(self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we could define this in _patch_all_layers -

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@realAsma
Copy link
Contributor

@sugunav14 have you tested an end to end flow with it?

As a follow up later, we can add a registry based plugin for saving the model so far (to enable resume during calibration)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants