Skip to content

Commit c49dead

Browse files
authored
[FEATURE SUPPORT] Move scaling out of streaming loops, bias-initialized acc_s, and fix dQ double-scaling
2 parents d883adc + 2c35c89 commit c49dead

File tree

4 files changed

+152
-317
lines changed

4 files changed

+152
-317
lines changed

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 31 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -644,17 +644,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
644644
clear(acc_dv);
645645
if constexpr (Has_bias) { if (accum_dbias) { clear(acc_dbias); } }
646646

647+
cute::cp_async_wait<0>();
648+
__syncthreads();
649+
650+
// Scale K once before streaming loop Q
651+
#pragma unroll
652+
for (int k = 0; k < size(tKsK); ++k) {
653+
tKsK(k) = static_cast<Element>(tKsK(k) * params.scale_softmax);
654+
}
655+
647656
for (; m_block >= m_block_min; --m_block) {
648657
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
649-
cute::cp_async_wait<0>();
650-
__syncthreads();
651-
652658
Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
653659
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (MMA=4, MMA_M, MMA_K)
660+
cute::cp_async_wait<0>();
661+
__syncthreads();
654662

655663
if (any_active) {
656-
clear(acc_s);
664+
if constexpr (Has_bias) {
665+
// Copy bias from smem to registers
666+
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
667+
Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias);
668+
cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);
669+
#pragma unroll
670+
for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); }
671+
} else {
672+
clear(acc_s);
673+
}
674+
}
675+
657676

677+
if (any_active) {
658678
Tensor dP_sum = make_fragment_like(lse);
659679
#pragma unroll
660680
for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }
@@ -686,71 +706,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
686706
FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap);
687707
}
688708

689-
if constexpr (Has_mask && Has_bias) {
690-
// Copy mask and bias from smem to registers
691-
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
692-
Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask);
693-
cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view);
694-
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
695-
Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias);
696-
cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);
697-
698-
// Reshape mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
699-
Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout()));
700-
Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout()));
701-
702-
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
703-
// actual_seqlen_k, because acc_s would be some finite value for those indices.
704-
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
705-
// so the result would still be correct.
706-
// However, it's possible that the values in acc_s are so large that they overflow
707-
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
708-
// So we need to mask out the elements beyond actual_seqlen_k.
709-
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
710-
scores, mask, bias, params.scale_softmax,
711-
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
712-
binfo.actual_seqlen_k,
713-
m_block * kBlockM + get<0>(taccScS_row(0)),
714-
binfo.actual_seqlen_q,
715-
AtomLayoutMS * 16
716-
);
717-
} else if constexpr (Has_mask && !Has_bias) {
709+
if constexpr (Has_mask) {
718710
// Copy mask from smem to registers
719711
Tensor tSrMask = make_tensor<Element>(shape(acc_s));
720712
Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask);
721713
cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view);
722714

723715
// Reshape mask from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
724716
Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout()));
725-
726-
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
727-
scores, mask, /*bias=*/nullptr, params.scale_softmax,
728-
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
729-
binfo.actual_seqlen_k,
730-
m_block * kBlockM + get<0>(taccScS_row(0)),
731-
binfo.actual_seqlen_q,
732-
AtomLayoutMS * 16
733-
);
734-
} else if constexpr (!Has_mask && Has_bias) {
735-
// Copy bias from smem to registers
736-
Tensor tSrBias = make_tensor<Element>(shape(acc_s));
737-
Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias);
738-
cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);
739717

740-
// Reshape bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
741-
Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout()));
742-
743-
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
744-
scores, /*mask=*/nullptr, bias, params.scale_softmax,
718+
FLASH_NAMESPACE::apply_mask<Is_causal, Has_mask>(
719+
scores, mask,
745720
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
746721
binfo.actual_seqlen_k,
747722
m_block * kBlockM + get<0>(taccScS_row(0)),
748723
binfo.actual_seqlen_q,
749724
AtomLayoutMS * 16
750725
);
751726
} else {
752-
FLASH_NAMESPACE::apply_mask</*Causal_mask=*/Is_causal, /*Has_mask=*/Has_mask, /*Has_bias=*/Has_bias>(
753-
scores, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax,
727+
FLASH_NAMESPACE::apply_mask<Is_causal, Has_mask>(
728+
scores, /*mask=*/nullptr,
754729
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
755730
binfo.actual_seqlen_k,
756731
m_block * kBlockM + get<0>(taccScS_row(0)),
@@ -965,8 +940,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
965940
for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); }
966941
}
967942
} else {
968-
#pragma unroll
969-
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
943+
// #pragma unroll
944+
// for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
970945
// Convert acc_dq from fp32 to fp16
971946
Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
972947
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_M, MMA_K)

csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ inline __device__ void convert_dQ(
279279
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
280280
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
281281
}
282-
#pragma unroll
283-
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
282+
// #pragma unroll
283+
// for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
284284
// Convert acc_dq from fp32 to fp16
285285
Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
286286
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_N, MMA_N)

0 commit comments

Comments
 (0)