-
Notifications
You must be signed in to change notification settings - Fork 39
[FEATURE SUPPORT] Move scaling out of streaming loops, bias-initialized acc_s, and fix dQ double-scaling #203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Removes bias and scaling handling from the mask helper to reduce specialization paths and rely solely on masking behavior.
Rationalizes accumulator setup so bias kernels reuse shared-memory bias instead of clearing registers, trimming sync overhead. Simplifies mask application templates to drop unused bias handling, tightening specialization footprint.
Pre-scales query tiles before streaming to cut redundant softmax multiplications. Initializes accumulators from shared bias when active so mask paths can skip extra clears. Simplifies mask application by dropping per-iteration bias scaling logic.
Provides an inverse scaling factor so kernels can reuse precomputed softmax adjustments instead of recomputing them
Ensures the reciprocal scale is always populated so downstream kernels can undo the softmax amplification without branching.
Drops bias and scale handling from the masking helper so upstream code owns those adjustments, preventing duplicated math.
Waits for outstanding async loads and syncs threads so Q scaling never races ahead of shared-memory tiles.
Prevents applying the softmax factor twice in the backward preprocessing so downstream gradients stay correctly scaled.
Pre-scales the keys right after synchronization so later matmul steps reuse the scaled values and hide latency. Unifies the mask and bias hydration before streaming to keep accumulators coherent and drops the now redundant gradient scaling.
Removes the unused reverse scaling parameter from the forward configuration to avoid stale values when softcap toggles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the flash attention implementation to move scaling operations earlier in the computation pipeline and removes bias handling from the masking logic. The key changes include:
- Moving Q and K scaling from the masking phase to before the streaming loops
- Simplifying
apply_maskby removing bias and scale_softmax parameters - Initializing attention score accumulators with bias values when present, instead of applying bias during masking
- Commenting out the final dQ scaling in backward passes
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| csrc/flash_dmattn/src/mask.h | Removed bias and scale_softmax parameters from apply_mask functions, simplified masking logic by removing redundant conditional branches |
| csrc/flash_dmattn/src/flash_fwd_kernel.h | Added Q scaling before streaming loops, moved bias application to accumulator initialization, updated apply_mask calls |
| csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h | Commented out final dQ scaling operation |
| csrc/flash_dmattn/src/flash_bwd_kernel.h | Added K scaling before streaming loops, moved bias application to accumulator initialization, commented out final dQ scaling, updated apply_mask calls |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for (int i = 0; i < size(tSrQ); ++i) { | ||
| tSsQ(i) = static_cast<Element>(tSsQ(i) * params.scale_softmax); |
Copilot
AI
Nov 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop iterates over size(tSrQ) but modifies tSsQ(i). This is inconsistent and likely incorrect. Based on the non-splitkv version (lines 429-439), there should be a conditional check for Kernel_traits::Is_Q_in_regs to determine whether to use tSrQ or tSsQ. Either add the conditional check or ensure both the loop bound and the tensor being modified are consistent.
| for (int i = 0; i < size(tSrQ); ++i) { | |
| tSsQ(i) = static_cast<Element>(tSsQ(i) * params.scale_softmax); | |
| if (Kernel_traits::Is_Q_in_regs) { | |
| for (int i = 0; i < size(tSrQ); ++i) { | |
| tSrQ(i) = static_cast<Element>(tSrQ(i) * params.scale_softmax); | |
| } | |
| } else { | |
| for (int i = 0; i < size(tSsQ); ++i) { | |
| tSsQ(i) = static_cast<Element>(tSsQ(i) * params.scale_softmax); | |
| } |
|
|
||
|
|
Copilot
AI
Nov 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty line with trailing whitespace. Remove the trailing whitespace for consistency.
Summary
Motivation: reduce math inside hot loops for better perf, fix backward dQ correctness, and simplify masking semantics.
Design
Forward kernels:
acc_sfrom bias when present to avoid +bias in-loop.Backward kernels:
acc_dq *= params.scale_softmaxat dQ write-out (main kernel) and in seq-k pathconvert_dQso dQ is not double-scaled.Masking:
Alternatives considered: leaving scaling inside loops (more math per iteration), or keeping bias/scaling in the mask helper (more template paths, duplicated responsibilities). The chosen approach centralizes responsibilities and reduces inner-loop work.
Changes
Internal kernels:
flash_fwd_kernel.h.flash_bwd_kernel.handflash_bwd_preprocess_kernel.h.Mask API (internal-only):
apply_masktemplate signature to accept only mask (no bias/scale), and perform masking only inmask.h.Public Python API remains unchanged.
Implementation Notes
cp_async_waitand__syncthreads()before pre-scaling Q/K when needed to avoid races with async smem preloads.acc_sonce per tile; removes repeated +bias in the streaming loop.Tests
Functional equivalence:
convert_dQ.Benchmarks (provided):
Forward (before vs. after):
Backward (before vs. after):
Notes:
Docs
flash_fwd_kernel.hflash_bwd_kernel.hflash_bwd_preprocess_kernel.hmask.hChecklist