Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • Implements pre-scaling of operands to remove per-iteration scaling in streaming attention:
    • Forward: scale Q once before the KV streaming loop.
    • Backward: scale K once before the Q streaming loop.
  • Initializes attention score accumulators from bias, removing the per-iteration “+ bias”.
  • Fixes dQ double-scaling in backward (both main kernel and seq-k preprocess).
  • Simplifies masking utilities to only perform masking (no scale/bias), reducing specialization and duplicated math.

Motivation: reduce math inside hot loops for better perf, fix backward dQ correctness, and simplify masking semantics.

Design

  • Forward kernels:

    • Pre-scale Q by params.scale_softmax before the KV loop in flash_fwd_kernel.h.
    • Initialize acc_s from bias when present to avoid +bias in-loop.
    • Add cp.async waits and barriers before touching tiles to prevent races when pre-scaling Q.
  • Backward kernels:

    • Pre-scale K by params.scale_softmax once before streaming Q in flash_bwd_kernel.h.
    • Remove the extra acc_dq *= params.scale_softmax at dQ write-out (main kernel) and in seq-k path convert_dQ so dQ is not double-scaled.
    • Keep the single scaling for dK as intended.
  • Masking:

    • Collapse mask helpers to only apply masking; no longer mix in bias or scaling inside mask.h.
    • Call sites pass only mask tensors; bias is handled via acc_s initialization, and scaling is handled by pre-scaling Q/K.

Alternatives considered: leaving scaling inside loops (more math per iteration), or keeping bias/scaling in the mask helper (more template paths, duplicated responsibilities). The chosen approach centralizes responsibilities and reduces inner-loop work.

Changes

  • Internal kernels:

    • Forward: Q pre-scaling, acc_s bias initialization, and syncs to guard smem vs. regs paths in flash_fwd_kernel.h.
    • Backward: K pre-scaling; remove dQ post-scale in both the main kernel and seq-k preprocess in flash_bwd_kernel.h and flash_bwd_preprocess_kernel.h.
  • Mask API (internal-only):

    • Simplify apply_mask template signature to accept only mask (no bias/scale), and perform masking only in mask.h.
    • Update all call sites accordingly.

Public Python API remains unchanged.

Implementation Notes

  • Correctness: Fixes dQ being scaled twice when K is pre-scaled in backward.
  • Synchronization: Adds cp_async_wait and __syncthreads() before pre-scaling Q/K when needed to avoid races with async smem preloads.
  • Bias handling: Uses bias to initialize acc_s once per tile; removes repeated +bias in the streaming loop.
  • dK scaling: Kept as a single multiply outside inner loops; not double-applied.
  • Edge cases: Causal masking still honored via simplified mask helper; behavior is unchanged aside from moving scale/bias out of the helper.

Tests

  • Functional equivalence:

    • Forward and backward equivalence verified across typical shapes; dQ mismatch caused by duplicate scaling is resolved with the removals in the main backward kernel and seq-k convert_dQ.
  • Benchmarks (provided):

    • Forward (before vs. after):

      📊 Benchmark Results (averaged over 3 runs):
      🔧 Configuration                                                ⚡ SDPA       🚀 CUDA       🌟 Triton     🌟 Flex            📈 Speedup
      🔄------------------------------------------------------------------------------------------------------------------------------------------------------🔄
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q256 K256 D64 W1024 C               0.45         0.23         N/A          N/A                CUDA:1.92x     
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q512 K512 D64 W1024 C               0.43         0.22         N/A          N/A                CUDA:1.99x     
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q1024 K1024 D64 W1024 C             0.53         0.13         N/A          N/A                CUDA:4.00x     
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q2048 K2048 D64 W1024 C             0.62         0.14         N/A          N/A                CUDA:4.30x     
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q4096 K4096 D64 W1024 C             2.58         0.20         N/A          N/A                CUDA:12.87x    
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q8192 K8192 D64 W1024 C             8.39         0.43         N/A          N/A                CUDA:19.30x    
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q16384 K16384 D64 W1024 C           32.19        1.66         N/A          N/A                CUDA:19.35x   
      🔄------------------------------------------------------------------------------------------------------------------------------------------------------🔄
      
      📊 Benchmark Results (averaged over 3 runs):
      🔧 Configuration                                                ⚡ SDPA       🚀 CUDA       🌟 Triton     🌟 Flex            📈 Speedup
      🔄------------------------------------------------------------------------------------------------------------------------------------------------------🔄
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q256 K256 D64 W1024 C               0.79         0.26         N/A          N/A                CUDA:3.07x     
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q512 K512 D64 W1024 C               0.44         0.12         N/A          N/A                CUDA:3.55x     
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q1024 K1024 D64 W1024 C             0.79         0.15         N/A          N/A                CUDA:5.40x     
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q2048 K2048 D64 W1024 C             0.60         0.20         N/A          N/A                CUDA:3.08x     
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q4096 K4096 D64 W1024 C             2.54         0.24         N/A          N/A                CUDA:10.43x    
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q8192 K8192 D64 W1024 C             8.56         0.45         N/A          N/A                CUDA:19.03x    
       ✅  ✅  ⚠️  ⚠️   B1 Hq2 Hkv1 Q16384 K16384 D64 W1024 C           32.20        1.43         N/A          N/A                CUDA:22.59x    
      🔄------------------------------------------------------------------------------------------------------------------------------------------------------🔄
      
    • Backward (before vs. after):

      📊 Backward Pass Benchmark Results (averaged over 3 runs):
      🔧 Configuration                                                ⚡ SDPA-BWD     🚀 CUDA-BWD     🌟 Triton-BWD   ✨ Flex-BWD        📈 Speedup
      🔄----------------------------------------------------------------------------------------------------------------------------------------------------------------🔄 
      📊 B1 Hq2 Hkv1 Q256 K256 D128 W1024 C                           ⚡ 0.92ms       🚀 0.64ms       🌟 N/A          ✨ N/A             📈 CUDA: 1.4x     
      📊 B1 Hq2 Hkv1 Q512 K512 D128 W1024 C                           ⚡ 1.13ms       🚀 0.76ms       🌟 N/A          ✨ N/A             📈 CUDA: 1.5x     
      📊 B1 Hq2 Hkv1 Q1024 K1024 D128 W1024 C                         ⚡ 1.24ms       🚀 0.90ms       🌟 N/A          ✨ N/A             📈 CUDA: 1.4x     
      📊 B1 Hq2 Hkv1 Q2048 K2048 D128 W1024 C                         ⚡ 1.40ms       🚀 1.07ms       🌟 N/A          ✨ N/A             📈 CUDA: 1.3x     
      📊 B1 Hq2 Hkv1 Q4096 K4096 D128 W1024 C                         ⚡ 3.51ms       🚀 1.40ms       🌟 N/A          ✨ N/A             📈 CUDA: 2.5x     
      📊 B1 Hq2 Hkv1 Q8192 K8192 D128 W1024 C                         ⚡ 11.16ms      🚀 2.63ms       🌟 N/A          ✨ N/A             📈 CUDA: 4.2x     
      📊 B1 Hq2 Hkv1 Q16384 K16384 D128 W1024 C                       ⚡ 42.26ms      🚀 15.13ms      🌟 N/A          ✨ N/A             📈 CUDA: 2.8x     
      🔄----------------------------------------------------------------------------------------------------------------------------------------------------------------🔄 
      
      📊 Backward Pass Benchmark Results (averaged over 3 runs):
      🔧 Configuration                                                ⚡ SDPA-BWD     🚀 CUDA-BWD     🌟 Triton-BWD   ✨ Flex-BWD        📈 Speedup
      🔄----------------------------------------------------------------------------------------------------------------------------------------------------------------🔄 
      📊 B1 Hq2 Hkv1 Q256 K256 D128 W1024 C                           ⚡ 0.77ms       🚀 0.60ms       🌟 N/A          ✨ N/A             📈 CUDA: 1.3x     
      📊 B1 Hq2 Hkv1 Q512 K512 D128 W1024 C                           ⚡ 0.79ms       🚀 0.65ms       🌟 N/A          ✨ N/A             📈 CUDA: 1.2x     
      📊 B1 Hq2 Hkv1 Q1024 K1024 D128 W1024 C                         ⚡ 0.81ms       🚀 0.69ms       🌟 N/A          ✨ N/A             📈 CUDA: 1.2x     
      📊 B1 Hq2 Hkv1 Q2048 K2048 D128 W1024 C                         ⚡ 1.14ms       🚀 1.11ms       🌟 N/A          ✨ N/A             📈 CUDA: 1.0x     
      📊 B1 Hq2 Hkv1 Q4096 K4096 D128 W1024 C                         ⚡ 3.26ms       🚀 1.42ms       🌟 N/A          ✨ N/A             📈 CUDA: 2.3x     
      📊 B1 Hq2 Hkv1 Q8192 K8192 D128 W1024 C                         ⚡ 10.90ms      🚀 2.39ms       🌟 N/A          ✨ N/A             📈 CUDA: 4.6x     
      📊 B1 Hq2 Hkv1 Q16384 K16384 D128 W1024 C                       ⚡ 42.36ms      🚀 15.02ms      🌟 N/A          ✨ N/A             📈 CUDA: 2.8x     
      🔄----------------------------------------------------------------------------------------------------------------------------------------------------------------🔄 
      

Notes:

  • Forward perf is broadly comparable, with clear wins on some shapes (e.g., Q512, Q16384) and small regressions on a few small shapes due to added synchronization and bias-init movement. Backward sees consistent small improvements at most sizes.
  • Gradient correctness: dQ scaling now matches the pre-scaled path; prior double-scaling is resolved.

Docs

Checklist

Removes bias and scaling handling from the mask helper to reduce specialization paths and rely solely on masking behavior.
Rationalizes accumulator setup so bias kernels reuse shared-memory bias instead of clearing registers, trimming sync overhead.
Simplifies mask application templates to drop unused bias handling, tightening specialization footprint.
Pre-scales query tiles before streaming to cut redundant softmax multiplications.
Initializes accumulators from shared bias when active so mask paths can skip extra clears.
Simplifies mask application by dropping per-iteration bias scaling logic.
Provides an inverse scaling factor so kernels can reuse precomputed softmax adjustments instead of recomputing them
Ensures the reciprocal scale is always populated so downstream kernels can undo the softmax amplification without branching.
Drops bias and scale handling from the masking helper so upstream code owns those adjustments, preventing duplicated math.
Waits for outstanding async loads and syncs threads so Q scaling never races ahead of shared-memory tiles.
Prevents applying the softmax factor twice in the backward preprocessing so downstream gradients stay correctly scaled.
Pre-scales the keys right after synchronization so later matmul steps reuse the scaled values and hide latency. Unifies the mask and bias hydration before streaming to keep accumulators coherent and drops the now redundant gradient scaling.
Removes the unused reverse scaling parameter from the forward configuration to avoid stale values when softcap toggles.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the flash attention implementation to move scaling operations earlier in the computation pipeline and removes bias handling from the masking logic. The key changes include:

  • Moving Q and K scaling from the masking phase to before the streaming loops
  • Simplifying apply_mask by removing bias and scale_softmax parameters
  • Initializing attention score accumulators with bias values when present, instead of applying bias during masking
  • Commenting out the final dQ scaling in backward passes

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
csrc/flash_dmattn/src/mask.h Removed bias and scale_softmax parameters from apply_mask functions, simplified masking logic by removing redundant conditional branches
csrc/flash_dmattn/src/flash_fwd_kernel.h Added Q scaling before streaming loops, moved bias application to accumulator initialization, updated apply_mask calls
csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h Commented out final dQ scaling operation
csrc/flash_dmattn/src/flash_bwd_kernel.h Added K scaling before streaming loops, moved bias application to accumulator initialization, commented out final dQ scaling, updated apply_mask calls

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1151 to +1152
for (int i = 0; i < size(tSrQ); ++i) {
tSsQ(i) = static_cast<Element>(tSsQ(i) * params.scale_softmax);
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop iterates over size(tSrQ) but modifies tSsQ(i). This is inconsistent and likely incorrect. Based on the non-splitkv version (lines 429-439), there should be a conditional check for Kernel_traits::Is_Q_in_regs to determine whether to use tSrQ or tSsQ. Either add the conditional check or ensure both the loop bound and the tensor being modified are consistent.

Suggested change
for (int i = 0; i < size(tSrQ); ++i) {
tSsQ(i) = static_cast<Element>(tSsQ(i) * params.scale_softmax);
if (Kernel_traits::Is_Q_in_regs) {
for (int i = 0; i < size(tSrQ); ++i) {
tSrQ(i) = static_cast<Element>(tSrQ(i) * params.scale_softmax);
}
} else {
for (int i = 0; i < size(tSsQ); ++i) {
tSsQ(i) = static_cast<Element>(tSsQ(i) * params.scale_softmax);
}

Copilot uses AI. Check for mistakes.
Comment on lines +675 to 676


Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty line with trailing whitespace. Remove the trailing whitespace for consistency.

Suggested change

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit c49dead into main Nov 4, 2025
7 checks passed
@LoserCheems LoserCheems deleted the fix-202 branch November 4, 2025 08:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE REQUEST] fuse softmax scale into operands and bias-first accumulation

9 participants