-
Notifications
You must be signed in to change notification settings - Fork 299
Added 2:4 sparsity to skip softmax method #1019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,6 +93,23 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): | |
| ), | ||
| ) | ||
|
|
||
| skip_diagonal_blocks: bool = ModeloptField( | ||
| default=True, | ||
| title="Skip diagonal blocks.", | ||
| description=( | ||
| "When True, keep diagonal tiles dense for 2:4 sparse attention. Defaults to True." | ||
| ), | ||
| ) | ||
|
|
||
| apply_sparse24: bool = ModeloptField( | ||
| default=False, | ||
| title="Apply 2:4 structured sparsity.", | ||
| description=( | ||
| "If True, additionally apply 2:4 structured sparsity (top-2 of every 4 elements " | ||
| "along seq_k) on top of the skip-softmax block mask. Only used by flash_skip_softmax." | ||
| ), | ||
| ) | ||
|
Comment on lines
+104
to
+111
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add cross-field validation for
Suggested fix-from pydantic import Field, field_validator
+from pydantic import Field, field_validator, model_validator
@@
class SparseAttentionAttributeConfig(ModeloptBaseConfig):
@@
apply_sparse24: bool = ModeloptField(
default=False,
@@
)
+
+ `@model_validator`(mode="after")
+ def validate_sparse24_requirements(self):
+ if self.apply_sparse24 and self.bc % 4 != 0:
+ raise ValueError("bc must be divisible by 4 when apply_sparse24=True")
+ return self🤖 Prompt for AI Agents |
||
|
|
||
| @field_validator("method") | ||
| @classmethod | ||
| def validate_method(cls, v): | ||
|
|
@@ -416,10 +433,51 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): | |
| }, | ||
| } | ||
|
|
||
| # Combined 2:4 structured sparsity + skip-softmax block mask (pytorch backend) | ||
| SPARSE24_SKIP_SOFTMAX = { | ||
| "sparse_cfg": { | ||
| "*attn*": { | ||
| "method": "flash_skip_softmax", | ||
| "threshold": {"prefill": 1e-3, "decode": 1e-4}, | ||
| "br": 128, | ||
| "bc": 128, | ||
| "backend": "pytorch", | ||
| "collect_stats": True, | ||
| "apply_sparse24": True, | ||
| "enable": True, | ||
| }, | ||
| "default": {"enable": False}, | ||
| }, | ||
| } | ||
|
|
||
| # Combined 2:4 + skip-softmax with RULER calibration | ||
| SPARSE24_SKIP_SOFTMAX_CALIB = { | ||
| "sparse_cfg": { | ||
| "calibration": { | ||
| "target_sparse_ratio": {"prefill": 0.9, "decode": 0.9}, | ||
| "samples": 64, | ||
| "max_seqlen": 65536, | ||
| "chunk_size": 4096, | ||
| }, | ||
| "*attn*": { | ||
| "method": "flash_skip_softmax", | ||
| "br": 128, | ||
| "bc": 128, | ||
| "backend": "pytorch", | ||
| "collect_stats": True, | ||
| "apply_sparse24": True, | ||
| "enable": True, | ||
| }, | ||
| "default": {"enable": False}, | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "SKIP_SOFTMAX_CALIB", | ||
| "SKIP_SOFTMAX_DEFAULT", | ||
| "SPARSE24_SKIP_SOFTMAX", | ||
| "SPARSE24_SKIP_SOFTMAX_CALIB", | ||
| "CalibrationConfig", | ||
| "FlashSkipSoftmaxConfig", | ||
| "SparseAttentionAttributeConfig", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,24 @@ | |
| from . import SparseAttentionMethod, register_sparse_method | ||
|
|
||
|
|
||
| def sparse24_mask_along_last_dim(scores: torch.Tensor) -> torch.Tensor: | ||
| """Compute 2:4 mask: for every 4 elements along the last dim, keep the 2 largest. | ||
|
|
||
| Args: | ||
| scores: Tensor of shape [..., N] with N divisible by 4. | ||
|
|
||
| Returns: | ||
| Boolean mask of same shape; True where the element is kept (top-2 of 4). | ||
| """ | ||
| *prefix, n = scores.shape | ||
| assert n % 4 == 0, "2:4 sparsity requires last dim divisible by 4" | ||
| grouped = scores.reshape(*prefix, n // 4, 4) | ||
| _, top2_idx = torch.topk(grouped, k=2, dim=-1, largest=True, sorted=False) | ||
| mask = torch.zeros_like(grouped, dtype=torch.bool) | ||
| mask.scatter_(-1, top2_idx, True) | ||
| return mask.reshape(*prefix, n) | ||
|
|
||
|
|
||
| @register_sparse_method("flash_skip_softmax") | ||
| class FlashSkipSoftmax(SparseAttentionMethod): | ||
| """Flash Attention-aware softmax skip sparse attention method. | ||
|
|
@@ -55,6 +73,7 @@ def __init__(self, method_config: dict | None = None): | |
|
|
||
| # Optional parameters not in Pydantic config | ||
| self.phase = config.get("phase", None) | ||
| self.apply_sparse24 = config.get("apply_sparse24", False) | ||
|
|
||
| # Initialize threshold from dict config (prefill phase as default) | ||
| self.threshold = self.threshold_config.get("prefill", 1e-3) | ||
|
|
@@ -195,6 +214,15 @@ def calc_correction_factor_and_p( | |
| element_mask = element_mask.reshape(batch_size, num_heads, padded_seq_q, padded_seq_k) | ||
| element_mask = element_mask[:, :, :seq_q, :seq_k] | ||
|
|
||
| # Step 7b: Apply 2:4 structured sparsity on top of block mask (optional) | ||
| if self.apply_sparse24: | ||
| attn_padded = blocked_attn.reshape( | ||
| batch_size, num_heads, padded_seq_q, padded_seq_k | ||
| ) | ||
| sparse24_mask = sparse24_mask_along_last_dim(attn_padded) | ||
| sparse24_mask = sparse24_mask[:, :, :seq_q, :seq_k] | ||
| element_mask = element_mask & sparse24_mask | ||
|
|
||
|
Comment on lines
+217
to
+225
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reported sparsity no longer matches the final mask when sparse24 is enabled. At Line 224 and Line 278, Please compute and report a post-AND metric (or add separate Also applies to: 273-279 🤖 Prompt for AI Agents |
||
| # Step 8: Calculate sparsity statistics | ||
| if self.is_causal: | ||
| # For causal attention, only count lower triangle blocks (including diagonal) | ||
|
|
@@ -242,6 +270,13 @@ def calc_correction_factor_and_p( | |
| element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) | ||
| element_mask = element_mask[:, :, :seq_q, :seq_k] | ||
|
|
||
| # Step 6b: Apply 2:4 structured sparsity on top of block mask (optional) | ||
| if self.apply_sparse24: | ||
| attn_padded = blocked_attn.reshape(batch_size, num_heads, 1, padded_seq_k) | ||
| sparse24_mask = sparse24_mask_along_last_dim(attn_padded) | ||
| sparse24_mask = sparse24_mask[:, :, :seq_q, :seq_k] | ||
| element_mask = element_mask & sparse24_mask | ||
|
|
||
| # Step 7: Calculate sparsity statistics | ||
| dense_blocks = block_mask.sum() | ||
| total_valid_blocks = block_mask.numel() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: NVIDIA/Model-Optimizer
Length of output: 461
🏁 Script executed:
Repository: NVIDIA/Model-Optimizer
Length of output: 1121
🏁 Script executed:
Repository: NVIDIA/Model-Optimizer
Length of output: 5112
🏁 Script executed:
Repository: NVIDIA/Model-Optimizer
Length of output: 755
🏁 Script executed:
Repository: NVIDIA/Model-Optimizer
Length of output: 142
🏁 Script executed:
Repository: NVIDIA/Model-Optimizer
Length of output: 19968
skip_diagonal_blocksis defined in config but not implemented.The field is exposed in
SparseAttentionAttributeConfig(line 96-102) with documented intent to "keep diagonal tiles dense for 2:4 sparse attention," but:method_configinFlashSkipSoftmax.__init__calc_correction_factor_and_p, lines 218–224, 274–278)sparse24_mask_along_last_dimapplies uniform 2:4 sparsity without diagonal block preservationEither wire the flag into the mask construction or document it as reserved for future use.
🤖 Prompt for AI Agents