diff --git a/benchmark/test_attention_perf.py b/benchmark/test_attention_perf.py index f13c9901e..bb25c60fa 100644 --- a/benchmark/test_attention_perf.py +++ b/benchmark/test_attention_perf.py @@ -27,16 +27,18 @@ def set_more_shapes(self): flag_gems.device == "musa" or vendor_name == "hygon", reason="RuntimeError" ) @pytest.mark.scaled_dot_product_attention -@pytest.mark.parametrize("dropout_p", [0.0, 0.25]) +@pytest.mark.parametrize("dropout_p", [0.0]) @pytest.mark.parametrize("is_causal", [True, False]) def test_perf_scaled_dot_product_attention(dropout_p, is_causal): def scaled_dot_product_attention_kwargs(shape, dtype, device): query = torch.randn(shape, device=device, dtype=dtype) key = torch.randn(shape, device=device, dtype=dtype) value = torch.randn(shape, device=device, dtype=dtype) - yield query, key, value, dropout_p, is_causal + yield query, key, value, None, dropout_p, is_causal - def sdpa_flash(query, key, value, dropout_p=dropout_p, is_causal=is_causal): + def sdpa_flash( + query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=is_causal + ): from torch.nn.attention import SDPBackend, sdpa_kernel with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): @@ -44,7 +46,7 @@ def sdpa_flash(query, key, value, dropout_p=dropout_p, is_causal=is_causal): query, key, value, - attn_mask=None, + attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, ) @@ -59,6 +61,7 @@ def sdpa_flash(query, key, value, dropout_p=dropout_p, is_causal=is_causal): torch.bfloat16, ], ) + bench.set_gems(flag_gems.scaled_dot_product_attention) bench.run() diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 4abd2152e..6b7dd8df5 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -9,9 +9,11 @@ from flag_gems.ops.argmax import argmax from flag_gems.ops.argmin import argmin from flag_gems.ops.attention import ( + ScaleDotProductAttention, flash_attention_forward, flash_attn_varlen_func, scaled_dot_product_attention, + scaled_dot_product_attention_backward, ) from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward from flag_gems.ops.bitwise_and import ( @@ -391,6 +393,8 @@ "rsqrt", "rsqrt_", "scaled_dot_product_attention", + "scaled_dot_product_attention_backward", + "ScaleDotProductAttention", "scatter", "scatter_", "select_scatter", diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index bed82835b..032fe1a1f 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -1,4 +1,5 @@ import logging +import math from functools import partial import torch @@ -320,10 +321,430 @@ def _attn_fwd( m_i += tl.math.log2(l_i) acc = acc / l_i[:, None] m_ptrs = M + off_hz * Q_CTX + offs_m - tl.store(m_ptrs, m_i) + tl.store(m_ptrs, m_i, mask=q_load_mask) tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None]) +@triton.jit +def _attn_bwd_preprocess( + O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + mask = off_m < Q_CTX + + off_hz = tl.program_id(1) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load( + O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], + mask=mask[:, None], + other=0.0, + ) + do = tl.load( + DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], + mask=mask[:, None], + other=0.0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _attn_bwd_dkdv( + dk, + dv, # + Q, + key, + value, + sm_scale, # + DO, # + M, + D, # + # shared by Q/K/V/DO. + stride_tok, + stride_d, # + H, + Q_CTX, + KV_CTX, + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_DMODEL: tl.constexpr, # + # Filled in by the wrapper. + start_n, + start_m, + num_steps, # + MASK: tl.constexpr, +): + # BLOCK_M1: 32 + # BLOCK_N1: 128 + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, ) + + offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, ) + + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, ) + offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, ) + + qT_ptrs = ( + Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + ) # (BLOCK_DMODEL, BLOCK_M1) + do_ptrs = ( + DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + ) # (BLOCK_M1, BLOCK_DMODEL) + + qT = tl.load( + qT_ptrs, mask=offs_m_mask[None, :], other=0.0 + ) # (BLOCK_DMODEL, BLOCK_M1) + + # Load m before computing qk to reduce pipeline stall. + m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, ) + + # key: (BLOCK_N1, BLOCK_DMODEL) + qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1) + m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1) + m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1) + pT = tl.math.exp2(qkT - m) + # pT = tl.math.exp2(qkT - m[None, :]) + + mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[ + :, None + ] # (BLOCK_N1, BLOCK_M1) + # Autoregressive masking. + if MASK: + mask &= offs_m[None, :] >= offs_n[:, None] + pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1) + + do = tl.load(do_ptrs) + # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL) + + # Compute dV. + dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, ) + + # Compute dP and dS. + dpT = tl.dot(value, tl.trans(do)).to( + tl.float32 + ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1) + dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1) + dsT = dsT.to(qT.dtype) + qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1) + dsT = tl.where( + offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0 + ) # (BLOCK_N1, BLOCK_M1) + dk += tl.dot( + dsT, tl.trans(qT) + ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL) + # Increment pointers. + curr_m += step_m + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq( + dq, + query, + K, + V, # + do, + m, + D, + # shared by Q/K/V/DO. + stride_tok, + stride_d, # + H, + Q_CTX, # + KV_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, + start_n, + num_steps, # + MASK: tl.constexpr, +): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_m_mask = offs_m < Q_CTX + + offs_k = tl.arange(0, BLOCK_DMODEL) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + offs_n = curr_n + tl.arange(0, BLOCK_N2) + offs_n_mask = offs_n < KV_CTX + + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + + kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0) + vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0) + qk = tl.dot(query, kT) + p = tl.math.exp2(qk - m) + mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :] + # Autoregressive masking. + if MASK: + # mask = (offs_m[:, None] >= offs_n[None, :]) + # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :] + mask &= offs_m[:, None] >= offs_n[None, :] + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = tl.where(mask, ds, 0.0).to(kT.dtype) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + return dq + + +@triton.jit +def _attn_bwd( + Q, + K, + V, + sm_scale, # + DO, # + DQ, + DK, + DV, # + M, + D, + # shared by Q/K/V/DO. + stride_z, + stride_h, + stride_tok, + stride_d, # + kv_stride_z, + kv_stride_h, # + H, # query head num + Q_CTX, # + KV_CTX, # + kv_head_num, # + GROUP_HEAD: tl.constexpr, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + BLOCK_DMODEL: tl.constexpr, +): + tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.") + + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * Q_CTX).to(tl.int64) + batch_id = bhid // H + q_head_id = bhid % H + kv_head_id = q_head_id // GROUP_HEAD + adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64) + kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64) + + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += kv_adj + V += kv_adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, BLOCK_DMODEL) + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_n_mask = offs_n < KV_CTX + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + key = tl.load( + K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, + mask=offs_n_mask[:, None], + other=0.0, + ) + value = tl.load( + V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, + mask=offs_n_mask[:, None], + other=0.0, + ) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv( + dk, + dv, # + Q, + key, + value, + sm_scale, # + DO, # + M, + D, # + stride_tok, + stride_d, # + H, + Q_CTX, # + KV_CTX, # + MASK_BLOCK_M1, + BLOCK_N1, + BLOCK_DMODEL, # + start_n, + start_m, + num_steps, # + MASK=True, # + ) + + # Compute dK and dV for non-masked blocks. + start_m += num_steps * MASK_BLOCK_M1 + remaining_m = Q_CTX - start_m + num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1 + + if num_steps > 0 and start_m < Q_CTX: + dk, dv = _attn_bwd_dkdv( # + dk, + dv, # + Q, + key, + value, + sm_scale, # + DO, # + M, + D, # + stride_tok, + stride_d, # + H, + Q_CTX, # + KV_CTX, # + BLOCK_M1, + BLOCK_N1, + BLOCK_DMODEL, # + start_n, + start_m, + num_steps, # + MASK=False, # + ) + # tl.device_print("dv: ", dv) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None]) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None]) + + # THIS BLOCK DOES DQ: + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + start_m = pid * BLOCK_M2 + end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX + num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2 + + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_m_mask = offs_m < Q_CTX + + query = tl.load( + Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, + mask=offs_m_mask[:, None], + other=0.0, + ) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + do = tl.load( + DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, + mask=offs_m_mask[:, None], + other=0.0, + ) + + m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) + m = m[:, None] + + # Stage 1 - Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + + if num_steps > 0: + dq = _attn_bwd_dq( + dq, + query, + K, + V, # + do, + m, + D, # + stride_tok, + stride_d, # + H, + Q_CTX, # + KV_CTX, # + BLOCK_M2, + MASK_BLOCK_N2, + BLOCK_DMODEL, # + start_m, + start_n, + num_steps, # + MASK=True, # + ) + + # Stage 2 - non-masked blocks + stage2_end_n = start_n + stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2 + + if stage2_num_steps > 0: + dq = _attn_bwd_dq( + dq, + query, + K, + V, # + do, + m, + D, # + stride_tok, + stride_d, # + H, + Q_CTX, # + KV_CTX, # + BLOCK_M2, + BLOCK_N2, + BLOCK_DMODEL, # + start_m, + stage2_end_n - stage2_num_steps * BLOCK_N2, + stage2_num_steps, # + MASK=False, # + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + # tl.store(dq_ptrs, dq) + + tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None]) + + def scaled_dot_product_attention( query, key, @@ -334,7 +755,32 @@ def scaled_dot_product_attention( scale=None, enable_gqa=False, ): - logger.debug("GEMS SCALED DOT PRODUCT ATTENTION") + return ScaleDotProductAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + ) + + +def scaled_dot_product_attention_backward( + do, + query, + key, + value, + o, + M, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, +): + logger.debug("GEMS SCALED DOT PRODUCT ATTENTION BACKWARD") # shape constraints HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] # when v is in float8_e5m2 it is transposed. @@ -343,89 +789,239 @@ def scaled_dot_product_attention( assert HEAD_DIM_K in {16, 32, 64, 128, 256} assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" - o = torch.empty_like(query, dtype=value.dtype) - - stage = 3 if is_causal else 1 - if scale is None: sm_scale = 1.0 / (HEAD_DIM_K**0.5) else: sm_scale = scale - q_head_num = query.shape[1] - kv_head_num = key.shape[1] - assert enable_gqa or q_head_num == kv_head_num, ( - f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, " - "enable_gqa must be True to support different head numbers." + assert do.is_contiguous() + assert ( + query.is_contiguous() + and key.is_contiguous() + and value.is_contiguous() + and o.is_contiguous() ) + assert query.stride() == o.stride() == do.stride() + assert key.stride() == value.stride() + + BLOCK_DMODEL = HEAD_DIM_K + BATCH, Q_HEAD, Q_CTX = query.shape[:3] + _, KV_HEAD, KV_CTX = key.shape[:3] + group_head = Q_HEAD // KV_HEAD + + NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + + RCP_LN2 = 1.0 / math.log(2) + + arg_k = key + arg_k = arg_k * (sm_scale * RCP_LN2) + # PRE_BLOCK = 128 + PRE_BLOCK = 256 + + # PRE_BLOCK = 32 + # assert N_CTX % PRE_BLOCK == 0 + # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD) + pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD) + + delta = torch.empty_like(M) - grid = lambda args: ( - triton.cdiv(query.shape[2], args["BLOCK_M"]), - query.shape[0] * query.shape[1], - 1, + # NOTE that dk & dv always have the same number of heads as q + dq = torch.empty_like(query).contiguous() + dk = torch.empty((BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K)).to(key.device).contiguous() + dv = torch.empty((BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V)).to(value.device).contiguous() + + _attn_bwd_preprocess[pre_grid]( + o, + do, # + delta, # + BATCH, + Q_HEAD, + Q_CTX, # + BLOCK_M=PRE_BLOCK, + D_HEAD=BLOCK_DMODEL, # ) - if attn_mask is not None: - HAS_ATTN_MASK = True - if attn_mask.dtype == torch.bool: - attn_mask = attn_mask.to(query.dtype) * -1.0e6 - stride_attn_mask_batch = attn_mask.stride(0) - stride_attn_mask_head = attn_mask.stride(1) - stride_attn_mask_q_seqlen = attn_mask.stride(2) - stride_attn_mask_kv_seqlen = attn_mask.stride(3) - else: - HAS_ATTN_MASK = False - stride_attn_mask_batch = 1 - stride_attn_mask_head = 1 - stride_attn_mask_q_seqlen = 1 - stride_attn_mask_kv_seqlen = 1 - - M = torch.empty( - (query.shape[0], query.shape[1], query.shape[2]), - device=query.device, - dtype=torch.float32, + grid = (triton.cdiv(Q_CTX, BLOCK_N1), 1, BATCH * Q_HEAD) + logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}") + logger.info(f"{M.shape=}") + + _attn_bwd[grid]( + query, + arg_k, + value, + sm_scale, + do, + dq, + dk, + dv, # + M, + delta, # + query.stride(0), + query.stride(1), + query.stride(2), + query.stride(3), # + key.stride(0), + key.stride(1), # + Q_HEAD, + Q_CTX, # + KV_CTX, # + KV_HEAD, # + GROUP_HEAD=group_head, # + BLOCK_M1=BLOCK_M1, + BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, + BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + BLOCK_DMODEL=BLOCK_DMODEL, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES, # ) - with torch_device_fn.device(query.device): - _attn_fwd[grid]( + if group_head > 1: + dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K) + dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V) + dk = dk.sum(dim=2) + dv = dv.sum(dim=2) + + return dq, dk, dv + + +class ScaleDotProductAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + ): + logger.debug("GEMS SCALED DOT PRODUCT ATTENTION") + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = value.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" + + o = torch.empty_like(query, dtype=value.dtype) + + stage = 3 if is_causal else 1 + + if scale is None: + sm_scale = 1.0 / (HEAD_DIM_K**0.5) + else: + sm_scale = scale + + q_head_num = query.shape[1] + kv_head_num = key.shape[1] + assert enable_gqa or q_head_num == kv_head_num, ( + f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, " + "enable_gqa must be True to support different head numbers." + ) + + grid = lambda args: ( + triton.cdiv(query.shape[2], args["BLOCK_M"]), + query.shape[0] * query.shape[1], + 1, + ) + + if attn_mask is not None: + HAS_ATTN_MASK = True + if attn_mask.dtype == torch.bool: + attn_mask = attn_mask.to(query.dtype) * -1.0e6 + stride_attn_mask_batch = attn_mask.stride(0) + stride_attn_mask_head = attn_mask.stride(1) + stride_attn_mask_q_seqlen = attn_mask.stride(2) + stride_attn_mask_kv_seqlen = attn_mask.stride(3) + else: + HAS_ATTN_MASK = False + stride_attn_mask_batch = 1 + stride_attn_mask_head = 1 + stride_attn_mask_q_seqlen = 1 + stride_attn_mask_kv_seqlen = 1 + + M = torch.empty( + (query.shape[0], query.shape[1], query.shape[2]), + device=query.device, + dtype=torch.float32, + ) + + with torch_device_fn.device(query.device): + _attn_fwd[grid]( + query, + key, + value, + attn_mask, + sm_scale, + M, + o, # + query.stride(0), + query.stride(1), + query.stride(2), + query.stride(3), # + key.stride(0), + key.stride(1), + key.stride(2), + key.stride(3), # + value.stride(0), + value.stride(1), + value.stride(2), + value.stride(3), # + stride_attn_mask_batch, + stride_attn_mask_head, + stride_attn_mask_q_seqlen, + stride_attn_mask_kv_seqlen, # + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), # + query.shape[0], + q_head_num, + kv_head_num, # + q_head_num // kv_head_num, # group_head + query.shape[2], # + key.shape[2], # + HEAD_DIM_K, # + STAGE=stage, # + HAS_ATTN_MASK=HAS_ATTN_MASK, # + ) + + ctx.save_for_backward(query, key, value, o, M) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = HEAD_DIM_K + ctx.causal = is_causal + ctx.enable_gqa = enable_gqa + return o + + @staticmethod + def backward(ctx, do): + query, key, value, o, M = ctx.saved_tensors + is_causal = ctx.causal + enable_gqa = ctx.enable_gqa + sm_scale = ctx.sm_scale + dq, dk, dv = scaled_dot_product_attention_backward( + do, query, key, value, - attn_mask, - sm_scale, + o, M, - o, # - query.stride(0), - query.stride(1), - query.stride(2), - query.stride(3), # - key.stride(0), - key.stride(1), - key.stride(2), - key.stride(3), # - value.stride(0), - value.stride(1), - value.stride(2), - value.stride(3), # - stride_attn_mask_batch, - stride_attn_mask_head, - stride_attn_mask_q_seqlen, - stride_attn_mask_kv_seqlen, # - o.stride(0), - o.stride(1), - o.stride(2), - o.stride(3), # - query.shape[0], - q_head_num, - kv_head_num, # - q_head_num // kv_head_num, # group_head - query.shape[2], # - key.shape[2], # - HEAD_DIM_K, # - STAGE=stage, # - HAS_ATTN_MASK=HAS_ATTN_MASK, # + attn_mask=None, + dropout_p=0.0, + is_causal=is_causal, + scale=sm_scale, + enable_gqa=enable_gqa, ) - return o + return dq, dk, dv, None, None, None, None, None def flash_attention_forward( diff --git a/tests/test_attention_ops.py b/tests/test_attention_ops.py index 7b9c4226b..d68a2a244 100644 --- a/tests/test_attention_ops.py +++ b/tests/test_attention_ops.py @@ -18,7 +18,15 @@ def make_input( - batch, num_head, num_head_k, q_seq_len, kv_seq_len, head_size, dtype, device + batch, + num_head, + num_head_k, + q_seq_len, + kv_seq_len, + head_size, + dtype, + device, + requires_grad=False, ): set_philox_state(1234567890, 0, device) q_shape = (batch, num_head, q_seq_len, head_size) @@ -26,6 +34,10 @@ def make_input( q = torch.empty(q_shape, dtype=dtype, device=device).uniform_(-0.05, 0.05) k = torch.empty(kv_shape, dtype=dtype, device=device).uniform_(-0.05, 0.05) v = torch.empty(kv_shape, dtype=dtype, device=device).uniform_(-0.05, 0.05) + if requires_grad: + q.requires_grad_() + k.requires_grad_() + v.requires_grad_() return q, k, v @@ -314,12 +326,102 @@ def test_sdpa_legacy( ): device = torch_device_fn.current_device() q, k, v = make_input( - batch, num_q_head, num_kv_head, q_seq_len, kv_seq_len, head_size, dtype, device + batch, + num_q_head, + num_kv_head, + q_seq_len, + kv_seq_len, + head_size, + dtype, + device, + requires_grad=True, + ) + ref_q = to_reference(q, False) + ref_k = to_reference(k, False) + ref_v = to_reference(v, False) + scale = float(1.0 / np.sqrt(head_size)) + + # forward + torch_result = torch_sdpa( + ref_q, ref_k, ref_v, scale, is_causal, enable_gqa=enable_gqa + ) + + gems_result = flag_gems.ops.scaled_dot_product_attention( + q, k, v, attn_mask=None, scale=scale, is_causal=is_causal, enable_gqa=enable_gqa + ) + + gems_assert_close(gems_result, torch_result, dtype) + + +@pytest.mark.skipif(flag_gems.vendor_name == "metax", reason="TODOFIX") +@pytest.mark.skipif(flag_gems.vendor_name == "hygon", reason="RuntimeError") +@pytest.mark.skipif(flag_gems.vendor_name == "kunlunxin", reason="RESULT TODOFIX") +@pytest.mark.skipif( + torch.__version__ < "2.5", reason="Low Pytorch Version: enable_gqa not supported" +) +@pytest.mark.scaled_dot_product_attention_backward +@pytest.mark.parametrize( + "batch, num_q_head, num_kv_head, q_seq_len, kv_seq_len, head_size, enable_gqa", + [ + (4, 8, 8, 1024, 1024, 64, False), + (4, 8, 8, 1024, 1024, 128, False), + (4, 8, 8, 2048, 256, 64, False), + (4, 8, 8, 2048, 256, 128, False), + (4, 8, 8, 17, 1030, 64, False), + (4, 8, 8, 17, 1030, 128, False), + # adopted from FlagAttention `test_attention_fwd`: + (2, 4, 4, 512, 612, 128, False), + (2, 4, 4, 1024, 1034, 64, False), + (2, 4, 4, 2048, 2048, 32, False), + (2, 4, 4, 4096, 4096, 16, False), + (2, 4, 4, 4001, 4001, 32, False), + (2, 4, 4, 4001, 4096, 64, False), + (2, 4, 4, 4096, 4000, 128, False), + (1, 2, 2, 8192, 8202, 16, False), + (1, 2, 2, 8192, 8192, 32, False), + # test for mqa/gqa + (2, 4, 2, 512, 612, 128, True), + (2, 4, 1, 1024, 1034, 64, True), + (2, 4, 2, 2048, 2048, 32, True), + (2, 4, 1, 4096, 4096, 16, True), + (2, 4, 2, 4001, 4001, 32, True), + (2, 4, 1, 4001, 4096, 64, True), + (2, 4, 2, 4096, 4000, 128, True), + (1, 2, 1, 8192, 8202, 16, True), + (1, 2, 1, 8192, 8192, 32, True), + ], +) +@pytest.mark.parametrize("is_causal", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_sdpa_legacy_backward( + batch, + num_q_head, + num_kv_head, + q_seq_len, + kv_seq_len, + head_size, + is_causal, + dtype, + enable_gqa, +): + device = torch_device_fn.current_device() + q, k, v = make_input( + batch, + num_q_head, + num_kv_head, + q_seq_len, + kv_seq_len, + head_size, + dtype, + device, + requires_grad=True, ) ref_q = to_reference(q, False) ref_k = to_reference(k, False) ref_v = to_reference(v, False) scale = float(1.0 / np.sqrt(head_size)) + + # forward torch_result = torch_sdpa( ref_q, ref_k, ref_v, scale, is_causal, enable_gqa=enable_gqa ) @@ -330,6 +432,22 @@ def test_sdpa_legacy( gems_assert_close(gems_result, torch_result, dtype) + # backward + dout = torch.randn_like(ref_q) + torch_result.backward(dout) + gems_result.backward(dout) + torch_q_grad = ref_q.grad.clone() if ref_q.grad is not None else None + torch_k_grad = ref_k.grad.clone() if ref_k.grad is not None else None + torch_v_grad = ref_v.grad.clone() if ref_v.grad is not None else None + gems_q_grad = q.grad.clone() if q.grad is not None else None + gems_k_grad = k.grad.clone() if k.grad is not None else None + gems_v_grad = v.grad.clone() if v.grad is not None else None + + # NOTE: NaN may arise in the gradients, this behavior aligns with PyTorch's SDPA + gems_assert_close(gems_q_grad, torch_q_grad, dtype, equal_nan=True) + gems_assert_close(gems_k_grad, torch_k_grad, dtype, equal_nan=True) + gems_assert_close(gems_v_grad, torch_v_grad, dtype, equal_nan=True) + @pytest.mark.skipif(flag_gems.vendor_name == "metax", reason="TODOFIX") @pytest.mark.skipif(flag_gems.vendor_name == "hygon", reason="RuntimeError")