diff --git a/examples/diffusers/sparsity/ltx2_skip_softmax.py b/examples/diffusers/sparsity/ltx2_skip_softmax.py new file mode 100644 index 0000000000..dae064e070 --- /dev/null +++ b/examples/diffusers/sparsity/ltx2_skip_softmax.py @@ -0,0 +1,397 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LTX-2 inference with skip-softmax sparse attention. + +This example applies skip-softmax sparse attention to the LTX-2 video +generation model using exponential model calibration +(``scale_factor = a * exp(b * target_sparsity)``). + +During calibration, ``flash_skip_softmax`` with the eager attention backend +collects sparsity statistics across multiple threshold trials. The fitted +exponential model then allows runtime control of the target sparsity ratio +without recalibration. + +Only the stage-1 backbone is sparsified. Stage 2 (spatial upsampler + +distilled LoRA) runs unmodified. + +Usage:: + + # With calibration (recommended) + python ltx2_skip_softmax.py --prompt "A cat playing piano" --output out.mp4 \\ + --calibrate --target-sparsity 0.25 + + # Disable sparsity on first/last 2 layers (higher quality, less speedup) + python ltx2_skip_softmax.py --prompt "A cat playing piano" --output out.mp4 \\ + --calibrate --target-sparsity 0.25 --skip-first-last 2 +""" + +import argparse +import functools +import os + +import torch +from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps +from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number +from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline +from ltx_pipelines.utils.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_2_STAGE_HEIGHT, + DEFAULT_2_STAGE_WIDTH, + DEFAULT_AUDIO_GUIDER_PARAMS, + DEFAULT_FRAME_RATE, + DEFAULT_NEGATIVE_PROMPT, + DEFAULT_NUM_INFERENCE_STEPS, + DEFAULT_SEED, + DEFAULT_VIDEO_GUIDER_PARAMS, +) +from ltx_pipelines.utils.media_io import encode_video + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +# ---- Model paths (edit these or override via environment variables) ---- +CHECKPOINT_PATH = os.environ.get( + "LTX2_CHECKPOINT", + "/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-dev.safetensors", +) +DISTILLED_LORA_PATH = os.environ.get( + "LTX2_DISTILLED_LORA", + "/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-distilled-lora-384.safetensors", +) +SPATIAL_UPSAMPLER_PATH = os.environ.get( + "LTX2_SPATIAL_UPSAMPLER", + "/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-spatial-upscaler-x2-1.0.safetensors", +) +GEMMA_ROOT = os.environ.get( + "LTX2_GEMMA_ROOT", + "/home/scratch.omniml_data_2/jingyux/models/LTX-2/gemma-3-12b-it-qat-q4_0-unquantized", +) + +DEFAULT_NUM_FRAMES = 121 +NUM_TRANSFORMER_BLOCKS = 48 + +# Default threshold trials for calibration +DEFAULT_THRESHOLD_TRIALS = [ + 1e-6, + 5e-6, + 1e-5, + 5e-5, + 1e-4, + 5e-4, + 1e-3, + 5e-3, + 1e-2, + 2e-2, + 5e-2, + 1e-1, + 2e-1, + 3e-1, + 5e-1, + 7e-1, +] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="LTX-2 video generation with skip-softmax sparse attention" + ) + parser.add_argument("--prompt", type=str, default=None, help="Text prompt for generation") + parser.add_argument( + "--prompt-dir", + type=str, + default=None, + help="Directory of .txt prompt files (one prompt per file). Overrides --prompt.", + ) + parser.add_argument("--output", type=str, default="output.mp4", help="Output video path") + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory to save videos when using --prompt-dir", + ) + parser.add_argument( + "--num-frames", type=int, default=DEFAULT_NUM_FRAMES, help="Number of frames" + ) + parser.add_argument("--height", type=int, default=DEFAULT_2_STAGE_HEIGHT, help="Video height") + parser.add_argument("--width", type=int, default=DEFAULT_2_STAGE_WIDTH, help="Video width") + parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="Random seed") + + # Sparse attention options + parser.add_argument( + "--skip-first-last", + type=int, + default=0, + help="Number of first/last transformer layers to keep dense (default: 0)", + ) + + # Calibration options + parser.add_argument( + "--calibrate", + action="store_true", + help="Calibrate threshold via exponential model (recommended)", + ) + parser.add_argument( + "--target-sparsity", + type=float, + default=0.25, + help="Target sparsity ratio for calibration (0.0-1.0)", + ) + parser.add_argument( + "--calib-steps", + type=int, + default=10, + help="Inference steps per calibration sample", + ) + parser.add_argument( + "--calib-frames", + type=int, + default=81, + help="Number of frames per calibration sample", + ) + parser.add_argument( + "--calib-size", + type=int, + default=1, + help="Number of prompts to use for calibration", + ) + return parser.parse_args() + + +def _patch_vae_requires_grad(pipeline: TI2VidTwoStagesPipeline): + """Ensure VAE decoder weights have requires_grad=False to avoid autograd issues.""" + for ledger_attr in ("stage_1_model_ledger", "stage_2_model_ledger"): + ledger = getattr(pipeline, ledger_attr, None) + if ledger is None: + continue + for loader_name in ("video_decoder", "audio_decoder"): + orig_loader = getattr(ledger, loader_name, None) + if orig_loader is None: + continue + + def _make_patched(fn): + @functools.wraps(fn) + def patched(): + model = fn() + model.requires_grad_(False) + return model + + return patched + + setattr(ledger, loader_name, _make_patched(orig_loader)) + + +def build_pipeline() -> TI2VidTwoStagesPipeline: + """Build the LTX-2 two-stage video generation pipeline.""" + pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=CHECKPOINT_PATH, + distilled_lora=[ + LoraPathStrengthAndSDOps(DISTILLED_LORA_PATH, 0.8, LTXV_LORA_COMFY_RENAMING_MAP) + ], + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH, + gemma_root=GEMMA_ROOT, + loras=[], + ) + _patch_vae_requires_grad(pipeline) + return pipeline + + +def build_sparse_config(args: argparse.Namespace) -> dict: + """Build sparse attention config from CLI args. + + Uses flash_skip_softmax which supports both calibration (eager attention + with F.softmax patching) and inference. Calibration fits an exponential + model: scale_factor = a * exp(b * sparsity). + """ + attn_cfg: dict = { + "method": "flash_skip_softmax", + "thresholds": {"prefill": [1e-3]}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": False, # Diffusion = bidirectional attention + "collect_stats": True, + "enable": True, + } + + sparse_cfg: dict = { + "*.attn1": attn_cfg, # Self-attention only + # Disable on all cross-attention and cross-modal attention + "*.attn2": {"enable": False}, + "*audio_attn1*": {"enable": False}, + "*audio_attn2*": {"enable": False}, + "*audio_to_video_attn*": {"enable": False}, + "*video_to_audio_attn*": {"enable": False}, + "default": {"enable": False}, + } + + # Keep first/last N layers dense for quality + for i in range(args.skip_first_last): + sparse_cfg[f"*transformer_blocks.{i}.attn*"] = {"enable": False} + sparse_cfg[f"*transformer_blocks.{NUM_TRANSFORMER_BLOCKS - 1 - i}.attn*"] = { + "enable": False + } + + config: dict = {"sparse_cfg": sparse_cfg} + + # Add calibration config with threshold trials + if args.calibrate: + sparse_cfg["calibration"] = { + "target_sparse_ratio": {"prefill": args.target_sparsity}, + "samples": args.calib_size, + "threshold_trials": DEFAULT_THRESHOLD_TRIALS, + } + + return config + + +def load_calib_prompts(calib_size: int) -> list[str]: + """Load calibration prompts from OpenVid-1M dataset.""" + from datasets import load_dataset + + dataset = load_dataset("nkp37/OpenVid-1M") + prompts = list(dataset["train"]["caption"][:calib_size]) + print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M") + return prompts + + +def build_calibration_forward_loop( + pipeline: TI2VidTwoStagesPipeline, + num_steps: int = 10, + num_frames: int = 81, + calib_size: int = 1, +): + """Build a forward loop for exponential model calibration. + + Generates short videos to exercise the attention mechanism at various + threshold trials, collecting sparsity statistics for the exponential fit. + """ + calib_prompts = load_calib_prompts(calib_size) + tiling_config = TilingConfig.default() + + def forward_loop(model): + for i, prompt in enumerate(calib_prompts): + print(f"Calibration [{i + 1}/{len(calib_prompts)}]: {prompt[:60]}...") + pipeline( + prompt=prompt, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=DEFAULT_SEED, + height=DEFAULT_2_STAGE_HEIGHT, + width=DEFAULT_2_STAGE_WIDTH, + num_frames=num_frames, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=num_steps, + video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS, + audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS, + images=[], + tiling_config=tiling_config, + ) + + return forward_loop + + +def print_sparsity_summary(transformer: torch.nn.Module) -> None: + """Print per-module sparsity statistics.""" + enabled, disabled = [], [] + for name, module in transformer.named_modules(): + if isinstance(module, SparseAttentionModule): + if module.is_enabled: + enabled.append((name, module)) + else: + disabled.append(name) + + print(f"\nSparse attention: {len(enabled)} enabled, {len(disabled)} disabled") + for name, module in enabled: + info = module.get_threshold_info() + print(f" {name}: {info}") + + +def main() -> None: + args = parse_args() + + # ---- Build pipeline ---- + print("Building LTX-2 pipeline...") + pipeline = build_pipeline() + + # ---- Get and sparsify the stage-1 transformer ---- + transformer = pipeline.stage_1_model_ledger.transformer() + # Pin transformer in memory so pipeline reuses the sparsified version + pipeline.stage_1_model_ledger.transformer = lambda: transformer + + config = build_sparse_config(args) + forward_loop = None + if args.calibrate: + forward_loop = build_calibration_forward_loop( + pipeline, + num_steps=args.calib_steps, + num_frames=args.calib_frames, + calib_size=args.calib_size, + ) + + print("Applying skip-softmax sparse attention...") + mtsa.sparsify(transformer, config, forward_loop=forward_loop) + + # ---- Build prompt list ---- + prompts_and_outputs: list[tuple[str, str]] = [] + if args.prompt_dir: + output_dir = args.output_dir or "output_videos" + os.makedirs(output_dir, exist_ok=True) + prompt_files = sorted(f for f in os.listdir(args.prompt_dir) if f.endswith(".txt")) + for pf in prompt_files: + with open(os.path.join(args.prompt_dir, pf)) as f: + prompt = f.read().strip() + stem = os.path.splitext(pf)[0] + prompts_and_outputs.append((prompt, os.path.join(output_dir, f"{stem}.mp4"))) + elif args.prompt: + prompts_and_outputs.append((args.prompt, args.output)) + else: + raise ValueError("Either --prompt or --prompt-dir must be provided") + + # ---- Generate ---- + tiling_config = TilingConfig.default() + for i, (prompt, output_path) in enumerate(prompts_and_outputs): + print(f"\nGenerating [{i + 1}/{len(prompts_and_outputs)}]: {prompt[:80]}...") + + video, audio = pipeline( + prompt=prompt, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=args.seed, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS, + audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS, + images=[], + tiling_config=tiling_config, + ) + + encode_video( + video=video, + fps=DEFAULT_FRAME_RATE, + audio=audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=output_path, + video_chunks_number=get_video_chunks_number(args.num_frames, tiling_config), + ) + print(f"Saved to {output_path}") + + # ---- Print stats ---- + print_sparsity_summary(transformer) + + +if __name__ == "__main__": + main() diff --git a/examples/diffusers/sparsity/wan22_skip_softmax.py b/examples/diffusers/sparsity/wan22_skip_softmax.py new file mode 100644 index 0000000000..ac2031a6d7 --- /dev/null +++ b/examples/diffusers/sparsity/wan22_skip_softmax.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wan 2.2 inference with skip-softmax sparse attention. + +This example applies skip-softmax sparse attention to the Wan 2.2 video +generation model (text-to-video) using exponential model calibration +(``scale_factor = a * exp(b * target_sparsity)``). + +During calibration, ``flash_skip_softmax`` with the eager attention backend +collects sparsity statistics across multiple threshold trials. The fitted +exponential model then allows runtime control of the target sparsity ratio +without recalibration. + +The Wan 2.2 5B model has 40 transformer blocks with self-attention (attn1) +and cross-attention (attn2). Only self-attention is sparsified. + +Usage:: + + # With calibration (recommended) + python wan22_skip_softmax.py --prompt "A cat playing piano" --output out.mp4 \\ + --calibrate --target-sparsity 0.25 + + # Custom model path + python wan22_skip_softmax.py --model-path /path/to/Wan2.2-T2V-5B \\ + --prompt "A sunset over mountains" --output sunset.mp4 --calibrate +""" + +import argparse +import os + +import torch +from diffusers import AutoencoderKLWan, WanPipeline +from diffusers.utils import export_to_video + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +DEFAULT_MODEL_PATH = os.environ.get("WAN22_MODEL_PATH", "Wan-AI/Wan2.2-T2V-5B") +NUM_TRANSFORMER_BLOCKS = 40 + +# Default threshold trials for calibration +DEFAULT_THRESHOLD_TRIALS = [ + 1e-6, + 5e-6, + 1e-5, + 5e-5, + 1e-4, + 5e-4, + 1e-3, + 5e-3, + 1e-2, + 2e-2, + 5e-2, + 1e-1, + 2e-1, + 3e-1, + 5e-1, + 7e-1, +] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Wan 2.2 video generation with skip-softmax sparse attention" + ) + parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation") + parser.add_argument("--output", type=str, default="output.mp4", help="Output video path") + parser.add_argument( + "--model-path", type=str, default=DEFAULT_MODEL_PATH, help="Wan 2.2 model path or HF ID" + ) + parser.add_argument( + "--num-frames", type=int, default=81, help="Number of frames (must be 4k+1)" + ) + parser.add_argument("--height", type=int, default=480, help="Video height") + parser.add_argument("--width", type=int, default=832, help="Video width") + parser.add_argument("--num-steps", type=int, default=50, help="Number of inference steps") + parser.add_argument( + "--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale" + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + + # Sparse attention options + parser.add_argument( + "--skip-first-last", + type=int, + default=0, + help="Number of first/last transformer layers to keep dense (default: 0)", + ) + + # Calibration options + parser.add_argument( + "--calibrate", + action="store_true", + help="Calibrate threshold via exponential model (recommended)", + ) + parser.add_argument( + "--target-sparsity", + type=float, + default=0.25, + help="Target sparsity ratio for calibration (0.0-1.0)", + ) + parser.add_argument( + "--calib-steps", + type=int, + default=10, + help="Inference steps for calibration", + ) + parser.add_argument( + "--calib-frames", + type=int, + default=33, + help="Number of frames for calibration (fewer = faster)", + ) + return parser.parse_args() + + +def build_pipeline(model_path: str) -> WanPipeline: + """Build the Wan 2.2 text-to-video pipeline.""" + vae = AutoencoderKLWan.from_pretrained(model_path, subfolder="vae", torch_dtype=torch.float32) + pipe = WanPipeline.from_pretrained(model_path, vae=vae, torch_dtype=torch.bfloat16) + pipe.to("cuda") + return pipe + + +def build_sparse_config(args: argparse.Namespace) -> dict: + """Build sparse attention config from CLI args. + + Uses flash_skip_softmax which supports both calibration (eager attention + with F.softmax patching) and inference. Calibration fits an exponential + model: scale_factor = a * exp(b * sparsity). + """ + attn_cfg: dict = { + "method": "flash_skip_softmax", + "thresholds": {"prefill": [1e-3]}, + "br": 128, + "bc": 128, + "backend": "pytorch", + "is_causal": False, # Diffusion = bidirectional attention + "collect_stats": True, + "enable": True, + } + + sparse_cfg: dict = { + "*.attn1*": attn_cfg, # Self-attention only + "*.attn2*": {"enable": False}, # Text cross-attention + "default": {"enable": False}, + } + + # Keep first/last N layers dense for quality + for i in range(args.skip_first_last): + sparse_cfg[f"*blocks.{i}.attn*"] = {"enable": False} + sparse_cfg[f"*blocks.{NUM_TRANSFORMER_BLOCKS - 1 - i}.attn*"] = {"enable": False} + + config: dict = {"sparse_cfg": sparse_cfg} + + # Add calibration config with threshold trials + if args.calibrate: + sparse_cfg["calibration"] = { + "target_sparse_ratio": {"prefill": args.target_sparsity}, + "samples": 1, + "threshold_trials": DEFAULT_THRESHOLD_TRIALS, + } + + return config + + +def build_calibration_forward_loop( + pipe: WanPipeline, + prompt: str, + num_steps: int = 10, + num_frames: int = 33, + height: int = 480, + width: int = 832, + seed: int = 42, +): + """Build a forward loop for exponential model calibration.""" + + def forward_loop(model): + print(f"Calibration: generating {num_frames} frames @ {height}x{width}...") + pipe( + prompt=prompt, + num_frames=num_frames, + height=height, + width=width, + num_inference_steps=num_steps, + guidance_scale=5.0, + generator=torch.Generator(device="cuda").manual_seed(seed), + ) + + return forward_loop + + +def print_sparsity_summary(model: torch.nn.Module) -> None: + """Print per-module sparsity statistics.""" + enabled, disabled = [], [] + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + if module.is_enabled: + enabled.append((name, module)) + else: + disabled.append(name) + + print(f"\nSparse attention: {len(enabled)} enabled, {len(disabled)} disabled") + for name, module in enabled: + info = module.get_threshold_info() + print(f" {name}: {info}") + + +def main() -> None: + args = parse_args() + + # ---- Build pipeline ---- + print(f"Loading Wan 2.2 from {args.model_path}...") + pipe = build_pipeline(args.model_path) + + # ---- Get and sparsify the transformer ---- + transformer = pipe.transformer + + config = build_sparse_config(args) + forward_loop = None + if args.calibrate: + forward_loop = build_calibration_forward_loop( + pipe, + prompt=args.prompt, + num_steps=args.calib_steps, + num_frames=args.calib_frames, + height=args.height, + width=args.width, + seed=args.seed, + ) + + print("Applying skip-softmax sparse attention...") + mtsa.sparsify(transformer, config, forward_loop=forward_loop) + + # ---- Generate ---- + print(f"Generating: {args.prompt[:80]}...") + output = pipe( + prompt=args.prompt, + num_frames=args.num_frames, + height=args.height, + width=args.width, + num_inference_steps=args.num_steps, + guidance_scale=args.guidance_scale, + generator=torch.Generator(device="cuda").manual_seed(args.seed), + ) + + export_to_video(output.frames[0], args.output, fps=16) + print(f"Saved to {args.output}") + + # ---- Print stats ---- + print_sparsity_summary(transformer) + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index dbc4d5bc27..da64e87d64 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -255,11 +255,14 @@ def calibrate_sparse_attention( print(f"Calibrating {len(sparse_modules)} sparse attention modules together...") - # Extract tokenizer and build calibration data if needed - tokenizer = _extract_tokenizer_from_model(model) + # Extract tokenizer and build calibration data only if no forward_loop is provided. + # When the user supplies their own forward_loop (e.g. for diffusion models), + # RULER dataset generation is skipped entirely. + tokenizer = None calibration_data = None - if calibrate_prefill or calibrate_decode: + if forward_loop is None and (calibrate_prefill or calibrate_decode): + tokenizer = _extract_tokenizer_from_model(model) builder = RulerDatasetBuilder( samples=calib_config.samples, max_seqlen=calib_config.max_seqlen, @@ -280,11 +283,15 @@ def calibrate_sparse_attention( print("PREFILL PHASE CALIBRATION") print("=" * 60) - if calibration_data is None: + if forward_loop is None and calibration_data is None: raise RuntimeError("calibration_data must be built before prefill") - prefill_forward_loop = forward_loop or create_calibration_forward_loop( - calibration_data, tokenizer, chunk_size=calib_config.chunk_size - ) + if forward_loop is not None: + prefill_forward_loop = forward_loop + else: + assert calibration_data is not None and tokenizer is not None + prefill_forward_loop = create_calibration_forward_loop( + calibration_data, tokenizer, chunk_size=calib_config.chunk_size + ) prefill_calibrator = DynamicThresholdCalibrator( threshold_trials=calib_config.threshold_trials, @@ -302,8 +309,8 @@ def calibrate_sparse_attention( print("DECODE PHASE CALIBRATION") print("=" * 60) - if calibration_data is None: - raise RuntimeError("calibration_data must be built before decode") + if calibration_data is None or tokenizer is None: + raise RuntimeError("calibration_data and tokenizer must be built before decode") decode_forward_loop = create_decode_calibration_forward_loop( calibration_data, tokenizer, num_decode_tokens=calib_config.num_decode_tokens ) diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index cdc2aed948..c8a8aea605 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -101,6 +101,45 @@ def is_attn_sparsified(model: nn.Module) -> bool: return any(isinstance(module, SparseAttentionModule) for module in model.modules()) +def _register_diffusers_backends_if_needed(model: nn.Module) -> None: + """Register diffusers/LTX attention backends if the model needs them. + + Called before plugin registration so that the backends are available + when ``SparseAttentionModule.forward()`` activates the skip-softmax context. + """ + # Register the diffusers eager and Triton backends if the model is a diffusers ModelMixin + try: + from diffusers.models.modeling_utils import ModelMixin + + if isinstance(model, ModelMixin): + from .kernels import ( + register_diffusers_eager_attention, + register_diffusers_triton_attention, + ) + + if register_diffusers_eager_attention is not None: + register_diffusers_eager_attention() + if register_diffusers_triton_attention is not None: + register_diffusers_triton_attention() + except (ImportError, Exception): + pass + + # Patch ltx_core Attention modules if present (independent of diffusers) + import contextlib + + try: + from .kernels import register_ltx_eager_attention, register_ltx_triton_attention + except (ImportError, RuntimeError): + return + + if register_ltx_eager_attention is not None: + with contextlib.suppress(Exception): + register_ltx_eager_attention(model) + if register_ltx_triton_attention is not None: + with contextlib.suppress(Exception): + register_ltx_triton_attention(model) + + def convert_to_sparse_attention_model( model: ModelLikeModule, config: SparseAttentionConfig ) -> ConvertReturnType: @@ -116,6 +155,9 @@ def convert_to_sparse_attention_model( # Initialize the true module if necessary model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + # Register diffusers backends for diffusion models + _register_diffusers_backends_if_needed(model) + # Set the correct attn_implementation for the chosen backend _set_attn_implementation(model, config) @@ -346,32 +388,46 @@ def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: if calibration_params is None: return None - # Build threshold_scale_factor with model parameters - threshold_scale_factor: dict[str, Any] = { - "formula": "a * exp(b * target_sparsity)", - } - for phase in ["prefill", "decode"]: - if phase in calibration_params: - threshold_scale_factor[phase] = { - "a": calibration_params[phase]["a"], - "b": calibration_params[phase]["b"], - } + # Detect calibration type from params + sample_params = next(iter(calibration_params.values())) + is_percentile = "threshold" in sample_params # Build the export config export_config: dict[str, Any] = { "config_groups": { "group_0": { - "sparse_algo": "softmax_skip", + "sparse_algo": "softmax_skip_diffusion" if is_percentile else "softmax_skip", "targets": sorted(target_classes) if target_classes else ["Attention"], } }, - "threshold_scale_factor": threshold_scale_factor, "producer": { "name": "modelopt", "version": mo_version, }, } + if is_percentile: + threshold_config: dict[str, Any] = { + "formula": "skip if gap >= threshold * log(seq_k)", + } + for phase in ["prefill", "decode"]: + if phase in calibration_params: + threshold_config[phase] = { + "threshold": calibration_params[phase]["threshold"], + } + export_config["threshold_config"] = threshold_config + else: + threshold_scale_factor: dict[str, Any] = { + "formula": "a * exp(b * target_sparsity)", + } + for phase in ["prefill", "decode"]: + if phase in calibration_params: + threshold_scale_factor[phase] = { + "a": calibration_params[phase]["a"], + "b": calibration_params[phase]["b"], + } + export_config["threshold_scale_factor"] = threshold_scale_factor + return export_config @@ -443,6 +499,16 @@ def _format_threshold(info: dict) -> str: s = target.get(phase, 0.5) parts.append(f"{phase}: a={a:.4f}, b={b:.2f}, target={s:.0%}") return f"calibrated({', '.join(parts)})" + if t == "dynamic_calibrated_percentile": + params = info.get("calibration_params", {}) + target = info.get("target_sparse_ratio", {}) + parts = [] + for phase in ["prefill", "decode"]: + if phase in params and "threshold" in params[phase]: + th = params[phase]["threshold"] + s = target.get(phase, 0.2) + parts.append(f"{phase}: threshold={th:.4f}, target={s:.0%}") + return f"percentile({', '.join(parts)})" if t == "static": v = info.get("value") if isinstance(v, dict): @@ -470,6 +536,8 @@ def print_sparse_attention_summary(model: nn.Module): # Group by (method, threshold) groups: dict[tuple[str, str], int] = {} for _, module in sparse_modules: + if not module.is_enabled: + continue method = getattr(module, "_method", "unknown") threshold = _format_threshold(module.get_threshold_info()) groups[(method, threshold)] = groups.get((method, threshold), 0) + 1 diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py b/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py index dee1bc472a..81f4295bb4 100644 --- a/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py @@ -13,12 +13,61 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-exports from modelopt.torch.kernels for backward compatibility.""" +"""Kernel integrations for sparse attention: Triton FA and diffusers backends.""" +import contextlib +import threading + +# --------------------------------------------------------------------------- +# Triton FA kernel re-exports (for HuggingFace LLM integration) +# --------------------------------------------------------------------------- from modelopt.torch.kernels import IS_AVAILABLE, attention, register_triton_attention +# --------------------------------------------------------------------------- +# Thread-local context: shared by diffusers eager and Triton backends +# --------------------------------------------------------------------------- +_thread_local = threading.local() + + +def set_skip_softmax_context(active: bool) -> None: + """Set thread-local flag indicating skip-softmax eager attention is active.""" + _thread_local.skip_softmax_active = active + + +def get_skip_softmax_context() -> bool: + """Return True if skip-softmax eager attention is active in this thread.""" + return getattr(_thread_local, "skip_softmax_active", False) + + +# --------------------------------------------------------------------------- +# Optional backend registrations (depend on diffusers / ltx_core) +# --------------------------------------------------------------------------- +register_diffusers_eager_attention = None +register_diffusers_triton_attention = None +register_ltx_eager_attention = None +register_ltx_triton_attention = None + +# Suppress ImportError (missing package) and RuntimeError (triton without GPU driver) +with contextlib.suppress(ImportError, RuntimeError): + from .diffusers_eager_attention import register_diffusers_eager_attention + +with contextlib.suppress(ImportError, RuntimeError): + from .diffusers_triton_attention import register_diffusers_triton_attention + +with contextlib.suppress(ImportError, RuntimeError): + from .ltx_eager_attention import register_ltx_eager_attention + +with contextlib.suppress(ImportError, RuntimeError): + from .ltx_triton_attention import register_ltx_triton_attention + __all__ = [ "IS_AVAILABLE", "attention", + "get_skip_softmax_context", + "register_diffusers_eager_attention", + "register_diffusers_triton_attention", + "register_ltx_eager_attention", + "register_ltx_triton_attention", "register_triton_attention", + "set_skip_softmax_context", ] diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py new file mode 100644 index 0000000000..16dd895f27 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Eager attention backend for diffusers skip-softmax sparse attention. + +Registers a ``modelopt_skip_softmax`` backend in diffusers' +``_AttentionBackendRegistry`` that computes attention eagerly with an explicit +``F.softmax`` call. This allows the existing softmax-patching mechanism in +``SparseAttentionModule`` to intercept and apply block-wise sparsity. + +Used during **calibration only** — inference uses the Triton FA kernel. +""" + +import inspect +import math + +import torch +import torch.nn.functional as F +from diffusers.models.attention_dispatch import ( + AttentionBackendName, + _AttentionBackendRegistry, + attention_backend, +) + +_BACKEND_NAME = "modelopt_skip_softmax" +_BACKEND_REGISTERED = False + + +# --------------------------------------------------------------------------- +# Eager attention implementation +# --------------------------------------------------------------------------- + + +def _diffusers_eager_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float | None = None, + enable_gqa: bool = False, +) -> torch.Tensor: + """Compute attention eagerly on diffusers layout ``[B, S, H, D]``. + + The explicit ``F.softmax`` call is what the skip-softmax patch intercepts. + """ + # Diffusers convention: [B, S, H, D] → transpose to [B, H, S, D] + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Handle GQA: repeat K/V heads to match Q heads + if enable_gqa and query.shape[1] != key.shape[1]: + num_heads_q = query.shape[1] + num_heads_kv = key.shape[1] + n_rep = num_heads_q // num_heads_kv + key = key.repeat_interleave(n_rep, dim=1) + value = value.repeat_interleave(n_rep, dim=1) + + if scale is None: + scale = 1.0 / math.sqrt(query.shape[-1]) + + # Q @ K^T * scale + scores = torch.matmul(query, key.transpose(-2, -1)) * scale + + # Apply attention mask if provided + if attn_mask is not None: + scores = scores + attn_mask + + # Apply causal mask if needed + if is_causal: + seq_q, seq_k = scores.shape[-2], scores.shape[-1] + causal_mask = torch.triu( + torch.full((seq_q, seq_k), float("-inf"), device=scores.device, dtype=scores.dtype), + diagonal=seq_k - seq_q + 1, + ) + scores = scores + causal_mask + + # F.softmax — this is where the skip-softmax patch intercepts + scores = F.softmax(scores, dim=-1) + + if dropout_p > 0.0: + scores = F.dropout(scores, p=dropout_p, training=True) + + # scores @ V + out = torch.matmul(scores, value) + + # Transpose back: [B, H, S, D] → [B, S, H, D] + out = out.transpose(1, 2) + return out + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + + +def register_diffusers_eager_attention() -> None: + """Register ``modelopt_skip_softmax`` backend in diffusers. + + Safe to call multiple times; registration happens only once. + """ + global _BACKEND_REGISTERED + if _BACKEND_REGISTERED: + return + + # Extend the AttentionBackendName enum with our custom value + new_member = str.__new__(AttentionBackendName, _BACKEND_NAME) + new_member._name_ = "MODELOPT_SKIP_SOFTMAX" + new_member._value_ = _BACKEND_NAME + AttentionBackendName._member_map_["MODELOPT_SKIP_SOFTMAX"] = new_member + AttentionBackendName._value2member_map_[_BACKEND_NAME] = new_member + + # Register the backend function + _AttentionBackendRegistry._backends[new_member] = _diffusers_eager_attention + _AttentionBackendRegistry._constraints[new_member] = [] + _AttentionBackendRegistry._supported_arg_names[new_member] = set( + inspect.signature(_diffusers_eager_attention).parameters.keys() + ) + + _BACKEND_REGISTERED = True + + +def get_skip_softmax_attention_backend(): + """Return a context manager that activates the modelopt_skip_softmax backend. + + Raises RuntimeError if the backend has not been registered yet. + """ + if not _BACKEND_REGISTERED: + raise RuntimeError( + "modelopt_skip_softmax backend not registered. " + "Call register_diffusers_eager_attention() first." + ) + return attention_backend(_BACKEND_NAME) diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py new file mode 100644 index 0000000000..17fec4e4eb --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton flash attention backend for diffusers models. + +Registers a ``modelopt_triton`` backend in diffusers' ``_AttentionBackendRegistry`` +that converts the diffusers [B, S, H, D] layout to the Triton FA kernel's varlen +[total_tokens, H, D] format. Supports skip-softmax tile skipping for sparse attention. + +Used during **inference** -- calibration uses the eager backend instead. +""" + +import inspect +import math +import threading + +import torch +from diffusers.models.attention_dispatch import ( + AttentionBackendName, + _AttentionBackendRegistry, + attention_backend, +) + +from modelopt.torch.kernels import attention + +_BACKEND_NAME = "modelopt_triton" +_BACKEND_REGISTERED = False + +# Thread-local storage for per-forward skip-softmax configuration. +# The method's get_sparse_context() sets these before each forward pass. +_thread_local = threading.local() + + +def set_triton_skip_softmax_config(threshold: float | None = None) -> None: + """Set thread-local skip-softmax config for the next Triton attention call.""" + _thread_local.skip_threshold = threshold + + +def clear_triton_skip_softmax_config() -> None: + """Clear thread-local skip-softmax config.""" + _thread_local.skip_threshold = None + + +# --------------------------------------------------------------------------- +# Triton attention implementation for diffusers layout +# --------------------------------------------------------------------------- + + +def _diffusers_triton_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float | None = None, + enable_gqa: bool = False, +) -> torch.Tensor: + """Compute attention via Triton FA kernel on diffusers layout ``[B, S, H, D]``. + + Converts to the kernel's varlen format, calls the Triton FA kernel, and + converts back. + """ + batch, seq_q, num_heads_q, head_dim = query.shape + seq_k = key.shape[1] + device = query.device + + # Reshape from diffusers [B, S, H, D] -> flat [B*S, H, D] + q = query.reshape(batch * seq_q, num_heads_q, head_dim).contiguous() + k = key.reshape(batch * seq_k, key.shape[2], head_dim).contiguous() + v = value.reshape(batch * seq_k, value.shape[2], head_dim).contiguous() + + # Build varlen metadata + b_start_loc_q = torch.arange(batch, device=device, dtype=torch.int32) * seq_q + b_seq_len_q = torch.full((batch,), seq_q, device=device, dtype=torch.int32) + + if scale is None: + scale = 1.0 / math.sqrt(head_dim) + + kw: dict = { + "b_start_loc": b_start_loc_q, + "b_seq_len": b_seq_len_q, + "max_input_len": seq_q, + "is_causal": is_causal, + "softmax_scale": scale, + } + + # If Q and KV have different sequence lengths, pass separate KV metadata + if seq_q != seq_k: + b_start_loc_k = torch.arange(batch, device=device, dtype=torch.int32) * seq_k + b_seq_len_k = torch.full((batch,), seq_k, device=device, dtype=torch.int32) + kw["b_start_loc_k"] = b_start_loc_k + kw["b_seq_len_k"] = b_seq_len_k + kw["max_input_len_k"] = seq_k + + # Read skip-softmax config from thread-local storage + threshold = getattr(_thread_local, "skip_threshold", None) + if threshold is not None and threshold > 0.0: + kw["skip_softmax_threshold"] = threshold + + assert attention is not None, "Triton attention kernel not available (requires CUDA + triton)" + o = attention(q, k, v, **kw) + + # Reshape back: [B*S, H, D] -> [B, S, H, D] + return o.view(batch, seq_q, num_heads_q, head_dim) + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + + +def register_diffusers_triton_attention() -> None: + """Register ``modelopt_triton`` backend in diffusers. + + Safe to call multiple times; registration happens only once. + """ + global _BACKEND_REGISTERED + if _BACKEND_REGISTERED: + return + + # Extend the AttentionBackendName enum with our custom value + new_member = str.__new__(AttentionBackendName, _BACKEND_NAME) + new_member._name_ = "MODELOPT_TRITON" + new_member._value_ = _BACKEND_NAME + AttentionBackendName._member_map_["MODELOPT_TRITON"] = new_member + AttentionBackendName._value2member_map_[_BACKEND_NAME] = new_member + + # Register the backend function + _AttentionBackendRegistry._backends[new_member] = _diffusers_triton_attention + _AttentionBackendRegistry._constraints[new_member] = [] + _AttentionBackendRegistry._supported_arg_names[new_member] = set( + inspect.signature(_diffusers_triton_attention).parameters.keys() + ) + + _BACKEND_REGISTERED = True + + +def get_triton_attention_backend(): + """Return a context manager that activates the modelopt_triton backend. + + Raises RuntimeError if the backend has not been registered yet. + """ + if not _BACKEND_REGISTERED: + raise RuntimeError( + "modelopt_triton backend not registered. " + "Call register_diffusers_triton_attention() first." + ) + return attention_backend(_BACKEND_NAME) diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py new file mode 100644 index 0000000000..6c082ee588 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Eager attention wrapper for LTX-2 (ltx_core) skip-softmax sparse attention. + +Patches ``Attention`` modules from ``ltx_core`` so that when the skip-softmax +thread-local flag is active, attention is computed eagerly with an explicit +``F.softmax`` call that the softmax-patching mechanism can intercept. + +Used during **calibration only** — inference uses the Triton FA kernel via +the diffusers Triton backend. +""" + +import math + +import torch +import torch.nn.functional as F +from ltx_core.model.transformer.attention import Attention + +from . import get_skip_softmax_context + + +def _ltx_eager_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, +) -> torch.Tensor: + """Eager attention on LTX-2 layout ``[B, T, H*D]``. + + Mirrors the ``PytorchAttention`` class in ltx_core but uses an explicit + ``F.softmax`` instead of ``scaled_dot_product_attention``. + """ + b, _, dim_total = q.shape + dim_head = dim_total // heads + + # Reshape to [B, T, H, D] then transpose to [B, H, T, D] + q = q.view(b, -1, heads, dim_head).transpose(1, 2) + k = k.view(b, -1, heads, dim_head).transpose(1, 2) + v = v.view(b, -1, heads, dim_head).transpose(1, 2) + + scale = 1.0 / math.sqrt(dim_head) + + # Q @ K^T * scale + scores = torch.matmul(q, k.transpose(-2, -1)) * scale + + # Apply mask if provided + if mask is not None: + # Expand mask dimensions to match scores [B, H, Sq, Sk] + if mask.ndim == 2: + mask = mask.unsqueeze(0) + if mask.ndim == 3: + mask = mask.unsqueeze(1) + scores = scores + mask + + # F.softmax — intercepted by skip-softmax patch + scores = F.softmax(scores, dim=-1) + + # scores @ V + out = torch.matmul(scores, v) + + # [B, H, T, D] → [B, T, H*D] + out = out.transpose(1, 2).reshape(b, -1, heads * dim_head) + return out + + +class _SkipSoftmaxLTXAttentionWrapper: + """Wraps an ``attention_function`` callable from ltx_core. + + When the thread-local skip-softmax flag is active, routes to the eager + attention path. Otherwise calls the original function. + """ + + def __init__(self, original_fn): + self._original_fn = original_fn + + def __call__( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if get_skip_softmax_context(): + return _ltx_eager_attention(q, k, v, heads, mask) + return self._original_fn(q, k, v, heads, mask) + + +def register_ltx_eager_attention(model: torch.nn.Module) -> None: + """Walk *model* and patch all ``ltx_core.model.transformer.attention.Attention`` modules. + + Patches modules so their ``attention_function`` is routed through the eager wrapper. + Safe to call multiple times on the same model — already-wrapped modules are + skipped. + """ + for module in model.modules(): + if isinstance(module, Attention): + fn = module.attention_function + if not isinstance(fn, _SkipSoftmaxLTXAttentionWrapper): + module.attention_function = _SkipSoftmaxLTXAttentionWrapper(fn) diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py new file mode 100644 index 0000000000..ddb880026c --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton flash attention wrapper for LTX-2 (ltx_core) skip-softmax sparse attention. + +Patches ``Attention`` modules from ``ltx_core`` so that when the Triton +skip-softmax flag is active, attention is computed via the Triton FA kernel +with fused tile skipping. + +Used during **inference** -- calibration uses the eager wrapper instead. +""" + +import math +import threading + +import torch +from ltx_core.model.transformer.attention import Attention + +from modelopt.torch.kernels import attention + +# Thread-local storage for skip-softmax configuration +_thread_local = threading.local() + + +def set_ltx_triton_context( + active: bool, + threshold: float | None = None, +) -> None: + """Set thread-local Triton config for LTX-2 attention.""" + _thread_local.active = active + _thread_local.threshold = threshold + + +def clear_ltx_triton_context() -> None: + """Clear thread-local Triton config.""" + _thread_local.active = False + _thread_local.threshold = None + + +def _get_ltx_triton_context() -> tuple[bool, float | None]: + """Return (active, threshold).""" + return ( + getattr(_thread_local, "active", False), + getattr(_thread_local, "threshold", None), + ) + + +def _ltx_triton_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, + threshold: float | None = None, +) -> torch.Tensor: + """Triton FA attention on LTX-2 layout ``[B, T, H*D]``. + + Converts from LTX-2's fused-head layout to the Triton kernel's varlen + format, calls the kernel with skip-softmax, and converts back. + """ + b, seq_q, dim_total = q.shape + dim_head = dim_total // heads + seq_k = k.shape[1] + device = q.device + + # LTX-2 layout: [B, T, H*D] -> reshape to [B, T, H, D] -> flat [B*T, H, D] + q_flat = q.view(b, seq_q, heads, dim_head).reshape(b * seq_q, heads, dim_head).contiguous() + k_flat = k.view(b, seq_k, heads, dim_head).reshape(b * seq_k, heads, dim_head).contiguous() + v_flat = v.view(b, seq_k, heads, dim_head).reshape(b * seq_k, heads, dim_head).contiguous() + + # Build varlen metadata + b_start_loc_q = torch.arange(b, device=device, dtype=torch.int32) * seq_q + b_seq_len_q = torch.full((b,), seq_q, device=device, dtype=torch.int32) + + scale = 1.0 / math.sqrt(dim_head) + + kw: dict = { + "b_start_loc": b_start_loc_q, + "b_seq_len": b_seq_len_q, + "max_input_len": seq_q, + "is_causal": False, # Diffusion uses bidirectional attention + "softmax_scale": scale, + } + + # Handle different Q/KV sequence lengths + if seq_q != seq_k: + b_start_loc_k = torch.arange(b, device=device, dtype=torch.int32) * seq_k + b_seq_len_k = torch.full((b,), seq_k, device=device, dtype=torch.int32) + kw["b_start_loc_k"] = b_start_loc_k + kw["b_seq_len_k"] = b_seq_len_k + kw["max_input_len_k"] = seq_k + + # Skip-softmax threshold + if threshold is not None and threshold > 0.0: + kw["skip_softmax_threshold"] = threshold + + assert attention is not None, "Triton attention kernel not available (requires CUDA + triton)" + o = attention(q_flat, k_flat, v_flat, **kw) + + # Reshape back: [B*T, H, D] -> [B, T, H*D] + return o.view(b, seq_q, heads * dim_head) + + +class _TritonLTXAttentionWrapper: + """Wraps an ``attention_function`` callable from ltx_core. + + When the thread-local Triton skip-softmax flag is active, routes to the + Triton FA kernel. Otherwise calls the original function. + """ + + def __init__(self, original_fn): + self._original_fn = original_fn + + def __call__( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + active, threshold = _get_ltx_triton_context() + if active: + return _ltx_triton_attention(q, k, v, heads, mask, threshold) + return self._original_fn(q, k, v, heads, mask) + + +def register_ltx_triton_attention(model: torch.nn.Module) -> None: + """Walk *model* and patch all ``ltx_core.Attention`` modules for Triton dispatch. + + Safe to call multiple times -- already-wrapped modules are skipped. + """ + for module in model.modules(): + if isinstance(module, Attention): + fn = module.attention_function + if not isinstance(fn, _TritonLTXAttentionWrapper): + module.attention_function = _TritonLTXAttentionWrapper(fn) diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 2501b58f65..aab399292a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -20,6 +20,7 @@ """ import math +from contextlib import ExitStack from typing import Any import numpy as np @@ -369,7 +370,11 @@ def get_threshold_info(self) -> dict[str, Any]: } def get_sparse_context(self, module: torch.nn.Module): - """Return a context manager that patches F.softmax with sparse masking.""" + """Return a context manager that patches F.softmax with sparse masking. + + Also registers the diffusers eager backend so that diffusion models + (which don't call F.softmax directly) route through the patched path. + """ original_softmax = F.softmax def sparse_softmax(input, dim=-1, *args, **kwargs): @@ -379,7 +384,21 @@ def sparse_softmax(input, dim=-1, *args, **kwargs): input = self.apply_sparsity(input, sparse_mask) return original_softmax(input, dim, *args, **kwargs) - return replace_function(torch.nn.functional, "softmax", sparse_softmax) + from ..kernels import set_skip_softmax_context + + stack = ExitStack() + set_skip_softmax_context(True) + stack.callback(set_skip_softmax_context, False) + + try: + from ..kernels.diffusers_eager_attention import get_skip_softmax_attention_backend + + stack.enter_context(get_skip_softmax_attention_backend()) + except (ImportError, RuntimeError): + pass + + stack.enter_context(replace_function(torch.nn.functional, "softmax", sparse_softmax)) + return stack @property def name(self) -> str: diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py index 4db51e894e..b885eeaea5 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -17,6 +17,8 @@ from contextlib import contextmanager +import torch + from .registry import SparseAttentionMethod, register_sparse_method @@ -45,6 +47,18 @@ def name(self) -> str: """Method name identifier.""" return "triton_skip_softmax" + def calculate_sparsity(self, attention_scores): + """Return a no-op mask (skip decision is made inside the Triton kernel).""" + mask = torch.ones_like(attention_scores, dtype=torch.bool) + return mask, {} + + def apply_sparsity(self, attention_scores, sparse_mask=None): + """Not supported — tile skipping is fused into the Triton kernel.""" + raise NotImplementedError( + "triton_skip_softmax applies tile skipping inside the Triton kernel. " + "Use backend='triton', not backend='pytorch'." + ) + def get_sparse_context(self, module): """Return context manager that activates skip-softmax during forward.""" diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index 599832943d..d26b73f0b4 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -16,7 +16,6 @@ """Dynamic sparse attention registration for HuggingFace models.""" import torch.nn as nn -import transformers from modelopt.torch.opt.dynamic import DynamicModule @@ -112,11 +111,22 @@ def _is_supported_model(model: nn.Module) -> bool: """ # Check for HuggingFace PreTrainedModel try: + import transformers + if isinstance(model, transformers.PreTrainedModel): return True except ImportError: pass + # Check for diffusers ModelMixin + try: + from diffusers.models.modeling_utils import ModelMixin + + if isinstance(model, ModelMixin): + return True + except ImportError: + pass + # Support any PyTorch model with attention modules return isinstance(model, nn.Module) diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py index 1eabdfe358..de70c3cadf 100644 --- a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -79,14 +79,15 @@ def collect(self, stats: dict): # In calibration mode, store per-sample stats if self.calibration_mode: - self.per_sample_stats.append( - { - "module": self.module_name, - "sparsity": stats.get("sparsity", 0.0), - "sample_length": stats.get("sample_length", 0), - "phase": phase, - } - ) + sample_stat = { + "module": self.module_name, + "sparsity": stats.get("sparsity", 0.0), + "sample_length": stats.get("sample_length", 0), + "phase": phase, + } + if "normalized_gaps" in stats: + sample_stat["normalized_gaps"] = stats["normalized_gaps"] + self.per_sample_stats.append(sample_stat) def get_summary(self) -> dict: """Get aggregated statistics summary. diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py b/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py new file mode 100644 index 0000000000..b8685b8410 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for diffusers kernel backends and thread-local context.""" + +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.nn as nn + + +def _mock_diffusers(): + """Mock diffusers.models.attention_dispatch for testing without real diffusers.""" + m = types.ModuleType("diffusers.models.attention_dispatch") + + class FakeBackendName(str): + _member_map_: dict = {} + _value2member_map_: dict = {} + + m.AttentionBackendName = FakeBackendName + + class FakeReg: + _backends: dict = {} + _constraints: dict = {} + _supported_arg_names: dict = {} + + m._AttentionBackendRegistry = FakeReg + m.attention_backend = MagicMock() + return { + "diffusers": types.ModuleType("diffusers"), + "diffusers.models": types.ModuleType("diffusers.models"), + "diffusers.models.attention_dispatch": m, + } + + +# --------------------------------------------------------------------------- +# Tests: thread-local skip-softmax context +# --------------------------------------------------------------------------- + + +class TestSkipSoftmaxContext: + def test_default_is_false(self): + from modelopt.torch.sparsity.attention_sparsity.kernels import get_skip_softmax_context + + assert get_skip_softmax_context() is False + + def test_set_and_get(self): + from modelopt.torch.sparsity.attention_sparsity.kernels import ( + get_skip_softmax_context, + set_skip_softmax_context, + ) + + set_skip_softmax_context(True) + assert get_skip_softmax_context() is True + set_skip_softmax_context(False) + assert get_skip_softmax_context() is False + + +# --------------------------------------------------------------------------- +# Tests: diffusers eager attention +# --------------------------------------------------------------------------- + + +class TestDiffusersEagerAttention: + @pytest.fixture(autouse=True) + def _setup(self): + with patch.dict(sys.modules, _mock_diffusers()): + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention import ( + _diffusers_eager_attention, + get_skip_softmax_attention_backend, + register_diffusers_eager_attention, + ) + + self._fn = _diffusers_eager_attention + self._register = register_diffusers_eager_attention + self._get_backend = get_skip_softmax_attention_backend + + import modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention as mod + + mod._BACKEND_REGISTERED = False + yield + + def test_basic_shape(self): + q = torch.randn(2, 8, 4, 16) + assert self._fn(q, q, q).shape == (2, 8, 4, 16) + + def test_cross_attention(self): + q = torch.randn(1, 4, 2, 8) + k = torch.randn(1, 12, 2, 8) + assert self._fn(q, k, k).shape == (1, 4, 2, 8) + + def test_causal(self): + q = torch.randn(1, 4, 1, 8) + assert self._fn(q, q, q, is_causal=True).shape == (1, 4, 1, 8) + + def test_gqa(self): + q = torch.randn(1, 4, 8, 16) + k = torch.randn(1, 4, 2, 16) + assert self._fn(q, k, k, enable_gqa=True).shape == (1, 4, 8, 16) + + def test_register_idempotent(self): + self._register() + self._register() + + def test_get_backend_before_register_raises(self): + with pytest.raises(RuntimeError, match="not registered"): + self._get_backend() + + +# --------------------------------------------------------------------------- +# Tests: diffusers triton attention +# --------------------------------------------------------------------------- + + +class TestDiffusersTritonAttention: + @pytest.fixture(autouse=True) + def _setup(self): + mocks = _mock_diffusers() + mk = types.ModuleType("modelopt.torch.kernels") + mk.attention = lambda q, k, v, **kw: q + mk.IS_AVAILABLE = True + mk.register_triton_attention = None + mocks["modelopt.torch.kernels"] = mk + + with patch.dict(sys.modules, mocks): + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + _diffusers_triton_attention, + clear_triton_skip_softmax_config, + get_triton_attention_backend, + register_diffusers_triton_attention, + set_triton_skip_softmax_config, + ) + + self._fn = _diffusers_triton_attention + self._set = set_triton_skip_softmax_config + self._clear = clear_triton_skip_softmax_config + self._register = register_diffusers_triton_attention + self._get_backend = get_triton_attention_backend + + import modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention as mod + + mod._BACKEND_REGISTERED = False + yield + + def test_set_clear_config(self): + self._set(threshold=0.1) + self._clear() + + def test_register_idempotent(self): + self._register() + self._register() + + def test_get_backend_before_register_raises(self): + with pytest.raises(RuntimeError, match="not registered"): + self._get_backend() + + +# --------------------------------------------------------------------------- +# Tests: conversion.py _register_diffusers_backends_if_needed +# --------------------------------------------------------------------------- + + +class TestRegisterDiffusersBackends: + def test_no_diffusers_no_error(self): + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + _register_diffusers_backends_if_needed, + ) + + _register_diffusers_backends_if_needed(nn.Linear(10, 10)) + + def test_with_diffusers_model(self): + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + _register_diffusers_backends_if_needed, + ) + + mock_mixin = type("ModelMixin", (nn.Module,), {}) + mock_utils = types.ModuleType("diffusers.models.modeling_utils") + mock_utils.ModelMixin = mock_mixin + + with ( + patch.dict(sys.modules, {"diffusers.models.modeling_utils": mock_utils}), + patch( + "modelopt.torch.sparsity.attention_sparsity.kernels.register_diffusers_eager_attention", + MagicMock(), + ) as mock_eager, + patch( + "modelopt.torch.sparsity.attention_sparsity.kernels.register_diffusers_triton_attention", + MagicMock(), + ) as mock_triton, + ): + _register_diffusers_backends_if_needed(mock_mixin()) + mock_eager.assert_called_once() + mock_triton.assert_called_once()