Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/llm_sparsity/attention_sparsity/hf_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from modelopt.torch.sparsity.attention_sparsity.config import (
SKIP_SOFTMAX_CALIB,
SKIP_SOFTMAX_DEFAULT,
SPARSE24_SKIP_SOFTMAX,
SPARSE24_SKIP_SOFTMAX_CALIB,
)
from modelopt.torch.utils.memory_monitor import launch_memory_monitor

Expand All @@ -43,6 +45,8 @@
SPARSE_ATTN_CFG_CHOICES = {
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
"sparse24_skip_softmax": SPARSE24_SKIP_SOFTMAX,
"sparse24_skip_softmax_calib": SPARSE24_SKIP_SOFTMAX_CALIB,
}


Expand Down
58 changes: 58 additions & 0 deletions modelopt/torch/sparsity/attention_sparsity/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,23 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
),
)

skip_diagonal_blocks: bool = ModeloptField(
default=True,
title="Skip diagonal blocks.",
description=(
"When True, keep diagonal tiles dense for 2:4 sparse attention. Defaults to True."
),
)
Comment on lines +96 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Verify whether skip_diagonal_blocks is consumed beyond config declarations/presets.
rg -n --type=py '\bskip_diagonal_blocks\b' -C2

Repository: NVIDIA/Model-Optimizer

Length of output: 461


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find the sparse24 mask implementation and related files
find modelopt/torch/sparsity/attention_sparsity -name "*.py" -type f | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1121


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Look at the config.py file to understand the class structure
cat -n modelopt/torch/sparsity/attention_sparsity/config.py | head -120

Repository: NVIDIA/Model-Optimizer

Length of output: 5112


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Search for mask construction or application logic in the sparse attention module
rg -n 'def.*mask' --type=py modelopt/torch/sparsity/attention_sparsity/ -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 755


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Read the flash_skip_softmax.py file to see the complete implementation
wc -l modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py

Repository: NVIDIA/Model-Optimizer

Length of output: 142


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Read the flash_skip_softmax.py file
cat -n modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py

Repository: NVIDIA/Model-Optimizer

Length of output: 19968


skip_diagonal_blocks is defined in config but not implemented.

The field is exposed in SparseAttentionAttributeConfig (line 96-102) with documented intent to "keep diagonal tiles dense for 2:4 sparse attention," but:

  1. It is never extracted from method_config in FlashSkipSoftmax.__init__
  2. It is not referenced in the sparse24 mask application logic (calc_correction_factor_and_p, lines 218–224, 274–278)
  3. sparse24_mask_along_last_dim applies uniform 2:4 sparsity without diagonal block preservation

Either wire the flag into the mask construction or document it as reserved for future use.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 96 - 102,
The config flag skip_diagonal_blocks is declared but never used; wire it into
FlashSkipSoftmax by reading it from method_config in FlashSkipSoftmax.__init__
and storing it as self.skip_diagonal_blocks, then propagate it into the mask
construction functions (sparse24_mask_along_last_dim and
calc_correction_factor_and_p). Specifically, update FlashSkipSoftmax.__init__ to
extract SparseAttentionAttributeConfig.skip_diagonal_blocks, pass that boolean
into sparse24_mask_along_last_dim, and in sparse24_mask_along_last_dim modify
the mask generation to force full (dense) tiles for diagonal tile indices (i.e.,
where query_tile_idx == key_tile_idx) when self.skip_diagonal_blocks is True;
finally adjust calc_correction_factor_and_p to compute correction factors and p
using the effective number of sparse elements after preserving diagonal tiles so
the probability/count math remains correct. If you prefer not to implement
behavior now, alternatively mark skip_diagonal_blocks as reserved by updating
SparseAttentionAttributeConfig docstring and removing any expectation that code
reads it.


apply_sparse24: bool = ModeloptField(
default=False,
title="Apply 2:4 structured sparsity.",
description=(
"If True, additionally apply 2:4 structured sparsity (top-2 of every 4 elements "
"along seq_k) on top of the skip-softmax block mask. Only used by flash_skip_softmax."
),
)
Comment on lines +104 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add cross-field validation for apply_sparse24 and bc.

apply_sparse24=True currently accepts any positive bc, but sparse24 masking requires grouping by 4. Invalid bc values can pass config validation and then fail later at runtime.

Suggested fix
-from pydantic import Field, field_validator
+from pydantic import Field, field_validator, model_validator
@@
 class SparseAttentionAttributeConfig(ModeloptBaseConfig):
@@
     apply_sparse24: bool = ModeloptField(
         default=False,
@@
     )
+
+    `@model_validator`(mode="after")
+    def validate_sparse24_requirements(self):
+        if self.apply_sparse24 and self.bc % 4 != 0:
+            raise ValueError("bc must be divisible by 4 when apply_sparse24=True")
+        return self
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 104 - 111,
When apply_sparse24 is True, bc must be grouped by 4; add a cross-field
validation in the attention sparsity config (the model config class in this
file) to enforce that when apply_sparse24 is True then bc % 4 == 0 (and
optionally bc >= 4). Implement this as a Pydantic root_validator (or the class's
post-init check) that inspects values['apply_sparse24'] and values['bc'] and
raises a ValueError with a clear message if the condition fails; reference the
apply_sparse24 and bc fields in the validator so invalid configs are rejected
early.


@field_validator("method")
@classmethod
def validate_method(cls, v):
Expand Down Expand Up @@ -416,10 +433,51 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
},
}

# Combined 2:4 structured sparsity + skip-softmax block mask (pytorch backend)
SPARSE24_SKIP_SOFTMAX = {
"sparse_cfg": {
"*attn*": {
"method": "flash_skip_softmax",
"threshold": {"prefill": 1e-3, "decode": 1e-4},
"br": 128,
"bc": 128,
"backend": "pytorch",
"collect_stats": True,
"apply_sparse24": True,
"enable": True,
},
"default": {"enable": False},
},
}

# Combined 2:4 + skip-softmax with RULER calibration
SPARSE24_SKIP_SOFTMAX_CALIB = {
"sparse_cfg": {
"calibration": {
"target_sparse_ratio": {"prefill": 0.9, "decode": 0.9},
"samples": 64,
"max_seqlen": 65536,
"chunk_size": 4096,
},
"*attn*": {
"method": "flash_skip_softmax",
"br": 128,
"bc": 128,
"backend": "pytorch",
"collect_stats": True,
"apply_sparse24": True,
"enable": True,
},
"default": {"enable": False},
},
}


__all__ = [
"SKIP_SOFTMAX_CALIB",
"SKIP_SOFTMAX_DEFAULT",
"SPARSE24_SKIP_SOFTMAX",
"SPARSE24_SKIP_SOFTMAX_CALIB",
"CalibrationConfig",
"FlashSkipSoftmaxConfig",
"SparseAttentionAttributeConfig",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@
from . import SparseAttentionMethod, register_sparse_method


def sparse24_mask_along_last_dim(scores: torch.Tensor) -> torch.Tensor:
"""Compute 2:4 mask: for every 4 elements along the last dim, keep the 2 largest.

Args:
scores: Tensor of shape [..., N] with N divisible by 4.

Returns:
Boolean mask of same shape; True where the element is kept (top-2 of 4).
"""
*prefix, n = scores.shape
assert n % 4 == 0, "2:4 sparsity requires last dim divisible by 4"
grouped = scores.reshape(*prefix, n // 4, 4)
_, top2_idx = torch.topk(grouped, k=2, dim=-1, largest=True, sorted=False)
mask = torch.zeros_like(grouped, dtype=torch.bool)
mask.scatter_(-1, top2_idx, True)
return mask.reshape(*prefix, n)


@register_sparse_method("flash_skip_softmax")
class FlashSkipSoftmax(SparseAttentionMethod):
"""Flash Attention-aware softmax skip sparse attention method.
Expand Down Expand Up @@ -55,6 +73,7 @@ def __init__(self, method_config: dict | None = None):

# Optional parameters not in Pydantic config
self.phase = config.get("phase", None)
self.apply_sparse24 = config.get("apply_sparse24", False)

# Initialize threshold from dict config (prefill phase as default)
self.threshold = self.threshold_config.get("prefill", 1e-3)
Expand Down Expand Up @@ -195,6 +214,15 @@ def calc_correction_factor_and_p(
element_mask = element_mask.reshape(batch_size, num_heads, padded_seq_q, padded_seq_k)
element_mask = element_mask[:, :, :seq_q, :seq_k]

# Step 7b: Apply 2:4 structured sparsity on top of block mask (optional)
if self.apply_sparse24:
attn_padded = blocked_attn.reshape(
batch_size, num_heads, padded_seq_q, padded_seq_k
)
sparse24_mask = sparse24_mask_along_last_dim(attn_padded)
sparse24_mask = sparse24_mask[:, :, :seq_q, :seq_k]
element_mask = element_mask & sparse24_mask

Comment on lines +217 to +225
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reported sparsity no longer matches the final mask when sparse24 is enabled.

At Line 224 and Line 278, element_mask is further pruned with sparse24_mask, but the returned stats["sparsity"] is still computed from block_mask (pre-sparse24). This under-reports effective sparsity in sparse24 modes and can skew analysis/calibration interpretation.

Please compute and report a post-AND metric (or add separate block_sparsity and element_sparsity fields).

Also applies to: 273-279

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 217 - 225, The reported sparsity uses block_mask before the
optional 2:4 pruning, so when apply_sparse24 is true the final mask
(element_mask after AND with sparse24_mask from sparse24_mask_along_last_dim) is
denser than reported; update the stats computation to reflect post-AND rates by
computing element-level sparsity from the final element_mask (or add two fields,
e.g., block_sparsity computed from block_mask and element_sparsity computed from
element_mask) and populate stats accordingly where stats["sparsity"] is set so
callers get accurate post-sparse24 metrics.

# Step 8: Calculate sparsity statistics
if self.is_causal:
# For causal attention, only count lower triangle blocks (including diagonal)
Expand Down Expand Up @@ -242,6 +270,13 @@ def calc_correction_factor_and_p(
element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k)
element_mask = element_mask[:, :, :seq_q, :seq_k]

# Step 6b: Apply 2:4 structured sparsity on top of block mask (optional)
if self.apply_sparse24:
attn_padded = blocked_attn.reshape(batch_size, num_heads, 1, padded_seq_k)
sparse24_mask = sparse24_mask_along_last_dim(attn_padded)
sparse24_mask = sparse24_mask[:, :, :seq_q, :seq_k]
element_mask = element_mask & sparse24_mask

# Step 7: Calculate sparsity statistics
dense_blocks = block_mask.sum()
total_valid_blocks = block_mask.numel()
Expand Down
Loading