@@ -644,17 +644,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
644644 clear (acc_dv);
645645 if constexpr (Has_bias) { if (accum_dbias) { clear (acc_dbias); } }
646646
647+ cute::cp_async_wait<0 >();
648+ __syncthreads ();
649+
650+ // Scale K once before streaming loop Q
651+ #pragma unroll
652+ for (int k = 0 ; k < size (tKsK); ++k) {
653+ tKsK (k) = static_cast <Element>(tKsK (k) * params.scale_softmax );
654+ }
655+
647656 for (; m_block >= m_block_min; --m_block) {
648657 Tensor acc_s = partition_fragment_C (tiled_mma_sdp, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA=4, MMA_M, MMA_N)
649- cute::cp_async_wait<0 >();
650- __syncthreads ();
651-
652658 Tensor acc_dp = partition_fragment_C (tiled_mma_sdp, Shape<Int<kBlockM >, Int<kBlockN >>{}); // (MMA=4, MMA_M, MMA_N)
653659 Tensor acc_dq = partition_fragment_C (tiled_mma_dq, Shape<Int<kBlockM >, Int<kHeadDim >>{}); // (MMA=4, MMA_M, MMA_K)
660+ cute::cp_async_wait<0 >();
661+ __syncthreads ();
654662
655663 if (any_active) {
656- clear (acc_s);
664+ if constexpr (Has_bias) {
665+ // Copy bias from smem to registers
666+ Tensor tSrBias = make_tensor<Element>(shape (acc_s));
667+ Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D (tSrBias);
668+ cute::copy (smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);
669+ #pragma unroll
670+ for (int i = 0 ; i < size (acc_s); ++i) { acc_s (i) = tSrBias (i); }
671+ } else {
672+ clear (acc_s);
673+ }
674+ }
675+
657676
677+ if (any_active) {
658678 Tensor dP_sum = make_fragment_like (lse);
659679 #pragma unroll
660680 for (int mi = 0 ; mi < size (lse); ++mi) { dP_sum (mi) = gdPsum (get<0 >(taccScS_row (mi))); }
@@ -686,71 +706,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
686706 FLASH_NAMESPACE::calculate_dtanh (scores, dtanh, params.softcap );
687707 }
688708
689- if constexpr (Has_mask && Has_bias) {
690- // Copy mask and bias from smem to registers
691- Tensor tSrMask = make_tensor<Element>(shape (acc_s));
692- Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D (tSrMask);
693- cute::copy (smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view);
694- Tensor tSrBias = make_tensor<Element>(shape (acc_s));
695- Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D (tSrBias);
696- cute::copy (smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);
697-
698- // Reshape mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
699- Tensor mask = make_tensor (tSrMask.data (), FLASH_NAMESPACE::convert_layout_acc_rowcol (tSrMask.layout ()));
700- Tensor bias = make_tensor (tSrBias.data (), FLASH_NAMESPACE::convert_layout_acc_rowcol (tSrBias.layout ()));
701-
702- // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
703- // actual_seqlen_k, because acc_s would be some finite value for those indices.
704- // In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
705- // so the result would still be correct.
706- // However, it's possible that the values in acc_s are so large that they overflow
707- // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
708- // So we need to mask out the elements beyond actual_seqlen_k.
709- FLASH_NAMESPACE::apply_mask</* Causal_mask=*/ Is_causal, /* Has_mask=*/ Has_mask, /* Has_bias=*/ Has_bias>(
710- scores, mask, bias, params.scale_softmax ,
711- n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16 ,
712- binfo.actual_seqlen_k ,
713- m_block * kBlockM + get<0 >(taccScS_row (0 )),
714- binfo.actual_seqlen_q ,
715- AtomLayoutMS * 16
716- );
717- } else if constexpr (Has_mask && !Has_bias) {
709+ if constexpr (Has_mask) {
718710 // Copy mask from smem to registers
719711 Tensor tSrMask = make_tensor<Element>(shape (acc_s));
720712 Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D (tSrMask);
721713 cute::copy (smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view);
722714
723715 // Reshape mask from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
724716 Tensor mask = make_tensor (tSrMask.data (), FLASH_NAMESPACE::convert_layout_acc_rowcol (tSrMask.layout ()));
725-
726- FLASH_NAMESPACE::apply_mask</* Causal_mask=*/ Is_causal, /* Has_mask=*/ Has_mask, /* Has_bias=*/ Has_bias>(
727- scores, mask, /* bias=*/ nullptr , params.scale_softmax ,
728- n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16 ,
729- binfo.actual_seqlen_k ,
730- m_block * kBlockM + get<0 >(taccScS_row (0 )),
731- binfo.actual_seqlen_q ,
732- AtomLayoutMS * 16
733- );
734- } else if constexpr (!Has_mask && Has_bias) {
735- // Copy bias from smem to registers
736- Tensor tSrBias = make_tensor<Element>(shape (acc_s));
737- Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D (tSrBias);
738- cute::copy (smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view);
739717
740- // Reshape bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N))
741- Tensor bias = make_tensor (tSrBias.data (), FLASH_NAMESPACE::convert_layout_acc_rowcol (tSrBias.layout ()));
742-
743- FLASH_NAMESPACE::apply_mask</* Causal_mask=*/ Is_causal, /* Has_mask=*/ Has_mask, /* Has_bias=*/ Has_bias>(
744- scores, /* mask=*/ nullptr , bias, params.scale_softmax ,
718+ FLASH_NAMESPACE::apply_mask<Is_causal, Has_mask>(
719+ scores, mask,
745720 n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16 ,
746721 binfo.actual_seqlen_k ,
747722 m_block * kBlockM + get<0 >(taccScS_row (0 )),
748723 binfo.actual_seqlen_q ,
749724 AtomLayoutMS * 16
750725 );
751726 } else {
752- FLASH_NAMESPACE::apply_mask</* Causal_mask= */ Is_causal, /* Has_mask= */ Has_mask, /* Has_bias= */ Has_bias >(
753- scores, /* mask=*/ nullptr , /* bias= */ nullptr , params. scale_softmax ,
727+ FLASH_NAMESPACE::apply_mask<Is_causal, Has_mask>(
728+ scores, /* mask=*/ nullptr ,
754729 n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16 ,
755730 binfo.actual_seqlen_k ,
756731 m_block * kBlockM + get<0 >(taccScS_row (0 )),
@@ -965,8 +940,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
965940 for (int i = 0 ; i < size (acc_dq); ++i) { atomicAdd (&tdQgdQaccum (i), acc_dq (i)); }
966941 }
967942 } else {
968- #pragma unroll
969- for (int i = 0 ; i < size (acc_dq); ++i) { acc_dq (i) *= params.scale_softmax ; }
943+ // #pragma unroll
944+ // for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; }
970945 // Convert acc_dq from fp32 to fp16
971946 Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
972947 Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S (rdQ); // ((Atom, AtomNum), MMA_M, MMA_K)
0 commit comments