Skip to content

Add the Skip softmax for diffusion#1166

Draft
jingyu-ml wants to merge 5 commits intomainfrom
jingyux/diffusion-skip-softmax
Draft

Add the Skip softmax for diffusion#1166
jingyu-ml wants to merge 5 commits intomainfrom
jingyux/diffusion-skip-softmax

Conversation

@jingyu-ml
Copy link
Copy Markdown
Contributor

@jingyu-ml jingyu-ml commented Apr 2, 2026

What does this PR do?

Type of change: new feature, new example

Summary

  • Add skip-softmax sparse attention support for diffusion models (LTX-2, Wan 2.2) using flash_skip_softmax with exponential model calibration (scale_factor = a * exp(b * sparsity))
  • Add diffusers/LTX kernel backends so that eager attention (with F.softmax patching) works on diffusion models that normally use scaled_dot_product_attention
  • Fix calibration to skip RULER dataset generation when user provides their own forward_loop (required for non-LLM models)

Changes

  • New kernel backends: diffusers_triton_attention.py, diffusers_eager_attention.py, ltx_triton_attention.py, ltx_eager_attention.py — route diffusers/LTX attention through explicit F.softmax for calibration
  • kernels/__init__.py: Thread-local context management, lazy imports for diffusers/LTX backends
  • conversion.py: Auto-register diffusers backends on sparsify(), updated export config and summary
  • calibrate.py: Skip RULER dataset when forward_loop is provided (enables diffusion model calibration)
  • flash_skip_softmax.py: Enhanced context manager activates diffusers eager backend
  • plugins/huggingface.py: Support diffusers ModelMixin in model detection
  • Example scripts: ltx2_skip_softmax.py, wan22_skip_softmax.py

Usage

import modelopt.torch.sparsity.attention_sparsity as mtsa

# 1. Build your diffusion pipeline and get the transformer
transformer = pipeline.transformer  # or pipeline.stage_1_model_ledger.transformer()

# 2. Define sparse config
config = {
    "sparse_cfg": {
        "calibration": {
            "target_sparse_ratio": {"prefill": 0.25},
            "threshold_trials": [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3,
                                 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1],
        },
        "*.attn1": {
            "method": "flash_skip_softmax",
            "thresholds": {"prefill": [1e-3]},
            "br": 128, "bc": 128,
            "backend": "pytorch",
            "is_causal": False,
            "collect_stats": True,
            "enable": True,
        },
        "*.attn2": {"enable": False},      # skip cross-attention
        "default": {"enable": False},
    },
}

# 3. Define a calibration forward loop (runs the diffusion pipeline)
def forward_loop(model):
    pipeline(prompt="a cat", num_frames=81, num_inference_steps=40, ...)

# 4. Sparsify + calibrate
mtsa.sparsify(transformer, config, forward_loop=forward_loop)

# 5. Generate as usual — sparsity is applied automatically
output = pipeline(prompt="a dog on the beach", ...)

Example scripts

# LTX-2 with 25% sparsity, skip first/last 3 layers
python examples/diffusers/sparsity/ltx2_skip_softmax.py \
    --prompt "A cat playing piano" --output out.mp4 \
    --calibrate --target-sparsity 0.25 --skip-first-last 3

# Wan 2.2 with 25% sparsity
python examples/diffusers/sparsity/wan22_skip_softmax.py \
    --prompt "A sunset over mountains" --output out.mp4 \
    --calibrate --target-sparsity 0.25

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners April 2, 2026 06:02
@jingyu-ml jingyu-ml requested a review from kaix-nv April 2, 2026 06:02
@jingyu-ml jingyu-ml marked this pull request as draft April 2, 2026 06:02
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 2, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7a7e236c-3085-42d5-95a1-f02bb4764e21

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch jingyux/diffusion-skip-softmax

Comment @coderabbitai help to get the list of available commands and usage tips.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 2, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jingyu-ml jingyu-ml self-assigned this Apr 2, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 2, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1166/

Built to branch gh-pages at 2026-04-02 22:52 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 2, 2026

Codecov Report

❌ Patch coverage is 50.69444% with 142 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.11%. Comparing base (87ea8ba) to head (2c323df).

Files with missing lines Patch % Lines
...attention_sparsity/kernels/ltx_triton_attention.py 7.84% 47 Missing ⚠️
.../attention_sparsity/kernels/ltx_eager_attention.py 11.11% 32 Missing ⚠️
...pt/torch/sparsity/attention_sparsity/conversion.py 35.41% 31 Missing ⚠️
...ion_sparsity/kernels/diffusers_triton_attention.py 53.84% 24 Missing ⚠️
...tion_sparsity/kernels/diffusers_eager_attention.py 93.61% 3 Missing ⚠️
.../attention_sparsity/methods/triton_skip_softmax.py 60.00% 2 Missing ⚠️
...arsity/attention_sparsity/calibration/calibrate.py 88.88% 1 Missing ⚠️
...sparsity/attention_sparsity/plugins/huggingface.py 80.00% 1 Missing ⚠️
...torch/sparsity/attention_sparsity/stats_manager.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1166      +/-   ##
==========================================
+ Coverage   74.28%   75.11%   +0.83%     
==========================================
  Files         349      353       +4     
  Lines       39846    40122     +276     
==========================================
+ Hits        29599    30139     +540     
+ Misses      10247     9983     -264     
Flag Coverage Δ
unit 54.50% <45.13%> (-0.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml force-pushed the jingyux/diffusion-skip-softmax branch from 8151232 to 5873652 Compare April 2, 2026 08:38
jingyu-ml and others added 2 commits April 2, 2026 21:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant