-
Couldn't load subscription status.
- Fork 36
Description
Is your feature request related to a problem? Please describe.
Batch inference with mha_varlen_fwd currently assumes per-query mask/bias layouts, blocking use cases where precomputed key-side gating needs {total_k, {1|num_heads_k|num_heads}} broadcastable tensors.
Describe the solution you'd like
Allow the varlen forward path to accept mask and bias shaped (total_k, num_heads_variant) and broadcast them across query timesteps during attention scoring.
Describe alternatives you've considered
Reshaping into (total_q, ...) and backfilling per-query copies increases memory by O(total_q * num_heads) and breaks streaming workloads.
Implementation details
- Would this require CUDA kernel changes? Likely yes: adjust mask/bias reads in
flash_fwd_kernelwhenparams.has_mask/has_bias. - Does this affect the Python API? Minor: extend argument validation to accept the new layout flag.
- Are there performance implications? Positive for batch decode; avoids redundant materialization.
- Any compatibility concerns with different GPU architectures? None beyond existing Ampere+ requirement.
Use case
- Sequence lengths: mixed, typically 1–8 queries per batch with thousands of cached keys.
- Target application: autoregressive decoding with dynamic sparsity.
- Benefit: removes host-side duplication and keeps KV cache compact.
Additional context
Mask/bias tensors originate from dependency-aware MaskMod pipelines and are naturally keyed by total_k.
Related work
- Technique: Broadcasted key-side gating for efficient decoding.
- Value: Aligns varlen backend with established sparse attention patterns.