Skip to content

[FEATURE REQUEST] Support total_k-head broadcastable mask/bias in varlen batch inference #183

@LoserCheems

Description

@LoserCheems

Is your feature request related to a problem? Please describe.
Batch inference with mha_varlen_fwd currently assumes per-query mask/bias layouts, blocking use cases where precomputed key-side gating needs {total_k, {1|num_heads_k|num_heads}} broadcastable tensors.

Describe the solution you'd like
Allow the varlen forward path to accept mask and bias shaped (total_k, num_heads_variant) and broadcast them across query timesteps during attention scoring.

Describe alternatives you've considered
Reshaping into (total_q, ...) and backfilling per-query copies increases memory by O(total_q * num_heads) and breaks streaming workloads.

Implementation details

  • Would this require CUDA kernel changes? Likely yes: adjust mask/bias reads in flash_fwd_kernel when params.has_mask/has_bias.
  • Does this affect the Python API? Minor: extend argument validation to accept the new layout flag.
  • Are there performance implications? Positive for batch decode; avoids redundant materialization.
  • Any compatibility concerns with different GPU architectures? None beyond existing Ampere+ requirement.

Use case

  • Sequence lengths: mixed, typically 1–8 queries per batch with thousands of cached keys.
  • Target application: autoregressive decoding with dynamic sparsity.
  • Benefit: removes host-side duplication and keeps KV cache compact.

Additional context
Mask/bias tensors originate from dependency-aware MaskMod pipelines and are naturally keyed by total_k.

Related work

  • Technique: Broadcasted key-side gating for efficient decoding.
  • Value: Aligns varlen backend with established sparse attention patterns.

Metadata

Metadata

Labels

featureNew feature request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions