Skip to content

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Nov 11, 2025

What does this PR do?

Type of change: ?
new feature

Overview: ?

  • This PR adds the sparse attention calibration algorithm
  • Chunked prefill to support long ctx_len
  • Separated calibration for prefill and decode

Usage

import modelopt.torch.sparsity.attention_sparsity as mtsa

# Apply sparse attention with calibration
model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB)

# Print summary - now shows actual thresholds
mtsa.print_sparse_attention_summary(model)
# Output:
# Method: flash_skip_softmax, Threshold: Dynamic (λ=437.395926)

# Or llm_eval integration
# HuggingFace sparse attention example
python examples/llm_sparsity/attention_sparsity/hf_sa.py \
    --pyt_ckpt_path Qwen/Qwen3-4B \
    --sparse_attn skip_softmax_calib 

Testing

The calibration results for Qwen/Qwen3-30B-A3B-Thinking-2507 are shown below and are consistent with our offline calibration results.

Prefill Calibration Results:
  Threshold scale factor: 898.742997 (std: 354.176531)
  R-squared: 0.8080
  Average achieved sparsity: 49.45% (target: 50.00%)

Decode Calibration Results:
  Threshold scale factor: 6.918761 (std: 4.768286)
  R-squared: 0.3797
  Average achieved sparsity: 50.78% (target: 50.00%)

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

@kaix-nv kaix-nv requested review from a team as code owners November 11, 2025 22:38
@kaix-nv kaix-nv requested review from RalphMao and removed request for RalphMao November 11, 2025 22:38
@codecov
Copy link

codecov bot commented Nov 11, 2025

Codecov Report

❌ Patch coverage is 73.62924% with 202 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.70%. Comparing base (b655321) to head (ed213d9).

Files with missing lines Patch % Lines
...arsity/attention_sparsity/calibration/calibrate.py 30.53% 91 Missing ⚠️
...rsity/attention_sparsity/calibration/calibrator.py 63.86% 43 Missing ⚠️
...sity/attention_sparsity/calibration/ruler_utils.py 73.24% 42 Missing ⚠️
...pt/torch/sparsity/attention_sparsity/conversion.py 72.97% 10 Missing ⚠️
...ch/sparsity/attention_sparsity/sparse_attention.py 63.15% 7 Missing ⚠️
...delopt/torch/sparsity/attention_sparsity/config.py 89.28% 6 Missing ⚠️
...sparsity/attention_sparsity/calibration/dataset.py 99.38% 1 Missing ⚠️
...y/attention_sparsity/methods/flash_skip_softmax.py 97.29% 1 Missing ⚠️
...ch/sparsity/attention_sparsity/methods/registry.py 80.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #538      +/-   ##
==========================================
+ Coverage   74.69%   74.70%   +0.01%     
==========================================
  Files         192      198       +6     
  Lines       18948    19691     +743     
==========================================
+ Hits        14153    14711     +558     
- Misses       4795     4980     +185     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 8c7ee86 to da6f627 Compare November 12, 2025 00:17
@kaix-nv kaix-nv changed the title [3/n] Adds sparse attention integration to the llm_eval examples [OMNIML-2850] [3/n] Adds sparse attention integration to the llm_eval examples Nov 12, 2025
@kaix-nv kaix-nv changed the title [OMNIML-2850] [3/n] Adds sparse attention integration to the llm_eval examples [OMNIML-2850][3/n] Adds sparse attention integration to the llm_eval examples Nov 12, 2025
@kaix-nv kaix-nv changed the title [OMNIML-2850][3/n] Adds sparse attention integration to the llm_eval examples [OMNIML-2850] [3/n] Adds sparse attention integration to the llm_eval examples Nov 12, 2025
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch 4 times, most recently from 525a119 to c9d7008 Compare November 13, 2025 07:40
@kaix-nv kaix-nv changed the title [OMNIML-2850] [3/n] Adds sparse attention integration to the llm_eval examples [OMNIML-2850] [3/n] Adds sparse attention calibration; Adds llm_eval support Nov 14, 2025
@kaix-nv kaix-nv changed the title [OMNIML-2850] [3/n] Adds sparse attention calibration; Adds llm_eval support [OMNIML-2850] [3/n] Adds sparse attention calibration Nov 14, 2025
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch 5 times, most recently from 7727793 to 2864629 Compare December 1, 2025 11:35
@kaix-nv kaix-nv requested a review from a team as a code owner December 1, 2025 11:35
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 2864629 to ca7e24e Compare December 1, 2025 15:19
@kaix-nv kaix-nv removed the request for review from kevalmorabia97 December 1, 2025 15:25
@kevalmorabia97
Copy link
Collaborator

kevalmorabia97 commented Dec 1, 2025

@kaix-nv github is showing 7000+ lines of code as part of this PR. Is that accurate?
It shouldn’t be that much. Less than half of the code should remain after rebasing on the preceding PR.

@kaix-nv kaix-nv requested a review from jy-yuan December 8, 2025 21:52
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch 4 times, most recently from 3474b6f to 74a29ea Compare December 13, 2025 21:00
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 0aa51a5 to 2101e99 Compare December 15, 2025 19:44
@kaix-nv kaix-nv requested a review from jy-yuan December 15, 2025 19:44
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 2101e99 to 80e0196 Compare December 15, 2025 19:50
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 80e0196 to 010e86a Compare December 16, 2025 00:46
@shengliangxu
Copy link
Contributor

I suggest keep just the calibration core logic + the unit tests for calibration in this PR.

Make a 4/n pr for the rest.

)

max_seqlen: int = ModeloptField(
default=32768,
Copy link
Contributor

@realAsma realAsma Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we support 32K sequence with eager attention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve added chunked prefill to support long sequence lengths. The default chunk_size is set to 2048, and users can reduce chunk_size if they encounter OOM issues.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For decoding, FlashAttention is used during prefill, which also supports long sequence lengths.

config: Sparse attention configuration with calibration settings
forward_loop: Optional callable that forwards calibration data through the model.
If provided, uses this for calibration data.
If None, will auto-generate RULER dataset for calibration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the auto generation work in general like Megatron Core models?

samples: int = ModeloptField(
default=24,
title="Calibration samples",
description="Total number of RULER samples for calibration (distributed across length bins).",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we hard coding the data set here??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a setup-dependent magic number. The default calibration settings are:
num_length_bins = 4
num_tasks = 6
which results in a minimum of 24 data samples.
I’ve updated the description to explain the details.

# Configuration with RULER calibration
# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length
# The calibrated threshold adapts to sequence length for optimal sparsity
SKIP_SOFTMAX_CALIB = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this config work for Megatron? Otherwise I suggest removing the dataset specific config - dataset and num_samples should be handled by example files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the dataset specific configs are framework-agnostic, they describe what calibration data to generate, not how to generate it. Both HF and Megatron helpers can consume these parameters to generate appropriate calibration data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Megatron support (e.g., adding forward_loop) will be added in future PRs.

Comment on lines 272 to 279
# Force eager attention if sparse attention is requested
if sparse_cfg:
kwargs["attn_implementation"] = "eager"
warnings.warn(
"Sparse attention requires attn_implementation='eager'. "
"Forcing eager attention implementation."
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can e move this to sparsity/plugins/hugginface ? We should detect if this is a HF model and if yes apply this (see

AutoQuantizeGradientSearcher.register_custom_support(
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved this check to hugginface.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as https://github.com/NVIDIA/Model-Optimizer/pull/538/files#r2646356349 and avoid repeated attention modification

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed this check.

from .dataset import RulerDatasetBuilder


def _extract_tokenizer_from_model(model: nn.Module) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works only work Huggingface transformers, is it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Megatron support (e.g., adding forward_loop) will be added in future PRs.

Copy link
Contributor

@realAsma realAsma Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this sparsity/plugins?

Otherwise this change will make transformers library a required dependency of ModelOpt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use local imports of third party libraries wherever necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve switched to local imports. Since most of the dataset construction logic can be reused by Megatron, I’d prefer to keep this file under the calibration folder.

# For causal attention, only count lower triangle blocks (including diagonal)
num_causal_blocks = num_block_rows * (2 * num_block_cols - num_block_rows + 1) // 2
total_valid_blocks = batch_size * num_heads * num_causal_blocks
density = float(block_mask.sum()) / total_valid_blocks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we keep this as a torch tensor? a float(tensor_in_gpu) causes unneseccary CPU-GPU sync

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. It’s been fixed.

density = float(block_mask.sum()) / total_valid_blocks
total_blocks = num_causal_blocks
else:
density = float(block_mask.sum() / block_mask.numel())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s been fixed.

Copy link
Collaborator

@jy-yuan jy-yuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the PR, and the calibration logic looks correct (dynamic thresholding via regression and phase separation seems solid). I tested the code and verified that the unit tests pass locally.

One small observation regarding dependencies: I noticed I had to manually install nltk and wonderwords to run the tests/calibration. It seems they are currently added to dev-test in setup.py, so they aren't included in pip install nvidia-modelopt[all]. If the RULER-based calibration is intended to be a supported feature for users (i.e., when they don't provide a custom forward loop), we might want to consider moving these to a user-facing optional dependency group (like calibration or hf) or catching the ModuleNotFoundError to suggest installation.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch 2 times, most recently from c1692d4 to e65d3e1 Compare December 30, 2025 22:32
@kaix-nv
Copy link
Contributor Author

kaix-nv commented Dec 30, 2025

Reviewed the PR, and the calibration logic looks correct (dynamic thresholding via regression and phase separation seems solid). I tested the code and verified that the unit tests pass locally.

One small observation regarding dependencies: I noticed I had to manually install nltk and wonderwords to run the tests/calibration. It seems they are currently added to dev-test in setup.py, so they aren't included in pip install nvidia-modelopt[all]. If the RULER-based calibration is intended to be a supported feature for users (i.e., when they don't provide a custom forward loop), we might want to consider moving these to a user-facing optional dependency group (like calibration or hf) or catching the ModuleNotFoundError to suggest installation.

Thanks Jiayi. Good catch. I’ve moved the dependency to the hf group, so it can now be installed via pip install -U nvidia-modelopt[hf]. cc @kevalmorabia97

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch 2 times, most recently from 7fc092b to 30a9794 Compare December 31, 2025 03:54
Signed-off-by: Kai Xu <[email protected]>
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_calibration branch from 30a9794 to ed213d9 Compare December 31, 2025 08:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants