Skip softmax calibration with list of thresholds#987
Skip softmax calibration with list of thresholds#987rohansjoshi wants to merge 1 commit intomainfrom
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughAdds multi-threshold support for attention sparsity: configs now accept per-phase lists of thresholds; calibrator collects per-sample sparsity for all thresholds in one forward pass; sparse method, stats aggregation, and tests updated to handle per-threshold arrays and renamed APIs/fields. Changes
Sequence Diagram(s)sequenceDiagram
participant Calibrator
participant Module
participant SparseMethod as "Sparse Method"
participant Aggregator
rect rgba(200,100,150,0.5)
Note over Calibrator,Aggregator: Old (per-threshold forwards)
loop For each threshold t
Calibrator->>Module: forward(threshold=t)
Module->>SparseMethod: compute sparsity(threshold=t)
SparseMethod-->>Module: sparsity_t
Module-->>Calibrator: per-sample sparsity_t
Calibrator->>Aggregator: aggregate(t, sparsity_t)
end
end
rect rgba(100,160,200,0.5)
Note over Calibrator,Aggregator: New (single-pass multi-threshold)
Calibrator->>Module: forward(thresholds=[t1,...,tN])
Module->>SparseMethod: compute sparsity for all thresholds
SparseMethod-->>Module: sparsity_list [s1,...,sN]
Module-->>Calibrator: per-sample sparsity_list
loop For each threshold index i
Calibrator->>Aggregator: unpack & aggregate sparsity_list[i]
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
7cb2377 to
8f455c1
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
137-140: Guard against silent data loss when pairing thresholds and sparsities.At Line 140,
zip(self.threshold_trials, sparsity_list)silently truncates on length mismatch, which can hide calibration stat drift.💡 Suggested fix
for sample_stat in per_sample_stats: length = sample_stat["sample_length"] sparsity_list = sample_stat["sparsity"] + if len(sparsity_list) != len(self.threshold_trials): + raise ValueError( + f"Expected {len(self.threshold_trials)} sparsity values, got {len(sparsity_list)}" + ) for threshold, sparsity in zip(self.threshold_trials, sparsity_list):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around lines 137 - 140, The loop silently truncates mismatched pairs by using zip(self.threshold_trials, sparsity_list); before iterating (in the calibrator that processes per_sample_stats), validate that len(sparsity_list) == len(self.threshold_trials) and if not, raise a clear exception or log an error and skip the sample to avoid silent data loss—use the sample_stat/"sample_length" and sparsity_list context to include identifying info in the message; do not rely on zip_longest to silently fill values, explicitly enforce or handle length mismatches.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 130-136: Wrap the calibration sequence in a try/finally so
calibration mode is always cleaned up on error: after calling
self._set_thresholds(...) and self._enable_calibration_mode(...), run
forward_loop(model) and self._extract_calibration_stats(...) inside a try block
and call self._disable_calibration_mode(...) (and reset any trial thresholds if
applicable) in the finally block; reference the methods _set_thresholds,
_enable_calibration_mode, forward_loop, _extract_calibration_stats, and
_disable_calibration_mode so you locate and wrap that exact sequence to ensure
modules are disabled and thresholds cleared even when exceptions occur.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 179-183: The code computes total_blocks/total_valid_blocks using
num_causal_blocks when self.is_causal but later counts dense block positions
across all blocks (producing negative sparsity); update the dense-block counting
to mask out non-causal positions or else compute both numerator and denominator
from the same masked positions: use the same causal mask used to derive
num_causal_blocks when counting dense blocks (and when computing
total_blocks/total_valid_blocks) so numerator and denominator align (apply this
fix in the block that sets total_blocks/total_valid_blocks and also in the later
dense-counting section referenced around the second occurrence at Lines
~194-197); refer to self.is_causal, num_causal_blocks, total_valid_blocks,
total_blocks and the dense-block counting logic to locate and change the code.
- Around line 60-61: The code currently falls back to the runtime value
self.thresholds when a phase key is missing, making behavior depend on the order
phases run; instead, when resolving per-phase thresholds use only configuration
defaults (e.g. phase_val = self.thresholds_config.get(phase,
self.thresholds_config.get("prefill", [1e-3]))), so replace any use of
self.thresholds as the fallback with a config-only chain (phase -> "prefill" ->
literal default) in the code that looks up phase thresholds (references:
self.thresholds_config and self.thresholds).
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 137-140: The loop silently truncates mismatched pairs by using
zip(self.threshold_trials, sparsity_list); before iterating (in the calibrator
that processes per_sample_stats), validate that len(sparsity_list) ==
len(self.threshold_trials) and if not, raise a clear exception or log an error
and skip the sample to avoid silent data loss—use the
sample_stat/"sample_length" and sparsity_list context to include identifying
info in the message; do not rely on zip_longest to silently fill values,
explicitly enforce or handle length mismatches.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: aced1124-5d44-4cab-b27a-89ab4c75bffa
📒 Files selected for processing (12)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/model_sparsify.pymodelopt/torch/sparsity/attention_sparsity/stats_manager.pytests/_test_utils/torch/sparsity/sparse_attention_common.pytests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.pytests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.pytests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.pytests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
| self.thresholds = self.thresholds_config.get("prefill", [1e-3]) | ||
|
|
There was a problem hiding this comment.
Avoid phase-order-dependent threshold fallback.
At Line 71, missing phase keys fall back to the previous runtime value (self.thresholds), so behavior depends on which phase ran first instead of config-only defaults.
💡 Suggested fix
- # Initialize thresholds from dict config (prefill phase as default)
- self.thresholds = self.thresholds_config.get("prefill", [1e-3])
+ # Deterministic fallback for configs that define only one phase
+ self._fallback_thresholds = (
+ self.thresholds_config.get("prefill")
+ or self.thresholds_config.get("decode")
+ or [1e-3]
+ )
+ self.thresholds = list(self._fallback_thresholds)
def _update_thresholds(self, phase: str):
"""Update thresholds list based on phase."""
- self.thresholds = self.thresholds_config.get(phase, self.thresholds)
+ self.thresholds = list(self.thresholds_config.get(phase, self._fallback_thresholds))Also applies to: 69-72
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 60 - 61, The code currently falls back to the runtime value
self.thresholds when a phase key is missing, making behavior depend on the order
phases run; instead, when resolving per-phase thresholds use only configuration
defaults (e.g. phase_val = self.thresholds_config.get(phase,
self.thresholds_config.get("prefill", [1e-3]))), so replace any use of
self.thresholds as the fallback with a config-only chain (phase -> "prefill" ->
literal default) in the code that looks up phase thresholds (references:
self.thresholds_config and self.thresholds).
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Pull request overview
Updates FlashSkipSoftmax “skip softmax” calibration to support evaluating multiple sparsity thresholds in a single forward pass, improving calibration throughput while keeping inference sparsity behavior unchanged.
Changes:
- Rename sparse attention config from
threshold(scalar per phase) tothresholds(list per phase) and propagate through configs/tests. - Update FlashSkipSoftmax to compute per-threshold sparsity stats in one pass (and use the first threshold for the applied mask).
- Extend stats aggregation to handle
sparse_blocksas either a scalar or a list.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py | Updates expected threshold info to thresholds dict-of-lists. |
| tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py | Updates sparse attention conversion tests to use thresholds. |
| tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py | Updates calibration tests/configs to use thresholds. |
| tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py | Updates FlashSkipSoftmax unit tests for list-based sparsity outputs. |
| tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py | Updates GPU integration configs to thresholds. |
| tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py | Updates GPU calibration configs to thresholds. |
| tests/_test_utils/torch/sparsity/sparse_attention_common.py | Updates shared test config fixtures to thresholds. |
| modelopt/torch/sparsity/attention_sparsity/stats_manager.py | Adds support for aggregating list-valued sparse_blocks and list average sparsity. |
| modelopt/torch/sparsity/attention_sparsity/model_sparsify.py | Updates public-facing doc/example to thresholds. |
| modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py | Implements multi-threshold stats collection and threshold list handling. |
| modelopt/torch/sparsity/attention_sparsity/config.py | Renames/validates thresholds as dict-of-float-lists (with length checks). |
| modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py | Switches calibration data collection to single-pass multi-threshold stats extraction. |
Comments suppressed due to low confidence (1)
modelopt/torch/sparsity/attention_sparsity/config.py:132
validate_thresholdsstill raises an error that says "Threshold must be..." even though the field is nowthresholdsand expects lists. Updating this message will make validation failures much clearer to users.
def validate_thresholds(cls, v):
"""Validate thresholds is a dict of lists with valid phases and values in range (0, 1)."""
if not isinstance(v, dict):
raise ValueError(
f"Threshold must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}"
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Show resolved
Hide resolved
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Outdated
Show resolved
Hide resolved
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Outdated
Show resolved
Hide resolved
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Show resolved
Hide resolved
|
The change does a great job of reducing calibration time overhead. LGTM overall, left a few comments. |
8f455c1 to
fd82fe3
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
137-149:⚠️ Potential issue | 🟡 MinorValidate sparsity list length before zip.
At Line 140,
zip(self.threshold_trials, sparsity_list)silently truncates if lengths differ. If a module returns fewer sparsity values than expected, data will be silently lost.💡 Suggested fix
for sample_stat in per_sample_stats: length = sample_stat["sample_length"] sparsity_list = sample_stat["sparsity"] + if len(sparsity_list) != len(self.threshold_trials): + raise ValueError( + f"Expected {len(self.threshold_trials)} sparsity values, got {len(sparsity_list)}" + ) for threshold, sparsity in zip(self.threshold_trials, sparsity_list):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around lines 137 - 149, The loop over per_sample_stats uses zip(self.threshold_trials, sparsity_list) which silently truncates mismatched lengths; before zipping, validate that len(sparsity_list) == len(self.threshold_trials) (or handle the mismatch explicitly) inside the same function/loop in calibrator.py (the variables: per_sample_stats, sparsity_list, self.threshold_trials); if they differ, either raise a clear ValueError or log an error and skip/pad entries so no data is silently dropped, then proceed to iterate using the validated/padded lists.
🧹 Nitpick comments (1)
tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py (1)
143-146: Relax sparsity bounds reflect known causal masking issue.The sparsity check at Line 145 uses
all(-1 <= s <= 1 for s in stats["sparsity"])rather than the expected[0, 1]range. This appears to accommodate the causal block counting issue flagged in previous reviews where numerator/denominator mismatch can produce invalid sparsity values.Consider adding a comment explaining this relaxed bound, or fixing the underlying causal masking issue in
calc_correction_factor_and_p.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py` around lines 143 - 146, The test currently asserts sparsity values allow [-1,1] which hides a known causal masking numerator/denominator mismatch; update the test at the assertion for stats["sparsity"] to either (a) tighten to the expected 0..1 range and adjust/fix the underlying calculation in calc_correction_factor_and_p (or related causal masking logic) so sparsity cannot go negative, or (b) if you keep the relaxed bound, add a concise comment above the assertion referencing the causal masking bug and pointing to calc_correction_factor_and_p so future maintainers know why [-1,1] is allowed and where to fix it. Ensure the reference is added near the assertion and that any code changes fix the denominator/numerator handling in calc_correction_factor_and_p to produce values in [0,1] before reverting the test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 150-156: Replace the hardcoded trust_remote_code=True in the
AutoModelForCausalLM.from_pretrained(...) call with a caller-configurable CLI
flag: add a new argument to the script's argument parser (e.g.,
--trust-remote-code as a store_true flag or a boolean option defaulting to
False) and pass that parsed value (e.g., args.trust_remote_code) into the
from_pretrained call (alongside existing args.pyt_ckpt_path,
attn_implementation, torch_dtype, device_map). Ensure the new flag defaults to
False and is referenced where trust_remote_code is currently used.
---
Outside diff comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 137-149: The loop over per_sample_stats uses
zip(self.threshold_trials, sparsity_list) which silently truncates mismatched
lengths; before zipping, validate that len(sparsity_list) ==
len(self.threshold_trials) (or handle the mismatch explicitly) inside the same
function/loop in calibrator.py (the variables: per_sample_stats, sparsity_list,
self.threshold_trials); if they differ, either raise a clear ValueError or log
an error and skip/pad entries so no data is silently dropped, then proceed to
iterate using the validated/padded lists.
---
Nitpick comments:
In `@tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py`:
- Around line 143-146: The test currently asserts sparsity values allow [-1,1]
which hides a known causal masking numerator/denominator mismatch; update the
test at the assertion for stats["sparsity"] to either (a) tighten to the
expected 0..1 range and adjust/fix the underlying calculation in
calc_correction_factor_and_p (or related causal masking logic) so sparsity
cannot go negative, or (b) if you keep the relaxed bound, add a concise comment
above the assertion referencing the causal masking bug and pointing to
calc_correction_factor_and_p so future maintainers know why [-1,1] is allowed
and where to fix it. Ensure the reference is added near the assertion and that
any code changes fix the denominator/numerator handling in
calc_correction_factor_and_p to produce values in [0,1] before reverting the
test.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 507ea991-20b5-4a73-a642-d9e9d329486c
📒 Files selected for processing (14)
examples/llm_sparsity/attention_sparsity/hf_sa.pymodelopt/torch/sparsity/attention_sparsity/calibration/calibrator.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/model_sparsify.pymodelopt/torch/sparsity/attention_sparsity/stats_manager.pytests/_test_utils/torch/sparsity/sparse_attention_common.pytests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.pytests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.pytests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.pytests/unit/torch/sparsity/attention_sparsity/test_stats_manager.pytests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
🚧 Files skipped from review as they are similar to previous changes (5)
- tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py
- tests/_test_utils/torch/sparsity/sparse_attention_common.py
- tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
- modelopt/torch/sparsity/attention_sparsity/stats_manager.py
- tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py
fd82fe3 to
586209f
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py (1)
59-71:⚠️ Potential issue | 🟠 MajorPhase-order-dependent threshold fallback persists.
At line 71,
_update_thresholdsfalls back toself.thresholdswhen the phase key is missing fromthresholds_config. This means behavior depends on which phase ran first, rather than deterministic config-only defaults.💡 Suggested fix
+ # Deterministic fallback for configs that define only one phase + self._fallback_thresholds = ( + self.thresholds_config.get("prefill") + or self.thresholds_config.get("decode") + or [1e-3] + ) # Initialize thresholds from dict config (prefill phase as default) - self.thresholds = self.thresholds_config.get("prefill", [1e-3]) + self.thresholds = list(self._fallback_thresholds) ... def _update_thresholds(self, phase: str): """Update thresholds list based on phase.""" - self.thresholds = self.thresholds_config.get(phase, self.thresholds) + self.thresholds = list(self.thresholds_config.get(phase, self._fallback_thresholds))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py` around lines 59 - 71, _update_thresholds currently falls back to the mutable instance field self.thresholds when a phase key is missing, causing behavior to depend on which phase ran first; change _update_thresholds to use a deterministic config-only fallback by reading from self.thresholds_config (e.g., use self.thresholds_config.get(phase, self.thresholds_config.get("prefill", [1e-3]))) so missing phases always resolve from the config (not prior runtime state) and keep set_calibration_mode/_calibration_mode logic unchanged.modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
130-136:⚠️ Potential issue | 🟠 MajorCalibration mode may not be disabled on exception.
If
forward_loop(model)or_extract_calibration_stats(...)raises an exception,_disable_calibration_modeis never called, leaving modules in calibration mode with trial thresholds still set.💡 Suggested fix
self._set_thresholds(attention_modules, self.threshold_trials) self._enable_calibration_mode(attention_modules) - with torch.no_grad(): - forward_loop(model) - per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) - self._disable_calibration_mode(attention_modules) + try: + with torch.no_grad(): + forward_loop(model) + per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) + finally: + self._disable_calibration_mode(attention_modules)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around lines 130 - 136, Wrap the execution of forward_loop(model) and _extract_calibration_stats(...) in a try/finally so that _disable_calibration_mode(attention_modules) is always called even if forward_loop or _extract_calibration_stats raises; keep the torch.no_grad() context around the try/finally block and ensure _set_thresholds(...) and _enable_calibration_mode(...) remain before the try so trial thresholds are set for the attempt and always cleared in the finally via _disable_calibration_mode(attention_modules).
🧹 Nitpick comments (4)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (1)
72-72: Consider adding a test with multiple thresholds.All threshold configurations in this file use single-element lists. Since the PR's main feature is gathering statistics for multiple thresholds in a single forward pass (~20× faster calibration), consider adding at least one test case that uses multiple threshold values (e.g.,
{"prefill": [1e-4, 1e-3, 1e-2], "decode": [1e-4, 1e-3]}). This would validate that the multi-threshold functionality works correctly at the conversion/configuration level.If multi-threshold behavior is tested elsewhere (e.g., in calibration tests), this can be disregarded.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py` at line 72, Add a new unit test case in test_sparse_attention_conversion.py that uses a thresholds config with multiple values (e.g., {"prefill": [1e-4, 1e-3, 1e-2], "decode": [1e-4, 1e-3]}) instead of single-element lists; exercise the same conversion/config path used by the existing tests (the code that reads the "thresholds" dict during sparse attention conversion) and assert that the resulting conversion/config contains entries for each provided threshold and that any gathered statistics or outputs are produced per-threshold (verify keys/counts match the input lists). Ensure the test targets the same functions used in the file for conversion/configuration so it validates multi-threshold handling end-to-end.modelopt/torch/sparsity/attention_sparsity/config.py (1)
127-132: Minor: Error message uses singular "Threshold" but field is "thresholds".The error message at line 131 says "Threshold must be a dict..." but the field is now named
thresholds(plural).💡 Suggested fix
if not isinstance(v, dict): raise ValueError( - f"Threshold must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}" + f"thresholds must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 127 - 132, In validate_thresholds (the classmethod in config.py) update the ValueError message to use the plural field name "thresholds" (e.g., "Thresholds must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}") so the error refers to the correct field; locate the validate_thresholds function and replace the singular "Threshold" text with "Thresholds" in the processLogger.error/raise ValueError message.modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py (1)
144-146: Redundant assertion after conditional check.The
asserton line 146 is defensive but redundant sinceuse_calibration_paramsalready guarantees that bothcalibration_paramsandtarget_sparse_ratioare notNone.💡 Suggested cleanup
if use_calibration_params: # Calibrated dynamic threshold: bypass thresholds list entirely - assert calibration_params is not None and target_sparse_ratio is not None a = calibration_params[phase]["a"]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py` around lines 144 - 146, Remove the redundant defensive assert inside the branch guarded by use_calibration_params: since the if use_calibration_params: check already guarantees that calibration_params and target_sparse_ratio are provided, delete the assert calibration_params is not None and target_sparse_ratio is not None to avoid unnecessary duplication in flash_skip_softmax.py (references: use_calibration_params, calibration_params, target_sparse_ratio).modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
137-149: Consider adding length validation before zipping thresholds with sparsity.At line 140,
zip(self.threshold_trials, sparsity_list)will silently truncate if the lengths don't match. Sincesparsity_listcomes from per-sample stats that should matchthreshold_trials, a mismatch would indicate a bug.💡 Suggested defensive check
for sample_stat in per_sample_stats: length = sample_stat["sample_length"] sparsity_list = sample_stat["sparsity"] + assert len(sparsity_list) == len(self.threshold_trials), ( + f"Sparsity list length {len(sparsity_list)} != threshold_trials {len(self.threshold_trials)}" + ) for threshold, sparsity in zip(self.threshold_trials, sparsity_list):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around lines 137 - 149, In the loop over per_sample_stats inside the calibrator (where you iterate sample_stat and do zip(self.threshold_trials, sparsity_list)), add a defensive length check that verifies len(sparsity_list) == len(self.threshold_trials) before zipping; if they differ, raise a clear ValueError or log an error with identifying info from sample_stat (e.g., "sample_length" or an ID) and either skip that sample or fail fast so the silent truncation cannot happen — update the code paths in the same function/method where "sparsity_list" and "self.threshold_trials" are used to perform this validation and handle the mismatch.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/stats_manager.py`:
- Around line 69-74: The code reads stats["sparse_blocks"] directly which can
raise KeyError if absent; change to use stats.get("sparse_blocks", []) (assign
to incoming) and treat empty list as no-op when updating
self.aggregated_stats["sparse_blocks"] so the loop which sums elements only runs
when incoming is non-empty; ensure you still initialize
self.aggregated_stats["sparse_blocks"] = list(incoming) when incoming is
present, mirroring the pattern used for total_blocks.
---
Duplicate comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 130-136: Wrap the execution of forward_loop(model) and
_extract_calibration_stats(...) in a try/finally so that
_disable_calibration_mode(attention_modules) is always called even if
forward_loop or _extract_calibration_stats raises; keep the torch.no_grad()
context around the try/finally block and ensure _set_thresholds(...) and
_enable_calibration_mode(...) remain before the try so trial thresholds are set
for the attempt and always cleared in the finally via
_disable_calibration_mode(attention_modules).
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 59-71: _update_thresholds currently falls back to the mutable
instance field self.thresholds when a phase key is missing, causing behavior to
depend on which phase ran first; change _update_thresholds to use a
deterministic config-only fallback by reading from self.thresholds_config (e.g.,
use self.thresholds_config.get(phase, self.thresholds_config.get("prefill",
[1e-3]))) so missing phases always resolve from the config (not prior runtime
state) and keep set_calibration_mode/_calibration_mode logic unchanged.
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 137-149: In the loop over per_sample_stats inside the calibrator
(where you iterate sample_stat and do zip(self.threshold_trials,
sparsity_list)), add a defensive length check that verifies len(sparsity_list)
== len(self.threshold_trials) before zipping; if they differ, raise a clear
ValueError or log an error with identifying info from sample_stat (e.g.,
"sample_length" or an ID) and either skip that sample or fail fast so the silent
truncation cannot happen — update the code paths in the same function/method
where "sparsity_list" and "self.threshold_trials" are used to perform this
validation and handle the mismatch.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 127-132: In validate_thresholds (the classmethod in config.py)
update the ValueError message to use the plural field name "thresholds" (e.g.,
"Thresholds must be a dict with 'prefill' and/or 'decode' keys, got
{type(v).__name__}") so the error refers to the correct field; locate the
validate_thresholds function and replace the singular "Threshold" text with
"Thresholds" in the processLogger.error/raise ValueError message.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 144-146: Remove the redundant defensive assert inside the branch
guarded by use_calibration_params: since the if use_calibration_params: check
already guarantees that calibration_params and target_sparse_ratio are provided,
delete the assert calibration_params is not None and target_sparse_ratio is not
None to avoid unnecessary duplication in flash_skip_softmax.py (references:
use_calibration_params, calibration_params, target_sparse_ratio).
In
`@tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py`:
- Line 72: Add a new unit test case in test_sparse_attention_conversion.py that
uses a thresholds config with multiple values (e.g., {"prefill": [1e-4, 1e-3,
1e-2], "decode": [1e-4, 1e-3]}) instead of single-element lists; exercise the
same conversion/config path used by the existing tests (the code that reads the
"thresholds" dict during sparse attention conversion) and assert that the
resulting conversion/config contains entries for each provided threshold and
that any gathered statistics or outputs are produced per-threshold (verify
keys/counts match the input lists). Ensure the test targets the same functions
used in the file for conversion/configuration so it validates multi-threshold
handling end-to-end.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 9fef1c11-30e1-4d26-80f0-4440f3ed1d89
📒 Files selected for processing (14)
examples/llm_sparsity/attention_sparsity/hf_sa.pymodelopt/torch/sparsity/attention_sparsity/calibration/calibrator.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/model_sparsify.pymodelopt/torch/sparsity/attention_sparsity/stats_manager.pytests/_test_utils/torch/sparsity/sparse_attention_common.pytests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.pytests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.pytests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.pytests/unit/torch/sparsity/attention_sparsity/test_stats_manager.pytests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
🚧 Files skipped from review as they are similar to previous changes (7)
- examples/llm_sparsity/attention_sparsity/hf_sa.py
- tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py
- tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
- tests/_test_utils/torch/sparsity/sparse_attention_common.py
- modelopt/torch/sparsity/attention_sparsity/model_sparsify.py
- tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py
- tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py
| incoming = stats["sparse_blocks"] | ||
| if "sparse_blocks" not in self.aggregated_stats: | ||
| self.aggregated_stats["sparse_blocks"] = list(incoming) | ||
| else: | ||
| for i, val in enumerate(incoming): | ||
| self.aggregated_stats["sparse_blocks"][i] += val |
There was a problem hiding this comment.
Potential KeyError if sparse_blocks is missing from stats.
Line 69 accesses stats["sparse_blocks"] directly without using .get(), unlike other fields (e.g., total_blocks on line 67). If a caller omits sparse_blocks, this will raise a KeyError.
Consider using .get() with a default or documenting that sparse_blocks is required:
💡 Suggested fix
- incoming = stats["sparse_blocks"]
+ incoming = stats.get("sparse_blocks")
+ if incoming is None:
+ return
+
if "sparse_blocks" not in self.aggregated_stats:
self.aggregated_stats["sparse_blocks"] = list(incoming)
else:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/stats_manager.py` around lines 69
- 74, The code reads stats["sparse_blocks"] directly which can raise KeyError if
absent; change to use stats.get("sparse_blocks", []) (assign to incoming) and
treat empty list as no-op when updating self.aggregated_stats["sparse_blocks"]
so the loop which sums elements only runs when incoming is non-empty; ensure you
still initialize self.aggregated_stats["sparse_blocks"] = list(incoming) when
incoming is present, mirroring the pattern used for total_blocks.
…single pass Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
586209f to
8430617
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (2)
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py (1)
59-71:⚠️ Potential issue | 🟠 MajorAvoid phase-order-dependent threshold fallback.
At line 71, missing phase keys fall back to the previous runtime value (
self.thresholds), so behavior depends on which phase ran first instead of config-only defaults.💡 Suggested fix
- # Initialize thresholds from dict config (prefill phase as default) - self.thresholds = self.thresholds_config.get("prefill", [1e-3]) + # Deterministic fallback for configs that define only one phase + self._fallback_thresholds = ( + self.thresholds_config.get("prefill") + or self.thresholds_config.get("decode") + or [1e-3] + ) + self.thresholds = list(self._fallback_thresholds) def _update_thresholds(self, phase: str): """Update thresholds list based on phase.""" - self.thresholds = self.thresholds_config.get(phase, self.thresholds) + self.thresholds = list(self.thresholds_config.get(phase, self._fallback_thresholds))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py` around lines 59 - 71, The _update_thresholds method currently falls back to the current runtime self.thresholds when a phase key is missing, causing behavior to depend on phase order; change _update_thresholds to load from the configuration only by setting self.thresholds = self.thresholds_config.get(phase, self.thresholds_config.get("prefill", [1e-3])) (or another config-level default key) so missing phase entries always fall back to a config-defined default instead of the previous runtime value; update references to thresholds/thresholds_config and preserve the calibration flag logic in set_calibration_mode/_update_thresholds.modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
130-136:⚠️ Potential issue | 🟠 MajorEnsure calibration mode is always disabled on failure.
An exception in
forward_loop(model)or_extract_calibration_stats(...)leaves modules in calibration mode with trial thresholds still set.💡 Suggested fix
self._set_thresholds(attention_modules, self.threshold_trials) self._enable_calibration_mode(attention_modules) - with torch.no_grad(): - forward_loop(model) - per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) - self._disable_calibration_mode(attention_modules) + try: + with torch.no_grad(): + forward_loop(model) + per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) + finally: + self._disable_calibration_mode(attention_modules)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around lines 130 - 136, Wrap the forward/measurement block in a try/finally to guarantee cleanup: before calling self._set_thresholds(...) save the original thresholds, then call self._set_thresholds(attention_modules, self.threshold_trials) and self._enable_calibration_mode(attention_modules), perform forward_loop(model) and self._extract_calibration_stats(... ) inside the try, and in the finally always call self._disable_calibration_mode(attention_modules) and restore/clear thresholds by calling self._set_thresholds(attention_modules, original_thresholds) (or self._set_thresholds(attention_modules, None) if no original) so modules never remain in calibration mode or with trial thresholds after an exception.
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
137-149: Consider validating sparsity_list length matches threshold_trials.The
zipat line 140 silently truncates ifsparsity_listhas fewer entries thanself.threshold_trials. This could mask bugs where modules don't report all thresholds.💡 Suggested validation
for sample_stat in per_sample_stats: length = sample_stat["sample_length"] sparsity_list = sample_stat["sparsity"] + if len(sparsity_list) != len(self.threshold_trials): + raise ValueError( + f"Sparsity list length {len(sparsity_list)} doesn't match " + f"threshold_trials length {len(self.threshold_trials)}" + ) for threshold, sparsity in zip(self.threshold_trials, sparsity_list):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py` around lines 137 - 149, The code currently zips self.threshold_trials with sparsity_list (from sample_stat["sparsity"]) which silently truncates when sparsity_list is shorter; update the loop in the per_sample_stats processing to validate that len(sparsity_list) == len(self.threshold_trials) (or at least >=) before iterating, and if the lengths mismatch raise an exception or log an explicit error including the offending sample (sample_stat) and the lengths; only proceed to append to all_data_points when the validation passes to avoid masked bugs in threshold reporting.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 130-136: Wrap the forward/measurement block in a try/finally to
guarantee cleanup: before calling self._set_thresholds(...) save the original
thresholds, then call self._set_thresholds(attention_modules,
self.threshold_trials) and self._enable_calibration_mode(attention_modules),
perform forward_loop(model) and self._extract_calibration_stats(... ) inside the
try, and in the finally always call
self._disable_calibration_mode(attention_modules) and restore/clear thresholds
by calling self._set_thresholds(attention_modules, original_thresholds) (or
self._set_thresholds(attention_modules, None) if no original) so modules never
remain in calibration mode or with trial thresholds after an exception.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 59-71: The _update_thresholds method currently falls back to the
current runtime self.thresholds when a phase key is missing, causing behavior to
depend on phase order; change _update_thresholds to load from the configuration
only by setting self.thresholds = self.thresholds_config.get(phase,
self.thresholds_config.get("prefill", [1e-3])) (or another config-level default
key) so missing phase entries always fall back to a config-defined default
instead of the previous runtime value; update references to
thresholds/thresholds_config and preserve the calibration flag logic in
set_calibration_mode/_update_thresholds.
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py`:
- Around line 137-149: The code currently zips self.threshold_trials with
sparsity_list (from sample_stat["sparsity"]) which silently truncates when
sparsity_list is shorter; update the loop in the per_sample_stats processing to
validate that len(sparsity_list) == len(self.threshold_trials) (or at least >=)
before iterating, and if the lengths mismatch raise an exception or log an
explicit error including the offending sample (sample_stat) and the lengths;
only proceed to append to all_data_points when the validation passes to avoid
masked bugs in threshold reporting.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ade8fd6f-a93b-4d32-9878-61b65c4b9f23
📒 Files selected for processing (15)
examples/llm_sparsity/attention_sparsity/hf_sa.pymodelopt/torch/sparsity/attention_sparsity/calibration/calibrator.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/model_sparsify.pymodelopt/torch/sparsity/attention_sparsity/stats_manager.pytests/_test_utils/torch/sparsity/sparse_attention_common.pytests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.pytests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.pytests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.pytests/unit/torch/sparsity/attention_sparsity/test_stats_manager.pytests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py
🚧 Files skipped from review as they are similar to previous changes (8)
- examples/llm_sparsity/attention_sparsity/hf_sa.py
- tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py
- modelopt/torch/sparsity/attention_sparsity/stats_manager.py
- modelopt/torch/sparsity/attention_sparsity/model_sparsify.py
- tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
- tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py
- tests/_test_utils/torch/sparsity/sparse_attention_common.py
- tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #987 +/- ##
=======================================
Coverage 70.09% 70.09%
=======================================
Files 221 221
Lines 25459 25491 +32
=======================================
+ Hits 17845 17868 +23
- Misses 7614 7623 +9 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Modify skip softmax calibration to use a list of thresholds instead of a single threshold. Sparsity during inference is unchanged, but during calibration we can use the list to gather statistics about many thresholds in a single forward pass. Makes calibration 20x faster
Summary by CodeRabbit
New Features
Improvements
Breaking Changes
Tests & Examples