diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 74c5e9a540..96c5e71fbd 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -31,6 +31,8 @@ from modelopt.torch.sparsity.attention_sparsity.config import ( SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_DEFAULT, + SPARSE24_SKIP_SOFTMAX, + SPARSE24_SKIP_SOFTMAX_CALIB, ) from modelopt.torch.utils.memory_monitor import launch_memory_monitor @@ -43,6 +45,8 @@ SPARSE_ATTN_CFG_CHOICES = { "skip_softmax": SKIP_SOFTMAX_DEFAULT, "skip_softmax_calib": SKIP_SOFTMAX_CALIB, + "sparse24_skip_softmax": SPARSE24_SKIP_SOFTMAX, + "sparse24_skip_softmax_calib": SPARSE24_SKIP_SOFTMAX_CALIB, } diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 2d73f13ad7..7759ad8214 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -93,6 +93,23 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) + skip_diagonal_blocks: bool = ModeloptField( + default=True, + title="Skip diagonal blocks.", + description=( + "When True, keep diagonal tiles dense for 2:4 sparse attention. Defaults to True." + ), + ) + + apply_sparse24: bool = ModeloptField( + default=False, + title="Apply 2:4 structured sparsity.", + description=( + "If True, additionally apply 2:4 structured sparsity (top-2 of every 4 elements " + "along seq_k) on top of the skip-softmax block mask. Only used by flash_skip_softmax." + ), + ) + @field_validator("method") @classmethod def validate_method(cls, v): @@ -416,10 +433,51 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): }, } +# Combined 2:4 structured sparsity + skip-softmax block mask (pytorch backend) +SPARSE24_SKIP_SOFTMAX = { + "sparse_cfg": { + "*attn*": { + "method": "flash_skip_softmax", + "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "collect_stats": True, + "apply_sparse24": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + +# Combined 2:4 + skip-softmax with RULER calibration +SPARSE24_SKIP_SOFTMAX_CALIB = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": {"prefill": 0.9, "decode": 0.9}, + "samples": 64, + "max_seqlen": 65536, + "chunk_size": 4096, + }, + "*attn*": { + "method": "flash_skip_softmax", + "br": 128, + "bc": 128, + "backend": "pytorch", + "collect_stats": True, + "apply_sparse24": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + __all__ = [ "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "SPARSE24_SKIP_SOFTMAX", + "SPARSE24_SKIP_SOFTMAX_CALIB", "CalibrationConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index f911b95f79..4ded7d3c79 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -28,6 +28,24 @@ from . import SparseAttentionMethod, register_sparse_method +def sparse24_mask_along_last_dim(scores: torch.Tensor) -> torch.Tensor: + """Compute 2:4 mask: for every 4 elements along the last dim, keep the 2 largest. + + Args: + scores: Tensor of shape [..., N] with N divisible by 4. + + Returns: + Boolean mask of same shape; True where the element is kept (top-2 of 4). + """ + *prefix, n = scores.shape + assert n % 4 == 0, "2:4 sparsity requires last dim divisible by 4" + grouped = scores.reshape(*prefix, n // 4, 4) + _, top2_idx = torch.topk(grouped, k=2, dim=-1, largest=True, sorted=False) + mask = torch.zeros_like(grouped, dtype=torch.bool) + mask.scatter_(-1, top2_idx, True) + return mask.reshape(*prefix, n) + + @register_sparse_method("flash_skip_softmax") class FlashSkipSoftmax(SparseAttentionMethod): """Flash Attention-aware softmax skip sparse attention method. @@ -55,6 +73,7 @@ def __init__(self, method_config: dict | None = None): # Optional parameters not in Pydantic config self.phase = config.get("phase", None) + self.apply_sparse24 = config.get("apply_sparse24", False) # Initialize threshold from dict config (prefill phase as default) self.threshold = self.threshold_config.get("prefill", 1e-3) @@ -195,6 +214,15 @@ def calc_correction_factor_and_p( element_mask = element_mask.reshape(batch_size, num_heads, padded_seq_q, padded_seq_k) element_mask = element_mask[:, :, :seq_q, :seq_k] + # Step 7b: Apply 2:4 structured sparsity on top of block mask (optional) + if self.apply_sparse24: + attn_padded = blocked_attn.reshape( + batch_size, num_heads, padded_seq_q, padded_seq_k + ) + sparse24_mask = sparse24_mask_along_last_dim(attn_padded) + sparse24_mask = sparse24_mask[:, :, :seq_q, :seq_k] + element_mask = element_mask & sparse24_mask + # Step 8: Calculate sparsity statistics if self.is_causal: # For causal attention, only count lower triangle blocks (including diagonal) @@ -242,6 +270,13 @@ def calc_correction_factor_and_p( element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) element_mask = element_mask[:, :, :seq_q, :seq_k] + # Step 6b: Apply 2:4 structured sparsity on top of block mask (optional) + if self.apply_sparse24: + attn_padded = blocked_attn.reshape(batch_size, num_heads, 1, padded_seq_k) + sparse24_mask = sparse24_mask_along_last_dim(attn_padded) + sparse24_mask = sparse24_mask[:, :, :seq_q, :seq_k] + element_mask = element_mask & sparse24_mask + # Step 7: Calculate sparsity statistics dense_blocks = block_mask.sum() total_valid_blocks = block_mask.numel()