Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions examples/llm_sparsity/attention_sparsity/hf_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,15 @@ def main(args):
model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
torch_dtype="auto",
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)

# Set pad token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# Move model to GPU if available
if torch.cuda.is_available():
model = model.cuda()
print("Model moved to CUDA")

# Generate sample output BEFORE sparse attention
print("\nGenerating sample output before sparse attention...")
output_before, test_prompt, input_ids = generate_sample_output(model, tokenizer, args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import torch
import torch.nn as nn
from scipy.optimize import curve_fit
from tqdm import tqdm

from ..stats_manager import SparseAttentionStatsManager
from ..utils import get_sparse_attention_modules
Expand Down Expand Up @@ -91,9 +90,9 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic
"""Calibrate a and b parameters for Exponential model.

Algorithm:
1. For each threshold λ_j in threshold_trials:
- Run ALL samples, collect sparsities S_ij for each sample i
- Compute scale_factor_ij = λ_j × L_i (where L_i is sample length)
1. Set thresholds = threshold_trials on all modules, run ONE forward pass.
Each module returns a sparsity list (one entry per threshold) per sample.
Unpack to get (scale_factor_ij = λ_j × L_i, sparsity_ij) pairs.

2. Fit Exponential model to ALL (sf_ij, S_ij) pairs:
scale_factor = a * exp(b * sparsity)
Expand Down Expand Up @@ -121,29 +120,25 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic
print(f"Starting Exponential model calibration ({phase} phase)")
print(f"Threshold trials: {len(self.threshold_trials)}")

# Stage 1: Collect ALL (scale_factor, sparsity) pairs for all thresholds and samples
print(f"\nStage 1: Collecting {phase} sparsity data for all thresholds...")
# Stage 1: Collect ALL (scale_factor, sparsity) pairs in a single forward pass.
# All threshold_trials are passed at once; each module returns a sparsity list
# with one entry per threshold, eliminating the need for repeated forward passes.
print(f"\nStage 1: Collecting {phase} sparsity data for all thresholds in one pass...")

# Collect ALL individual data points (not averaged)
all_data_points = [] # List of {"threshold", "length", "scale_factor", "sparsity"}

for threshold in tqdm(self.threshold_trials, desc=f"Testing thresholds ({phase})"):
self._set_threshold(attention_modules, threshold)
self._enable_calibration_mode(attention_modules)
with torch.no_grad():
forward_loop(model)
per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase)
self._disable_calibration_mode(attention_modules)

if not per_sample_stats:
continue

# Collect individual (scale_factor, sparsity) pairs for each sample
for sample_stat in per_sample_stats:
length = sample_stat["sample_length"]
sparsity = sample_stat["sparsity"]
self._set_thresholds(attention_modules, self.threshold_trials)
self._enable_calibration_mode(attention_modules)
with torch.no_grad():
forward_loop(model)
per_sample_stats = self._extract_calibration_stats(attention_modules, phase=phase)
self._disable_calibration_mode(attention_modules)

for sample_stat in per_sample_stats:
length = sample_stat["sample_length"]
sparsity_list = sample_stat["sparsity"]
for threshold, sparsity in zip(self.threshold_trials, sparsity_list):
scale_factor = threshold * length

all_data_points.append(
{
"threshold": threshold,
Expand Down Expand Up @@ -307,17 +302,26 @@ def _extract_calibration_stats(
aggregated_stats = []

for sample_idx in range(num_samples):
sparsities = []
sparsity_lists = []
sample_length = 0

for module_stats in all_per_sample_stats:
if sample_idx < len(module_stats):
sample_stat = module_stats[sample_idx]
sparsities.append(sample_stat.get("sparsity", 0.0))
sparsity = sample_stat.get("sparsity", [])
sparsity_lists.append(sparsity if isinstance(sparsity, list) else [sparsity])
if not sample_length and "sample_length" in sample_stat:
sample_length = sample_stat["sample_length"]

avg_sparsity = float(np.mean(sparsities)) if sparsities else 0.0
if not sparsity_lists:
continue

lengths = [len(s) for s in sparsity_lists]
assert len(set(lengths)) == 1, (
f"All modules must have the same number of thresholds, got {lengths}"
)
n = lengths[0]
avg_sparsity = [float(np.mean([sl[i] for sl in sparsity_lists])) for i in range(n)]

aggregated_stats.append(
{
Expand All @@ -328,7 +332,7 @@ def _extract_calibration_stats(

return aggregated_stats

def _set_threshold(self, modules: list[nn.Module], threshold: float):
"""Set threshold on sparse attention modules."""
def _set_thresholds(self, modules: list[nn.Module], thresholds: list[float]):
"""Set thresholds list on sparse attention modules."""
for module in modules:
module._sparse_method_instance.threshold = threshold
module._sparse_method_instance.thresholds = thresholds
47 changes: 31 additions & 16 deletions modelopt/torch/sparsity/attention_sparsity/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
description="If True, enables sparse attention. If False, bypasses sparsity.",
)

threshold: dict[str, float] = ModeloptField(
default={"prefill": 1e-3, "decode": 1e-4},
title="Sparsity threshold.",
thresholds: dict[str, list[float]] = ModeloptField(
default={"prefill": [1e-3], "decode": [1e-4]},
title="Sparsity thresholds.",
description=(
"Threshold for determining which attention values to skip. "
"Must be a dict with 'prefill' and 'decode' keys."
"Thresholds for determining which attention values to skip. "
"Must be a dict with 'prefill' and/or 'decode' keys, each mapping to a list of floats. "
"Prefill and decode lists must have the same length. "
"Sparsity is computed per threshold; the first threshold's mask is applied."
),
)

Expand Down Expand Up @@ -120,10 +122,10 @@ def validate_block_size(cls, v):
raise ValueError(f"Block size must be positive, got {v}")
return v

@field_validator("threshold")
@field_validator("thresholds")
@classmethod
def validate_threshold(cls, v):
"""Validate threshold is a dict with valid phases and values in range (0, 1)."""
def validate_thresholds(cls, v):
"""Validate thresholds is a dict of lists with valid phases and values in range (0, 1)."""
if not isinstance(v, dict):
raise ValueError(
f"Threshold must be a dict with 'prefill' and/or 'decode' keys, got {type(v).__name__}"
Expand All @@ -135,12 +137,25 @@ def validate_threshold(cls, v):
raise ValueError(
f"Invalid threshold phases: {invalid_keys}. Valid phases: {valid_phases}"
)
# Validate all values are in range (0, 1)
for phase, threshold in v.items():
if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1:
# Validate all values are lists of floats in range (0, 1)
lengths = {}
for phase, threshold_list in v.items():
if not isinstance(threshold_list, list) or len(threshold_list) == 0:
raise ValueError(
f"Threshold for phase '{phase}' must be in range (0, 1), got {threshold}"
f"Thresholds for phase '{phase}' must be a non-empty list, got {threshold_list}"
)
for threshold in threshold_list:
if not isinstance(threshold, (int, float)) or threshold <= 0 or threshold >= 1:
raise ValueError(
f"Each threshold for phase '{phase}' must be in range (0, 1), got {threshold}"
)
lengths[phase] = len(threshold_list)
# Validate prefill and decode lists have the same length
if len(lengths) == 2 and len(set(lengths.values())) != 1:
raise ValueError(
f"Prefill and decode threshold lists must have the same length, "
f"got prefill={lengths['prefill']}, decode={lengths['decode']}"
)
return v


Expand Down Expand Up @@ -356,7 +371,7 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
default={
"*attention*": {
"method": "flash_skip_softmax",
"threshold": {"prefill": 1e-3, "decode": 1e-5},
"thresholds": {"prefill": [1e-3], "decode": [1e-5]},
"br": 128, # Flash Attention block rows
"bc": 128, # Flash Attention block columns
"backend": "pytorch", # Only pytorch backend supported
Expand All @@ -378,9 +393,9 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
"sparse_cfg": {
"*attn*": {
"method": "flash_skip_softmax",
"threshold": {
"prefill": 1e-3, # More aggressive during prefill
"decode": 1e-4, # Conservative during decode
"thresholds": {
"prefill": [1e-3], # More aggressive during prefill
"decode": [1e-4], # Conservative during decode
},
"br": 128, # Flash Attention block rows
"bc": 128, # Flash Attention block columns
Expand Down
Loading
Loading