From 244827eb1c5159e29a04ae94b8114246b58d3f26 Mon Sep 17 00:00:00 2001 From: Rohan Joshi Date: Fri, 6 Mar 2026 01:52:34 +0000 Subject: [PATCH] Refactor threshold -> thresholds; calibrate with all thresholds in a single pass Signed-off-by: Rohan Joshi --- .../llm_sparsity/attention_sparsity/hf_sa.py | 8 +- .../calibration/calibrator.py | 62 +++--- .../sparsity/attention_sparsity/config.py | 47 +++-- .../methods/flash_skip_softmax.py | 178 ++++++++++-------- .../attention_sparsity/model_sparsify.py | 4 +- .../attention_sparsity/stats_manager.py | 21 ++- .../torch/sparsity/sparse_attention_common.py | 4 +- .../test_calibration_gpu.py | 16 +- .../test_integration_gpu.py | 11 +- .../test_flash_skip_softmax.py | 44 ++--- .../test_sparse_attention_calibration.py | 10 +- .../test_sparse_attention_config.py | 28 +-- .../test_sparse_attention_conversion.py | 10 +- .../attention_sparsity/test_stats_manager.py | 46 ++--- .../attention_sparsity/test_threshold_info.py | 20 +- 15 files changed, 275 insertions(+), 234 deletions(-) diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 74c5e9a54..0b97298f5 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -150,7 +150,8 @@ def main(args): model = AutoModelForCausalLM.from_pretrained( args.pyt_ckpt_path, attn_implementation="eager", - torch_dtype=torch.bfloat16, + torch_dtype="auto", + device_map="auto", ) tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) @@ -158,11 +159,6 @@ def main(args): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - # Move model to GPU if available - if torch.cuda.is_available(): - model = model.cuda() - print("Model moved to CUDA") - # Generate sample output BEFORE sparse attention print("\nGenerating sample output before sparse attention...") output_before, test_prompt, input_ids = generate_sample_output(model, tokenizer, args) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py index df2c05e20..682120693 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -24,7 +24,6 @@ import torch import torch.nn as nn from scipy.optimize import curve_fit -from tqdm import tqdm from ..stats_manager import SparseAttentionStatsManager from ..utils import get_sparse_attention_modules @@ -91,9 +90,9 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic """Calibrate a and b parameters for Exponential model. Algorithm: - 1. For each threshold λ_j in threshold_trials: - - Run ALL samples, collect sparsities S_ij for each sample i - - Compute scale_factor_ij = λ_j × L_i (where L_i is sample length) + 1. Set thresholds = threshold_trials on all modules, run ONE forward pass. + Each module returns a sparsity list (one entry per threshold) per sample. + Unpack to get (scale_factor_ij = λ_j × L_i, sparsity_ij) pairs. 2. Fit Exponential model to ALL (sf_ij, S_ij) pairs: scale_factor = a * exp(b * sparsity) @@ -121,29 +120,25 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic print(f"Starting Exponential model calibration ({phase} phase)") print(f"Threshold trials: {len(self.threshold_trials)}") - # Stage 1: Collect ALL (scale_factor, sparsity) pairs for all thresholds and samples - print(f"\nStage 1: Collecting {phase} sparsity data for all thresholds...") + # Stage 1: Collect ALL (scale_factor, sparsity) pairs in a single forward pass. + # All threshold_trials are passed at once; each module returns a sparsity list + # with one entry per threshold, eliminating the need for repeated forward passes. + print(f"\nStage 1: Collecting {phase} sparsity data for all thresholds in one pass...") - # Collect ALL individual data points (not averaged) all_data_points = [] # List of {"threshold", "length", "scale_factor", "sparsity"} - for threshold in tqdm(self.threshold_trials, desc=f"Testing thresholds ({phase})"): - self._set_threshold(attention_modules, threshold) - self._enable_calibration_mode(attention_modules) - with torch.no_grad(): - forward_loop(model) - per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) - self._disable_calibration_mode(attention_modules) - - if not per_sample_stats: - continue - - # Collect individual (scale_factor, sparsity) pairs for each sample - for sample_stat in per_sample_stats: - length = sample_stat["sample_length"] - sparsity = sample_stat["sparsity"] + self._set_thresholds(attention_modules, self.threshold_trials) + self._enable_calibration_mode(attention_modules) + with torch.no_grad(): + forward_loop(model) + per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase) + self._disable_calibration_mode(attention_modules) + + for sample_stat in per_sample_stats: + length = sample_stat["sample_length"] + sparsity_list = sample_stat["sparsity"] + for threshold, sparsity in zip(self.threshold_trials, sparsity_list): scale_factor = threshold * length - all_data_points.append( { "threshold": threshold, @@ -307,17 +302,26 @@ def _extract_calibration_stats( aggregated_stats = [] for sample_idx in range(num_samples): - sparsities = [] + sparsity_lists = [] sample_length = 0 for module_stats in all_per_sample_stats: if sample_idx < len(module_stats): sample_stat = module_stats[sample_idx] - sparsities.append(sample_stat.get("sparsity", 0.0)) + sparsity = sample_stat.get("sparsity", []) + sparsity_lists.append(sparsity if isinstance(sparsity, list) else [sparsity]) if not sample_length and "sample_length" in sample_stat: sample_length = sample_stat["sample_length"] - avg_sparsity = float(np.mean(sparsities)) if sparsities else 0.0 + if not sparsity_lists: + continue + + lengths = [len(s) for s in sparsity_lists] + assert len(set(lengths)) == 1, ( + f"All modules must have the same number of thresholds, got {lengths}" + ) + n = lengths[0] + avg_sparsity = [float(np.mean([sl[i] for sl in sparsity_lists])) for i in range(n)] aggregated_stats.append( { @@ -328,7 +332,7 @@ def _extract_calibration_stats( return aggregated_stats - def _set_threshold(self, modules: list[nn.Module], threshold: float): - """Set threshold on sparse attention modules.""" + def _set_thresholds(self, modules: list[nn.Module], thresholds: list[float]): + """Set thresholds list on sparse attention modules.""" for module in modules: - module._sparse_method_instance.threshold = threshold + module._sparse_method_instance.thresholds = thresholds diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index 2d73f13ad..180ff89b0 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -46,12 +46,14 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): description="If True, enables sparse attention. If False, bypasses sparsity.", ) - threshold: dict[str, float] = ModeloptField( - default={"prefill": 1e-3, "decode": 1e-4}, - title="Sparsity threshold.", + thresholds: dict[str, list[float]] = ModeloptField( + default={"prefill": [1e-3], "decode": [1e-4]}, + title="Sparsity thresholds.", description=( - "Threshold for determining which attention values to skip. " - "Must be a dict with 'prefill' and 'decode' keys." + "Thresholds for determining which attention values to skip. " + "Must be a dict with 'prefill' and/or 'decode' keys, each mapping to a list of floats. " + "Prefill and decode lists must have the same length. " + "Sparsity is computed per threshold; the first threshold's mask is applied." ), ) @@ -120,10 +122,10 @@ def validate_block_size(cls, v): raise ValueError(f"Block size must be positive, got {v}") return v - @field_validator("threshold") + @field_validator("thresholds") @classmethod - def validate_threshold(cls, v): - """Validate threshold is a dict with valid phases and values in range (0, 1).""" + def validate_thresholds(cls, v): + """Validate thresholds is a dict of lists with valid phases and values in range (0, 1).""" if not isinstance(v, dict): raise ValueError( f"Threshold must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}" @@ -135,12 +137,25 @@ def validate_threshold(cls, v): raise ValueError( f"Invalid threshold phases: {invalid_keys}. Valid phases: {valid_phases}" ) - # Validate all values are in range (0, 1) - for phase, threshold in v.items(): - if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1: + # Validate all values are lists of floats in range (0, 1) + lengths = {} + for phase, threshold_list in v.items(): + if not isinstance(threshold_list, list) or len(threshold_list) == 0: raise ValueError( - f"Threshold for phase '{phase}' must be in range (0, 1), got {threshold}" + f"Thresholds for phase '{phase}' must be a non-empty list, got {threshold_list}" ) + for threshold in threshold_list: + if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1: + raise ValueError( + f"Each threshold for phase '{phase}' must be in range (0, 1), got {threshold}" + ) + lengths[phase] = len(threshold_list) + # Validate prefill and decode lists have the same length + if len(lengths) == 2 and len(set(lengths.values())) != 1: + raise ValueError( + f"Prefill and decode threshold lists must have the same length, " + f"got prefill={lengths['prefill']}, decode={lengths['decode']}" + ) return v @@ -356,7 +371,7 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): default={ "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "thresholds": {"prefill": [1e-3], "decode": [1e-5]}, "br": 128, # Flash Attention block rows "bc": 128, # Flash Attention block columns "backend": "pytorch", # Only pytorch backend supported @@ -378,9 +393,9 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": { - "prefill": 1e-3, # More aggressive during prefill - "decode": 1e-4, # Conservative during decode + "thresholds": { + "prefill": [1e-3], # More aggressive during prefill + "decode": [1e-4], # Conservative during decode }, "br": 128, # Flash Attention block rows "bc": 128, # Flash Attention block columns diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index f911b95f7..9bfd6a954 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -40,14 +40,14 @@ def __init__(self, method_config: dict | None = None): """Initialize Flash softmax skip method. Args: - method_config: Configuration dict with threshold, br, bc, is_causal, etc. + method_config: Configuration dict with thresholds, br, bc, is_causal, etc. All required fields should have defaults from SparseAttentionAttributeConfig. """ super().__init__() config = method_config or {} # Extract configuration - self.threshold_config = config["threshold"] + self.thresholds_config = config["thresholds"] self.br = config["br"] self.bc = config["bc"] self.backend = config["backend"] @@ -56,19 +56,19 @@ def __init__(self, method_config: dict | None = None): # Optional parameters not in Pydantic config self.phase = config.get("phase", None) - # Initialize threshold from dict config (prefill phase as default) - self.threshold = self.threshold_config.get("prefill", 1e-3) + # Initialize thresholds from dict config (prefill phase as default) + self.thresholds = self.thresholds_config.get("prefill", [1e-3]) # Calibration mode flag (prevents threshold updates during calibration) self._calibration_mode = False def set_calibration_mode(self, enabled: bool): - """Set calibration mode to prevent _update_threshold from modifying the threshold.""" + """Set calibration mode to prevent _update_thresholds from modifying the thresholds.""" self._calibration_mode = enabled - def _update_threshold(self, phase: str): - """Update threshold based on phase.""" - self.threshold = self.threshold_config.get(phase, self.threshold) + def _update_thresholds(self, phase: str): + """Update thresholds list based on phase.""" + self.thresholds = self.thresholds_config.get(phase, self.thresholds) def _infer_phase(self, attention_scores: torch.Tensor) -> str: """Infer phase from attention scores shape.""" @@ -132,25 +132,25 @@ def calc_correction_factor_and_p( """ batch_size, num_heads, seq_q, seq_k = attn_weights.shape - # Calculate threshold + # Check whether to use calibrated single-threshold path or multi-threshold list path calibration_params = self.calibration_params target_sparse_ratio = self.target_sparse_ratio - - if ( + use_calibration_params = ( calibration_params is not None and phase in calibration_params and target_sparse_ratio is not None - ): - # Use calibrated a, b to compute dynamic threshold - # Exponential model: scale_factor = a * exp(b * target_sparsity) + ) + + if use_calibration_params: + # Calibrated dynamic threshold: bypass thresholds list entirely + assert calibration_params is not None and target_sparse_ratio is not None a = calibration_params[phase]["a"] b = calibration_params[phase]["b"] target_sparsity = target_sparse_ratio.get(phase, 0.5) scale_factor = a * np.exp(b * target_sparsity) - log_threshold = np.log(scale_factor / seq_k) + log_thresholds = [np.log(scale_factor / seq_k)] else: - # Use static threshold from config (no calibration or phase not calibrated) - log_threshold = np.log(self.threshold) + log_thresholds = [np.log(t) for t in self.thresholds] if phase == "prefill": blocked_attn, num_block_rows, num_block_cols, padded_seq_q, padded_seq_k = ( @@ -158,103 +158,115 @@ def calc_correction_factor_and_p( ) # Step 1: Compute maximum value in each block - # For each 128x128 block, find max across the 128 columns - # blocked_attn: [batch, heads, block_rows, br=128, block_cols, bc=128] - # block_max: [batch, heads, block_rows, br=128, block_cols] + # blocked_attn: [batch, heads, block_rows, br, block_cols, bc] + # block_max: [batch, heads, block_rows, br, block_cols] block_max = blocked_attn.max(dim=-1)[0] + del blocked_attn # free padded copy early; block_max holds what we need # Step 2: Track cumulative maximum across blocks (left to right) - # This simulates Flash Attention's online softmax normalization - # block_max_cummax: [batch, heads, block_rows, br=128, block_cols] block_max_cummax = block_max.cummax(dim=-1)[0] - # Step 3: Calculate correction factor (how often max changes) - # Used by Flash Attention to adjust running sum when max increases + # Step 3: Calculate correction factor block_max_larger = torch.ones_like(block_max) block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item() - del block_max, block_max_larger - - # Step 4 & 5: Compute threshold mask directly without storing p. - # Fusing the subtraction and comparison avoids allocating a second - # full attention-matrix-sized tensor alongside blocked_attn. - p_larger_than_thresh = (blocked_attn - block_max_cummax[..., None]) > log_threshold - del block_max_cummax - - # Reduce over bc (128 cols), then br (128 rows) to get block-level decision - # Result: [batch, heads, block_rows, block_cols] - block_mask = p_larger_than_thresh.any(dim=-1).any(dim=-2) - del p_larger_than_thresh + del block_max_larger - # Step 6: Expand block mask back to element level - # All 128x128 elements in a block share the same mask value - # [batch, heads, block_rows, block_cols] -> [batch, heads, block_rows, br=128, block_cols, bc=128] - element_mask = block_mask.unsqueeze(-2).unsqueeze(-1).expand_as(blocked_attn) - - # Step 7: Reshape to original attention shape and remove padding - element_mask = element_mask.reshape(batch_size, num_heads, padded_seq_q, padded_seq_k) - element_mask = element_mask[:, :, :seq_q, :seq_k] - - # Step 8: Calculate sparsity statistics + # Pre-compute total_valid_blocks (same for all thresholds) if self.is_causal: - # For causal attention, only count lower triangle blocks (including diagonal) num_causal_blocks = num_block_rows * (2 * num_block_cols - num_block_rows + 1) // 2 total_valid_blocks = batch_size * num_heads * num_causal_blocks - dense_blocks = block_mask.sum() total_blocks = num_causal_blocks else: - dense_blocks = block_mask.sum() # Keep as tensor - total_valid_blocks = block_mask.numel() + total_valid_blocks = batch_size * num_heads * num_block_rows * num_block_cols total_blocks = num_block_rows * num_block_cols - sparsity = 1.0 - dense_blocks.item() / total_valid_blocks + + # Step 4-5: Loop over thresholds, computing block mask and sparsity for each. + # Only store block_mask for the first threshold (used for element_mask). + # In calibration mode, skip element_mask entirely to save memory. + # We compare block_max to block_max_cummax directly (avoids materializing the + # full blocked_attn-sized intermediate tensor — saves ~1x attn weights per threshold). + dense_blocks_list = [] + block_mask_0 = None + block_diff = block_max - block_max_cummax + for i, log_threshold in enumerate(log_thresholds): + block_mask = (block_diff > log_threshold).any(dim=-2) + + dense_blocks_list.append(block_mask.sum().item()) + + if i == 0 and not self._calibration_mode: + block_mask_0 = block_mask + del block_mask + + del block_max, block_max_cummax + + # Step 6-7: Expand block_mask_0 to element level (skip in calibration mode) + if not self._calibration_mode and block_mask_0 is not None: + element_mask = ( + block_mask_0.unsqueeze(-2) + .unsqueeze(-1) + .expand(batch_size, num_heads, num_block_rows, self.br, num_block_cols, self.bc) + ) + del block_mask_0 + element_mask = element_mask.reshape( + batch_size, num_heads, padded_seq_q, padded_seq_k + ) + element_mask = element_mask[:, :, :seq_q, :seq_k] + else: + element_mask = None + else: # decode blocked_attn, _, num_block_cols, _, padded_seq_k = self._reshape_to_blocks( attn_weights, 1, self.bc ) - # Decode: Single query row attends to all past key blocks - # blocked_attn: [batch, heads, 1, 1, num_block_cols, bc=128] - - # Step 1: Find maximum in each key block - # block_max: [batch, heads, 1, 1, num_block_cols] + # blocked_attn: [batch, heads, 1, 1, num_block_cols, bc] block_max = blocked_attn.max(dim=-1)[0] - - # Step 2: Track cumulative maximum across key blocks (left to right) - # Simulates Flash Attention's online softmax normalization + del blocked_attn # free early; block_max holds what we need block_max_cummax = block_max.cummax(dim=-1)[0] - # Step 3: Calculate correction factor - # Tracks how often the maximum increases (needed for Flash Attention rescaling) block_max_larger = torch.ones_like(block_max) block_max_larger[..., 1:] = block_max[..., 1:] > block_max_cummax[..., :-1] correction_factor = (block_max_larger.sum() / block_max_larger.numel()).item() - del block_max, block_max_larger + del block_max_larger - # Step 4 & 5: Compute threshold mask directly without storing p. - p_larger_than_thresh = (blocked_attn - block_max_cummax[..., None]) > log_threshold - del block_max_cummax + total_valid_blocks = batch_size * num_heads * num_block_cols + total_blocks = num_block_cols - block_mask = p_larger_than_thresh.any(dim=-1, keepdim=False) - del p_larger_than_thresh + dense_blocks_list = [] + block_mask_0 = None + for i, log_threshold in enumerate(log_thresholds): + block_mask = block_max - block_max_cummax > log_threshold - # Step 6: Expand to element level and remove padding - element_mask = block_mask[..., None].expand_as(blocked_attn) - element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) - element_mask = element_mask[:, :, :seq_q, :seq_k] + dense_blocks_list.append(block_mask.sum().item()) - # Step 7: Calculate sparsity statistics - dense_blocks = block_mask.sum() - total_valid_blocks = block_mask.numel() - sparsity = 1.0 - dense_blocks.item() / total_valid_blocks - total_blocks = num_block_cols + if i == 0 and not self._calibration_mode: + block_mask_0 = block_mask + del block_mask + + del block_max, block_max_cummax + + if not self._calibration_mode and block_mask_0 is not None: + element_mask = block_mask_0[..., None].expand( + batch_size, num_heads, 1, 1, num_block_cols, self.bc + ) + del block_mask_0 + element_mask = element_mask.reshape(batch_size, num_heads, 1, padded_seq_k) + element_mask = element_mask[:, :, :seq_q, :seq_k] + else: + element_mask = None + + sparsity_list = [1.0 - d / total_valid_blocks for d in dense_blocks_list] + + sparsity_out = sparsity_list + sparse_blocks_out = [int(s * total_blocks) for s in sparsity_list] - # Create stats dictionary stats = { "correction_factor": correction_factor, - "sparsity": sparsity, + "sparsity": sparsity_out, "phase": phase, "total_blocks": total_blocks, - "sparse_blocks": int(sparsity * total_blocks), + "sparse_blocks": sparse_blocks_out, "sample_length": seq_k, } @@ -280,9 +292,9 @@ def calculate_sparsity( # Infer phase from tensor shape phase = self._infer_phase(attention_scores) - # Update threshold for the detected phase (skip during calibration) + # Update thresholds for the detected phase (skip during calibration) if not self._calibration_mode: - self._update_threshold(phase) + self._update_thresholds(phase) # Calculate block-wise sparsity mask and stats sparse_mask, stats = self.calc_correction_factor_and_p(attention_scores, phase) @@ -347,10 +359,10 @@ def get_threshold_info(self) -> dict[str, Any]: "phases": phase_info, } else: - # Static threshold (single value or phase-specific dict) + # Static thresholds (list per phase) return { "type": "static", - "value": self.threshold_config, + "value": self.thresholds_config, } @property diff --git a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py index b79e25bd8..28c18943a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py +++ b/modelopt/torch/sparsity/attention_sparsity/model_sparsify.py @@ -52,7 +52,7 @@ def sparsify( Sparse attention configurations is a dictionary mapping wildcards or filter functions to its sparse attention attributes. The wildcards or filter functions are matched - against the module names. The sparse attention attributes include ``"threshold"``, + against the module names. The sparse attention attributes include ``"thresholds"``, ``"enable"``, and method-specific parameters. An example ``config`` dictionary is given below: @@ -64,7 +64,7 @@ def sparsify( # Phase-aware thresholds with backend selection "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "thresholds": {"prefill": [1e-3], "decode": [1e-5]}, "backend": "pytorch", # Only pytorch backend supported "enable": True, }, diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py index b84a3cade..1eabdfe35 100644 --- a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -45,7 +45,6 @@ def __init__(self, module_name: str, enabled: bool = True): self.aggregated_stats: dict = { "total_calls": 0, "total_blocks": 0, - "sparse_blocks": 0, "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, } @@ -66,7 +65,13 @@ def collect(self, stats: dict): # Update aggregated stats self.aggregated_stats["total_calls"] += 1 self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0) - self.aggregated_stats["sparse_blocks"] += stats.get("sparse_blocks", 0) + + incoming = stats["sparse_blocks"] + if "sparse_blocks" not in self.aggregated_stats: + self.aggregated_stats["sparse_blocks"] = list(incoming) + else: + for i, val in enumerate(incoming): + self.aggregated_stats["sparse_blocks"][i] += val phase = stats.get("phase", "unknown") if phase in self.aggregated_stats["phase_counts"]: @@ -91,10 +96,15 @@ def get_summary(self) -> dict: and phase distribution. """ total_blocks = self.aggregated_stats["total_blocks"] - if total_blocks > 0: - avg_sparsity = self.aggregated_stats["sparse_blocks"] / total_blocks + sparse_blocks = self.aggregated_stats.get("sparse_blocks") + if sparse_blocks is not None: + avg_sparsity = ( + [sb / total_blocks for sb in sparse_blocks] + if total_blocks > 0 + else [0.0] * len(sparse_blocks) + ) else: - avg_sparsity = 0.0 + avg_sparsity = [] return { "module": self.module_name, @@ -122,7 +132,6 @@ def reset(self): self.aggregated_stats = { "total_calls": 0, "total_blocks": 0, - "sparse_blocks": 0, "phase_counts": {"prefill": 0, "decode": 0, "unknown": 0}, } self.per_sample_stats = [] diff --git a/tests/_test_utils/torch/sparsity/sparse_attention_common.py b/tests/_test_utils/torch/sparsity/sparse_attention_common.py index 6e9ae5014..58d711aff 100644 --- a/tests/_test_utils/torch/sparsity/sparse_attention_common.py +++ b/tests/_test_utils/torch/sparsity/sparse_attention_common.py @@ -95,7 +95,7 @@ def get_input(cls, d_model=128, seq_len=10, batch_size=2): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-4, "decode": 1e-4}, + "thresholds": {"prefill": [1e-4], "decode": [1e-4]}, "br": 128, "bc": 128, "enable": True, @@ -107,7 +107,7 @@ def get_input(cls, d_model=128, seq_len=10, batch_size=2): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "thresholds": {"prefill": [1e-3], "decode": [1e-5]}, "br": 128, "bc": 128, "enable": True, diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py index 97296971d..39f983766 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py @@ -115,7 +115,7 @@ def test_calibration_simple_model(self, simple_model): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 64, "bc": 64, "backend": "pytorch", @@ -157,7 +157,7 @@ def test_calibration_pytorch_backend(self, simple_model): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "backend": "pytorch", "enable": True, "calibration": { @@ -189,7 +189,7 @@ def test_simplified_calibration(self, simple_model): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "enable": True, "calibration": { "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, @@ -216,7 +216,7 @@ def test_calibration_persistence(self, simple_model): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "enable": True, "calibration": { "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, @@ -263,7 +263,7 @@ def test_calibrated_model_inference(self, simple_model_setup): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "backend": "pytorch", "enable": True, "calibration": { @@ -299,7 +299,7 @@ def test_calibrated_vs_fixed_threshold(self, simple_model_setup): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "enable": True, "calibration": { "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, @@ -315,7 +315,7 @@ def test_calibrated_vs_fixed_threshold(self, simple_model_setup): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "enable": True, } }, @@ -358,7 +358,7 @@ def test_memory_usage(self, simple_model_setup): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "enable": True, "calibration": { "target_sparse_ratio": {"prefill": 0.5, "decode": 0.0}, diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py index df4cfaa65..64177ad89 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py @@ -66,7 +66,7 @@ def test_load_and_sparsify(self, tinyllama_model): sparse_cfg={ "*attn*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -94,7 +94,7 @@ def test_forward_prefill(self, tinyllama_model, tinyllama_tokenizer): config = SparseAttentionConfig( sparse_cfg={ "*attn*": { - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "backend": "pytorch", "enable": True, } @@ -124,7 +124,10 @@ def test_forward_decode(self, tinyllama_model): config = SparseAttentionConfig( sparse_cfg={ "*attn*": { - "threshold": {"prefill": 1e-3, "decode": 1e-5}, # More conservative for decode + "thresholds": { + "prefill": [1e-3], + "decode": [1e-5], + }, # More conservative for decode "backend": "pytorch", "enable": True, } @@ -163,7 +166,7 @@ def test_gqa_attention(self, tinyllama_model): sparse_config = SparseAttentionConfig( sparse_cfg={ "*attn*": { - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "backend": "pytorch", "enable": True, } diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py index ce2fa3da2..5c8b7d984 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_flash_skip_softmax.py @@ -30,7 +30,7 @@ def test_phase_inference(self): """Test phase detection from attention score shape.""" method = FlashSkipSoftmax( { - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -50,7 +50,7 @@ def test_threshold_update_dict_config(self): """Test threshold updates with dict config.""" method = FlashSkipSoftmax( { - "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "thresholds": {"prefill": [1e-3], "decode": [1e-5]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -58,23 +58,23 @@ def test_threshold_update_dict_config(self): } ) - # Initially uses prefill threshold - initial_threshold = method.threshold + # Initially uses prefill thresholds + initial_thresholds = method.thresholds # Update to decode - method._update_threshold("decode") - assert method.threshold == 1e-5 - assert method.threshold != initial_threshold + method._update_thresholds("decode") + assert method.thresholds == [1e-5] + assert method.thresholds != initial_thresholds # Update back to prefill - method._update_threshold("prefill") - assert method.threshold == 1e-3 + method._update_thresholds("prefill") + assert method.thresholds == [1e-3] def test_block_reshaping_divisible(self): """Test block reshaping with divisible sequence lengths.""" method = FlashSkipSoftmax( { - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -97,7 +97,7 @@ def test_block_reshaping_with_padding(self): """Test block reshaping with non-divisible lengths.""" method = FlashSkipSoftmax( { - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -120,7 +120,7 @@ def test_correction_factor_calculation_prefill(self): """Test correction factor for prefill phase.""" method = FlashSkipSoftmax( { - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -140,14 +140,15 @@ def test_correction_factor_calculation_prefill(self): assert "total_blocks" in stats assert stats["phase"] == "prefill" assert 0 <= stats["correction_factor"] <= 1 - # Sparsity can be negative if threshold is too low (more blocks kept than expected) - assert -1 <= stats["sparsity"] <= 1 + # sparsity is now a list (one entry per threshold) + assert isinstance(stats["sparsity"], list) + assert all(-1 <= s <= 1 for s in stats["sparsity"]) def test_correction_factor_calculation_decode(self): """Test correction factor for decode phase.""" method = FlashSkipSoftmax( { - "threshold": {"prefill": 1e-3, "decode": 1e-5}, + "thresholds": {"prefill": [1e-3], "decode": [1e-5]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -163,14 +164,15 @@ def test_correction_factor_calculation_decode(self): # Verify stats structure assert stats["phase"] == "decode" assert "correction_factor" in stats - assert 0 <= stats["sparsity"] <= 1 + assert isinstance(stats["sparsity"], list) + assert all(0 <= s <= 1 for s in stats["sparsity"]) assert mask.shape == (1, 1, 1, 256) def test_block_mask_correctness(self): """Test block mask shape and type.""" method = FlashSkipSoftmax( { - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -189,7 +191,7 @@ def test_block_mask_correctness(self): def test_causal_vs_noncausal(self): """Test total_blocks calculation for causal vs non-causal.""" config_base = { - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -212,7 +214,7 @@ def test_calculate_sparsity_assertions(self): """Test calculate_sparsity input validation.""" method = FlashSkipSoftmax( { - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -228,7 +230,7 @@ def test_apply_sparsity_with_mask(self): """Test apply_sparsity with pre-computed mask.""" method = FlashSkipSoftmax( { - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -255,7 +257,7 @@ def test_apply_sparsity_without_mask(self): """Test apply_sparsity calculates mask internally when None.""" method = FlashSkipSoftmax( { - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 128, "bc": 128, "backend": "pytorch", diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py index b91ec40cf..0785cdf22 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py @@ -194,7 +194,7 @@ def test_calibration_disabled(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 64, "bc": 64, "enable": True, @@ -228,7 +228,7 @@ def test_sparsify_with_calibration_requires_forward_loop(self): }, "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "br": 64, "bc": 64, "enable": True, @@ -331,7 +331,7 @@ def test_calibrate_empty_stats(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 0.1, "decode": 0.1}, + "thresholds": {"prefill": [0.1], "decode": [0.1]}, "br": 64, "bc": 64, "enable": True, @@ -365,7 +365,7 @@ def test_calibrate_no_config(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 0.1, "decode": 0.1}, + "thresholds": {"prefill": [0.1], "decode": [0.1]}, "br": 64, "bc": 64, "enable": True, @@ -408,7 +408,7 @@ def test_extract_calibration_config_none(self): "sparse_cfg": { "*attn*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 0.1, "decode": 0.1}, + "thresholds": {"prefill": [0.1], "decode": [0.1]}, } }, } diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py index ddbb718f4..a53f037d8 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py @@ -30,13 +30,13 @@ def test_valid_config(self): """Test creating valid config.""" config = SparseAttentionAttributeConfig( method="flash_skip_softmax", - threshold={"prefill": 1e-4, "decode": 1e-4}, + thresholds={"prefill": [1e-4], "decode": [1e-4]}, br=128, bc=128, enable=True, ) assert config.method == "flash_skip_softmax" - assert config.threshold == {"prefill": 1e-4, "decode": 1e-4} + assert config.thresholds == {"prefill": [1e-4], "decode": [1e-4]} assert config.br == 128 assert config.bc == 128 @@ -63,44 +63,44 @@ def test_threshold_validation_range(self): """Test threshold dict values must be in range (0, 1).""" # Zero value with pytest.raises(ValidationError, match="must be in range"): - SparseAttentionAttributeConfig(threshold={"prefill": 0, "decode": 1e-4}) + SparseAttentionAttributeConfig(thresholds={"prefill": [0], "decode": [1e-4]}) # Negative value with pytest.raises(ValidationError, match="must be in range"): - SparseAttentionAttributeConfig(threshold={"prefill": -0.1, "decode": 1e-4}) + SparseAttentionAttributeConfig(thresholds={"prefill": [-0.1], "decode": [1e-4]}) # Value equals 1.0 with pytest.raises(ValidationError, match="must be in range"): - SparseAttentionAttributeConfig(threshold={"prefill": 1.0, "decode": 1e-4}) + SparseAttentionAttributeConfig(thresholds={"prefill": [1.0], "decode": [1e-4]}) # Value greater than 1.0 with pytest.raises(ValidationError, match="must be in range"): - SparseAttentionAttributeConfig(threshold={"prefill": 1.5, "decode": 1e-4}) + SparseAttentionAttributeConfig(thresholds={"prefill": [1.5], "decode": [1e-4]}) def test_threshold_validation_dict(self): """Test threshold dict validation.""" # Valid phase-aware threshold - config = SparseAttentionAttributeConfig(threshold={"prefill": 1e-3, "decode": 1e-5}) - assert config.threshold == {"prefill": 1e-3, "decode": 1e-5} + config = SparseAttentionAttributeConfig(thresholds={"prefill": [1e-3], "decode": [1e-5]}) + assert config.thresholds == {"prefill": [1e-3], "decode": [1e-5]} # Invalid phase key with pytest.raises(ValidationError, match="Invalid threshold phases"): - SparseAttentionAttributeConfig(threshold={"invalid_phase": 1e-3}) + SparseAttentionAttributeConfig(thresholds={"invalid_phase": [1e-3]}) # Invalid threshold value in dict (negative) with pytest.raises(ValidationError, match="must be in range"): - SparseAttentionAttributeConfig(threshold={"prefill": -1e-3}) + SparseAttentionAttributeConfig(thresholds={"prefill": [-1e-3]}) # Invalid threshold value in dict (>= 1.0) with pytest.raises(ValidationError, match="must be in range"): - SparseAttentionAttributeConfig(threshold={"prefill": 1.0}) + SparseAttentionAttributeConfig(thresholds={"prefill": [1.0]}) def test_threshold_validation_type(self): - """Test threshold must be a dict (not single value or string).""" + """Test thresholds must be a dict (not single value or string).""" # Single float value not allowed with pytest.raises(ValidationError, match="Input should be a valid dictionary"): - SparseAttentionAttributeConfig(threshold=1e-4) + SparseAttentionAttributeConfig(thresholds=1e-4) # String not allowed with pytest.raises(ValidationError, match="Input should be a valid dictionary"): - SparseAttentionAttributeConfig(threshold="invalid") + SparseAttentionAttributeConfig(thresholds="invalid") diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index 9a9544419..eab020022 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -69,7 +69,7 @@ def test_pattern_based_replacement(self): "sparse_cfg": { "*self_attn*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-4, "decode": 1e-4}, + "thresholds": {"prefill": [1e-4], "decode": [1e-4]}, "br": 128, "bc": 128, "enable": True, @@ -100,7 +100,7 @@ def filter_func(name): "sparse_cfg": { filter_func: { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "enable": True, }, }, @@ -118,7 +118,7 @@ def test_no_matching_modules(self): "sparse_cfg": { "*nonexistent*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 1e-3, "decode": 1e-4}, + "thresholds": {"prefill": [1e-3], "decode": [1e-4]}, "enable": True, }, }, @@ -192,7 +192,7 @@ def test_get_stats_with_stats_manager(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 0.001, "decode": 0.0001}, + "thresholds": {"prefill": [0.001], "decode": [0.0001]}, "br": 64, "bc": 64, "collect_stats": True, # Enable stats collection @@ -228,7 +228,7 @@ def test_get_stats_without_stats_manager(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 0.001, "decode": 0.0001}, + "thresholds": {"prefill": [0.001], "decode": [0.0001]}, "br": 64, "bc": 64, "collect_stats": False, # Disable stats collection diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py index 2a390ab1f..318e6a4b9 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_stats_manager.py @@ -34,7 +34,7 @@ def test_initialization_defaults(self): assert manager.calibration_mode is False assert manager.aggregated_stats["total_calls"] == 0 assert manager.aggregated_stats["total_blocks"] == 0 - assert manager.aggregated_stats["sparse_blocks"] == 0 + assert "sparse_blocks" not in manager.aggregated_stats assert manager.per_sample_stats == [] def test_initialization_disabled(self): @@ -56,7 +56,7 @@ def test_collect_stats_enabled(self): "sparsity": 0.5, "phase": "prefill", "total_blocks": 100, - "sparse_blocks": 50, + "sparse_blocks": [50], "sample_length": 1024, } @@ -64,7 +64,7 @@ def test_collect_stats_enabled(self): assert manager.aggregated_stats["total_calls"] == 1 assert manager.aggregated_stats["total_blocks"] == 100 - assert manager.aggregated_stats["sparse_blocks"] == 50 + assert manager.aggregated_stats["sparse_blocks"] == [50] assert manager.aggregated_stats["phase_counts"]["prefill"] == 1 assert manager.aggregated_stats["phase_counts"]["decode"] == 0 @@ -76,7 +76,7 @@ def test_collect_stats_disabled(self): "sparsity": 0.5, "phase": "prefill", "total_blocks": 100, - "sparse_blocks": 50, + "sparse_blocks": [50], } manager.collect(stats) @@ -84,7 +84,7 @@ def test_collect_stats_disabled(self): # Should remain at initial values assert manager.aggregated_stats["total_calls"] == 0 assert manager.aggregated_stats["total_blocks"] == 0 - assert manager.aggregated_stats["sparse_blocks"] == 0 + assert "sparse_blocks" not in manager.aggregated_stats def test_collect_multiple_calls(self): """Test accumulation over multiple collect calls.""" @@ -96,13 +96,13 @@ def test_collect_multiple_calls(self): "sparsity": 0.5, "phase": "prefill", "total_blocks": 100, - "sparse_blocks": 50, + "sparse_blocks": [50], } manager.collect(stats) assert manager.aggregated_stats["total_calls"] == 5 assert manager.aggregated_stats["total_blocks"] == 500 - assert manager.aggregated_stats["sparse_blocks"] == 250 + assert manager.aggregated_stats["sparse_blocks"] == [250] assert manager.aggregated_stats["phase_counts"]["prefill"] == 5 def test_collect_different_phases(self): @@ -110,11 +110,11 @@ def test_collect_different_phases(self): manager = SparseAttentionStatsManager(module_name="test", enabled=True) # Collect prefill stats - manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) - manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": [50]}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": [50]}) # Collect decode stats - manager.collect({"phase": "decode", "total_blocks": 10, "sparse_blocks": 5}) + manager.collect({"phase": "decode", "total_blocks": 10, "sparse_blocks": [5]}) assert manager.aggregated_stats["phase_counts"]["prefill"] == 2 assert manager.aggregated_stats["phase_counts"]["decode"] == 1 @@ -135,7 +135,7 @@ def test_calibration_mode_per_sample_collection(self): "sparsity": 0.5, "phase": "prefill", "total_blocks": 100, - "sparse_blocks": 50, + "sparse_blocks": [50], "sample_length": 1024, } @@ -153,7 +153,7 @@ def test_calibration_mode_off(self): manager = SparseAttentionStatsManager(module_name="test", enabled=True) # Calibration mode is off by default - stats = {"sparsity": 0.5, "phase": "prefill", "total_blocks": 100, "sparse_blocks": 50} + stats = {"sparsity": 0.5, "phase": "prefill", "total_blocks": 100, "sparse_blocks": [50]} manager.collect(stats) @@ -174,7 +174,7 @@ def test_set_calibration_mode_with_reset(self): "sparsity": 0.5, "phase": "prefill", "total_blocks": 100, - "sparse_blocks": 50, + "sparse_blocks": [50], "sample_length": 1024, } ) @@ -195,7 +195,7 @@ def test_set_calibration_mode_without_reset(self): "sparsity": 0.5, "phase": "prefill", "total_blocks": 100, - "sparse_blocks": 50, + "sparse_blocks": [50], "sample_length": 1024, } ) @@ -214,15 +214,15 @@ def test_get_summary_with_data(self): manager = SparseAttentionStatsManager(module_name="test_module", enabled=True) # Collect stats - manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 30}) - manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": 50}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": [30]}) + manager.collect({"phase": "prefill", "total_blocks": 100, "sparse_blocks": [50]}) summary = manager.get_summary() assert summary["module"] == "test_module" assert summary["total_calls"] == 2 # Average sparsity: (30+50) / (100+100) = 80/200 = 0.4 - assert summary["average_sparsity"] == 0.4 + assert summary["average_sparsity"] == [0.4] assert summary["phase_distribution"]["prefill"] == 2 def test_get_summary_zero_blocks(self): @@ -230,11 +230,11 @@ def test_get_summary_zero_blocks(self): manager = SparseAttentionStatsManager(module_name="test", enabled=True) # Collect stats with zero blocks - manager.collect({"phase": "prefill", "total_blocks": 0, "sparse_blocks": 0}) + manager.collect({"phase": "prefill", "total_blocks": 0, "sparse_blocks": [0]}) summary = manager.get_summary() - assert summary["average_sparsity"] == 0.0 # Should handle division by zero + assert summary["average_sparsity"] == [0.0] # Should handle division by zero class TestGetCalibrationStats: @@ -252,7 +252,7 @@ def test_get_calibration_stats(self): "sparsity": 0.3 + i * 0.1, "phase": "prefill", "total_blocks": 100, - "sparse_blocks": 30, + "sparse_blocks": [30], "sample_length": 1024 + i * 512, } ) @@ -287,7 +287,7 @@ def test_reset(self): "sparsity": 0.5, "phase": "prefill", "total_blocks": 100, - "sparse_blocks": 50, + "sparse_blocks": [50], "sample_length": 1024, } ) @@ -296,7 +296,7 @@ def test_reset(self): "sparsity": 0.3, "phase": "decode", "total_blocks": 10, - "sparse_blocks": 3, + "sparse_blocks": [3], "sample_length": 128, } ) @@ -311,7 +311,7 @@ def test_reset(self): # All stats should be cleared assert manager.aggregated_stats["total_calls"] == 0 assert manager.aggregated_stats["total_blocks"] == 0 - assert manager.aggregated_stats["sparse_blocks"] == 0 + assert "sparse_blocks" not in manager.aggregated_stats assert manager.per_sample_stats == [] assert manager.aggregated_stats["phase_counts"]["prefill"] == 0 assert manager.aggregated_stats["phase_counts"]["decode"] == 0 diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py index 320196ccc..00622d365 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_threshold_info.py @@ -33,7 +33,7 @@ def test_phased_threshold(self): """Test threshold info for phase-specific static thresholds.""" method = FlashSkipSoftmax( method_config={ - "threshold": {"prefill": 0.001, "decode": 0.0001}, + "thresholds": {"prefill": [0.001], "decode": [0.0001]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -43,17 +43,17 @@ def test_phased_threshold(self): info = method.get_threshold_info() - # Static phased thresholds are reported as type "static" with dict value + # Static phased thresholds are reported as type "static" with dict of lists assert info["type"] == "static" assert isinstance(info["value"], dict) - assert info["value"]["prefill"] == 0.001 - assert info["value"]["decode"] == 0.0001 + assert info["value"]["prefill"] == [0.001] + assert info["value"]["decode"] == [0.0001] def test_dynamic_calibrated_threshold(self): """Test threshold info for calibrated dynamic threshold.""" method = FlashSkipSoftmax( method_config={ - "threshold": {"prefill": 0.001, "decode": 0.0001}, + "thresholds": {"prefill": [0.001], "decode": [0.0001]}, "br": 128, "bc": 128, "backend": "pytorch", @@ -94,7 +94,7 @@ def test_module_delegates_to_method(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 0.005, "decode": 0.001}, + "thresholds": {"prefill": [0.005], "decode": [0.001]}, "br": 64, "bc": 64, "enable": True, @@ -117,8 +117,8 @@ def test_module_delegates_to_method(self): info = sparse_module.get_threshold_info() assert info["type"] == "static" - assert info["value"]["prefill"] == 0.005 - assert info["value"]["decode"] == 0.001 + assert info["value"]["prefill"] == [0.005] + assert info["value"]["decode"] == [0.001] def test_module_with_calibrated_threshold(self): """Test module reports calibrated threshold correctly.""" @@ -128,7 +128,7 @@ def test_module_with_calibrated_threshold(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 0.001, "decode": 0.0001}, + "thresholds": {"prefill": [0.001], "decode": [0.0001]}, "br": 64, "bc": 64, "enable": True, @@ -167,7 +167,7 @@ def test_module_without_method_instance(self): "sparse_cfg": { "*attention*": { "method": "flash_skip_softmax", - "threshold": {"prefill": 0.001, "decode": 0.0001}, + "thresholds": {"prefill": [0.001], "decode": [0.0001]}, "br": 64, "bc": 64, "enable": True,