Skip to content

Added 2:4 sparsity to skip softmax method#1019

Open
rohansjoshi wants to merge 1 commit intomainfrom
rohjoshi/sparse24-plus-skipsoftmax
Open

Added 2:4 sparsity to skip softmax method#1019
rohansjoshi wants to merge 1 commit intomainfrom
rohjoshi/sparse24-plus-skipsoftmax

Conversation

@rohansjoshi
Copy link
Contributor

@rohansjoshi rohansjoshi commented Mar 11, 2026

Summary

Adds an apply_sparse24: bool config option to the existing flash_skip_softmax method. When enabled, a 2:4 structured sparsity mask (top-2 of every 4 elements along seq_k) is AND-ed with the skip-softmax block mask in
both prefill and decode phases.

This is a pure PyTorch-level feature for research and analysis — not a performance optimization. It allows studying the interaction between block-level and 2:4 structured sparsity patterns.

Changes

  • config.py — New apply_sparse24 field on SparseAttentionAttributeConfig; new SPARSE24_SKIP_SOFTMAX and SPARSE24_SKIP_SOFTMAX_CALIB preset configs.
  • methods/flash_skip_softmax.py — Reads the flag and applies the 2:4 mask inside calc_correction_factor_and_p.
  • hf_sa.py — Exposes --sparse_attn sparse24_skip_softmax and --sparse_attn sparse24_skip_softmax_calib as CLI choices.

Summary by CodeRabbit

  • New Features
    • Added two new sparse attention configuration options: sparse24_skip_softmax and sparse24_skip_softmax_calib for enhanced sparsity patterns.
    • Introduced configuration parameters enabling 2:4 structured sparsity combined with skip-softmax support.
    • Added RULER-based calibration option for sparse attention configurations.

Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 11, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 11, 2026

📝 Walkthrough

Walkthrough

This PR introduces support for 2:4 structured sparsity in attention mechanisms by adding new configuration options and implementing conditional sparsity mask computation. Two new pre-defined configurations enable 2:4 sparsity with skip-softmax, optionally combined with calibration.

Changes

Cohort / File(s) Summary
Configuration & Constants
modelopt/torch/sparsity/attention_sparsity/config.py
Added two new configuration fields (skip_diagonal_blocks, apply_sparse24) to SparseAttentionAttributeConfig and introduced two pre-defined configurations (SPARSE24_SKIP_SOFTMAX, SPARSE24_SKIP_SOFTMAX_CALIB) that combine 2:4 sparsity with skip-softmax, with the latter including RULER-based calibration.
Implementation
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Added sparse24_mask_along_last_dim utility function to compute 2:4 sparsity masks and integrated apply_sparse24 configuration flag into FlashSkipSoftmax to conditionally apply sparse masks in both prefill and decode paths.
Example Integration
examples/llm_sparsity/attention_sparsity/hf_sa.py
Updated imports and extended SPARSE_ATTN_CFG_CHOICES dictionary with two new entries mapping to the newly defined sparse attention configurations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately captures the main objective: adding 2:4 sparsity functionality to the skip softmax method, which is the core change across all modified files.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed Pull request contains no security anti-patterns: no unsafe deserialization, hardcoded trust flags, eval/exec calls, nosec comments, or new dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch rohjoshi/sparse24-plus-skipsoftmax
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Mar 11, 2026

Codecov Report

❌ Patch coverage is 34.78261% with 15 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.22%. Comparing base (fe83270) to head (6afa360).
⚠️ Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
...y/attention_sparsity/methods/flash_skip_softmax.py 21.05% 15 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1019      +/-   ##
==========================================
- Coverage   70.25%   70.22%   -0.04%     
==========================================
  Files         220      220              
  Lines       25368    25391      +23     
==========================================
+ Hits        17822    17830       +8     
- Misses       7546     7561      +15     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@rohansjoshi rohansjoshi marked this pull request as ready for review March 13, 2026 21:41
@rohansjoshi rohansjoshi requested a review from a team as a code owner March 13, 2026 21:41
@rohansjoshi rohansjoshi requested review from kaix-nv and realAsma and removed request for realAsma March 13, 2026 21:41
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 104-111: When apply_sparse24 is True, bc must be grouped by 4; add
a cross-field validation in the attention sparsity config (the model config
class in this file) to enforce that when apply_sparse24 is True then bc % 4 == 0
(and optionally bc >= 4). Implement this as a Pydantic root_validator (or the
class's post-init check) that inspects values['apply_sparse24'] and values['bc']
and raises a ValueError with a clear message if the condition fails; reference
the apply_sparse24 and bc fields in the validator so invalid configs are
rejected early.
- Around line 96-102: The config flag skip_diagonal_blocks is declared but never
used; wire it into FlashSkipSoftmax by reading it from method_config in
FlashSkipSoftmax.__init__ and storing it as self.skip_diagonal_blocks, then
propagate it into the mask construction functions (sparse24_mask_along_last_dim
and calc_correction_factor_and_p). Specifically, update
FlashSkipSoftmax.__init__ to extract
SparseAttentionAttributeConfig.skip_diagonal_blocks, pass that boolean into
sparse24_mask_along_last_dim, and in sparse24_mask_along_last_dim modify the
mask generation to force full (dense) tiles for diagonal tile indices (i.e.,
where query_tile_idx == key_tile_idx) when self.skip_diagonal_blocks is True;
finally adjust calc_correction_factor_and_p to compute correction factors and p
using the effective number of sparse elements after preserving diagonal tiles so
the probability/count math remains correct. If you prefer not to implement
behavior now, alternatively mark skip_diagonal_blocks as reserved by updating
SparseAttentionAttributeConfig docstring and removing any expectation that code
reads it.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 217-225: The reported sparsity uses block_mask before the optional
2:4 pruning, so when apply_sparse24 is true the final mask (element_mask after
AND with sparse24_mask from sparse24_mask_along_last_dim) is denser than
reported; update the stats computation to reflect post-AND rates by computing
element-level sparsity from the final element_mask (or add two fields, e.g.,
block_sparsity computed from block_mask and element_sparsity computed from
element_mask) and populate stats accordingly where stats["sparsity"] is set so
callers get accurate post-sparse24 metrics.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 1d7d1c0e-1264-43a9-a838-5e7a5c47cd32

📥 Commits

Reviewing files that changed from the base of the PR and between fe83270 and 6afa360.

📒 Files selected for processing (3)
  • examples/llm_sparsity/attention_sparsity/hf_sa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py

Comment on lines +96 to +102
skip_diagonal_blocks: bool = ModeloptField(
default=True,
title="Skip diagonal blocks.",
description=(
"When True, keep diagonal tiles dense for 2:4 sparse attention. Defaults to True."
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Verify whether skip_diagonal_blocks is consumed beyond config declarations/presets.
rg -n --type=py '\bskip_diagonal_blocks\b' -C2

Repository: NVIDIA/Model-Optimizer

Length of output: 461


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find the sparse24 mask implementation and related files
find modelopt/torch/sparsity/attention_sparsity -name "*.py" -type f | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1121


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Look at the config.py file to understand the class structure
cat -n modelopt/torch/sparsity/attention_sparsity/config.py | head -120

Repository: NVIDIA/Model-Optimizer

Length of output: 5112


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Search for mask construction or application logic in the sparse attention module
rg -n 'def.*mask' --type=py modelopt/torch/sparsity/attention_sparsity/ -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 755


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Read the flash_skip_softmax.py file to see the complete implementation
wc -l modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py

Repository: NVIDIA/Model-Optimizer

Length of output: 142


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Read the flash_skip_softmax.py file
cat -n modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py

Repository: NVIDIA/Model-Optimizer

Length of output: 19968


skip_diagonal_blocks is defined in config but not implemented.

The field is exposed in SparseAttentionAttributeConfig (line 96-102) with documented intent to "keep diagonal tiles dense for 2:4 sparse attention," but:

  1. It is never extracted from method_config in FlashSkipSoftmax.__init__
  2. It is not referenced in the sparse24 mask application logic (calc_correction_factor_and_p, lines 218–224, 274–278)
  3. sparse24_mask_along_last_dim applies uniform 2:4 sparsity without diagonal block preservation

Either wire the flag into the mask construction or document it as reserved for future use.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 96 - 102,
The config flag skip_diagonal_blocks is declared but never used; wire it into
FlashSkipSoftmax by reading it from method_config in FlashSkipSoftmax.__init__
and storing it as self.skip_diagonal_blocks, then propagate it into the mask
construction functions (sparse24_mask_along_last_dim and
calc_correction_factor_and_p). Specifically, update FlashSkipSoftmax.__init__ to
extract SparseAttentionAttributeConfig.skip_diagonal_blocks, pass that boolean
into sparse24_mask_along_last_dim, and in sparse24_mask_along_last_dim modify
the mask generation to force full (dense) tiles for diagonal tile indices (i.e.,
where query_tile_idx == key_tile_idx) when self.skip_diagonal_blocks is True;
finally adjust calc_correction_factor_and_p to compute correction factors and p
using the effective number of sparse elements after preserving diagonal tiles so
the probability/count math remains correct. If you prefer not to implement
behavior now, alternatively mark skip_diagonal_blocks as reserved by updating
SparseAttentionAttributeConfig docstring and removing any expectation that code
reads it.

Comment on lines +104 to +111
apply_sparse24: bool = ModeloptField(
default=False,
title="Apply 2:4 structured sparsity.",
description=(
"If True, additionally apply 2:4 structured sparsity (top-2 of every 4 elements "
"along seq_k) on top of the skip-softmax block mask. Only used by flash_skip_softmax."
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add cross-field validation for apply_sparse24 and bc.

apply_sparse24=True currently accepts any positive bc, but sparse24 masking requires grouping by 4. Invalid bc values can pass config validation and then fail later at runtime.

Suggested fix
-from pydantic import Field, field_validator
+from pydantic import Field, field_validator, model_validator
@@
 class SparseAttentionAttributeConfig(ModeloptBaseConfig):
@@
     apply_sparse24: bool = ModeloptField(
         default=False,
@@
     )
+
+    `@model_validator`(mode="after")
+    def validate_sparse24_requirements(self):
+        if self.apply_sparse24 and self.bc % 4 != 0:
+            raise ValueError("bc must be divisible by 4 when apply_sparse24=True")
+        return self
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 104 - 111,
When apply_sparse24 is True, bc must be grouped by 4; add a cross-field
validation in the attention sparsity config (the model config class in this
file) to enforce that when apply_sparse24 is True then bc % 4 == 0 (and
optionally bc >= 4). Implement this as a Pydantic root_validator (or the class's
post-init check) that inspects values['apply_sparse24'] and values['bc'] and
raises a ValueError with a clear message if the condition fails; reference the
apply_sparse24 and bc fields in the validator so invalid configs are rejected
early.

Comment on lines +217 to +225
# Step 7b: Apply 2:4 structured sparsity on top of block mask (optional)
if self.apply_sparse24:
attn_padded = blocked_attn.reshape(
batch_size, num_heads, padded_seq_q, padded_seq_k
)
sparse24_mask = sparse24_mask_along_last_dim(attn_padded)
sparse24_mask = sparse24_mask[:, :, :seq_q, :seq_k]
element_mask = element_mask & sparse24_mask

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reported sparsity no longer matches the final mask when sparse24 is enabled.

At Line 224 and Line 278, element_mask is further pruned with sparse24_mask, but the returned stats["sparsity"] is still computed from block_mask (pre-sparse24). This under-reports effective sparsity in sparse24 modes and can skew analysis/calibration interpretation.

Please compute and report a post-AND metric (or add separate block_sparsity and element_sparsity fields).

Also applies to: 273-279

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 217 - 225, The reported sparsity uses block_mask before the
optional 2:4 pruning, so when apply_sparse24 is true the final mask
(element_mask after AND with sparse24_mask from sparse24_mask_along_last_dim) is
denser than reported; update the stats computation to reflect post-AND rates by
computing element-level sparsity from the final element_mask (or add two fields,
e.g., block_sparsity computed from block_mask and element_sparsity computed from
element_mask) and populate stats accordingly where stats["sparsity"] is set so
callers get accurate post-sparse24 metrics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant