42
42
#define K_TRANSPOSE_ID 22
43
43
#define dQ_ACCUM_ID 23
44
44
45
- #define FWD_VAR 1
46
- #define BWD_VAR 1
47
-
48
45
#define VIRTUAL_ID 30
49
46
50
47
namespace transformer_engine {
@@ -121,10 +118,8 @@ createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
121
118
.setaMatDesc (qTensor)
122
119
.setbMatDesc (kTransposeTensor )
123
120
.setcMatDesc (sTensor )
124
- #ifdef FWD_VAR
125
121
.setmOverrideDesc (seqlenQTensor)
126
122
.setnOverrideDesc (seqlenKTensor)
127
- #endif
128
123
.setmatmulDesc (matmul_1_Desc)
129
124
.build ();
130
125
@@ -192,19 +187,16 @@ createPaddingMask(int64_t b,
192
187
.setAxis (2 )
193
188
.setComputeType (CUDNN_DATA_FLOAT)
194
189
.build ();
195
- // std::cout << genIndexRowDesc.describe() << std::endl;
196
190
197
191
// Create a gen index Node.
198
192
auto genIndexRow_op = unary_pw_op_create (prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
199
- // std::cout << genIndexRow_op.describe() << std::endl;
200
193
201
194
// Define the gen index for row descriptor
202
195
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder ()
203
196
.setMode (CUDNN_POINTWISE_GEN_INDEX)
204
197
.setAxis (3 )
205
198
.setComputeType (CUDNN_DATA_FLOAT)
206
199
.build ();
207
- // std::cout << genIndexColumnDesc.describe() << std::endl;
208
200
209
201
// Create a gen index Node.
210
202
auto genIndexColumn_op = unary_pw_op_create (prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
@@ -663,10 +655,8 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
663
655
.setaMatDesc (afterScaleDropoutTensor)
664
656
.setbMatDesc (vTensor)
665
657
.setcMatDesc (oTensor)
666
- #ifdef FWD_VAR
667
658
.setmOverrideDesc (seqlenQTensor)
668
659
.setkOverrideDesc (seqlenKTensor)
669
- #endif
670
660
.setmatmulDesc (matmul_2_Desc)
671
661
.build ();
672
662
@@ -726,14 +716,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
726
716
auto sAfterMaskTensor = createCausalMask (
727
717
b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor );
728
718
729
- // std::shared_ptr<cudnn_frontend::Tensor> softmaxInput;
730
- // softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move(sAfterMaskTensor));
731
-
732
- // constexpr bool variable_sequence_length = true;
733
- // if (variable_sequence_length) {
734
719
auto sAfterPaddingMaskTensor = createPaddingMask (b, h, s_q, s_kv, d, layout, tensorType, ops, sAfterMaskTensor );
735
720
auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move (sAfterPaddingMaskTensor ));
736
- // }
737
721
738
722
NVTE_CHECK (dropout_probability != 1 .0f ,
739
723
" Dropout probability cannot be 1.0" );
@@ -815,10 +799,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
815
799
data_ptrs.insert (std::pair<uint64_t , void *>(D_SEED_ID, devPtrDropoutSeed));
816
800
data_ptrs.insert (std::pair<uint64_t , void *>(D_OFFSET_ID, devPtrDropoutOffset));
817
801
data_ptrs.insert (std::pair<uint64_t , void *>(D_CONST_ID, &scale_dropout));
818
- #ifdef FWD_VAR
819
802
data_ptrs.insert (std::pair<uint64_t , void *>(Q_SEQLEN_ID, devActualSeqlenQ));
820
803
data_ptrs.insert (std::pair<uint64_t , void *>(K_SEQLEN_ID, devActualSeqlenK));
821
- #endif
822
804
823
805
// If training mode, we write out softmax stats
824
806
if (is_training) {
@@ -919,12 +901,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
919
901
int64_t scale_dim[4 ] = {1 , 1 , 1 , 1 };
920
902
int64_t scale_stride[4 ] = {1 , 1 , 1 , 1 };
921
903
922
- #ifdef BWD_VAR
923
904
auto seqlenQTensor = tensor_create (CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim,
924
905
seqlen_stride, false , false );
925
906
auto seqlenKTensor = tensor_create (CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim,
926
907
seqlen_stride, false , false );
927
- #endif
928
908
929
909
/* ******************************************************************************
930
910
* Dot product dO * O */
@@ -1007,10 +987,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
1007
987
.setaMatDesc (qTensor)
1008
988
.setbMatDesc (kTransposeTensor )
1009
989
.setcMatDesc (pTensor)
1010
- #ifdef BWD_VAR
1011
990
.setmOverrideDesc (seqlenQTensor)
1012
991
.setnOverrideDesc (seqlenKTensor)
1013
- #endif
1014
992
.setmatmulDesc (matmul_0_Desc)
1015
993
.build ();
1016
994
@@ -1032,17 +1010,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
1032
1010
/* ******************************************************************************
1033
1011
* Causal masking -> pAfterMaskTensor */
1034
1012
1035
- // std::shared_ptr<cudnn_frontend::Tensor> softmaxInput;
1036
1013
auto pAfterMaskTensor = createCausalMask (
1037
1014
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor);
1038
1015
1039
- // softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move(pAfterMaskTensor));
1040
-
1041
- // constexpr bool variable_sequence_length = true;
1042
- // if (variable_sequence_length) {
1043
1016
auto pAfterPaddingMaskTensor = createPaddingMask (b, h, s_q, s_kv, d, layout, tensorType, ops, pAfterMaskTensor);
1044
1017
auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move (pAfterPaddingMaskTensor));
1045
- // }
1046
1018
1047
1019
/* ******************************************************************************
1048
1020
* pAfterMaskTensor - softmaxStats -> pAfterSubtract */
@@ -1128,10 +1100,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
1128
1100
.setaMatDesc (sTransposeTensor )
1129
1101
.setbMatDesc (dOTensor)
1130
1102
.setcMatDesc (dVTensor)
1131
- #ifdef BWD_VAR
1132
1103
.setmOverrideDesc (seqlenKTensor)
1133
1104
.setkOverrideDesc (seqlenQTensor)
1134
- #endif
1135
1105
.setmatmulDesc (matmul_1_Desc)
1136
1106
.build ();
1137
1107
@@ -1157,10 +1127,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
1157
1127
.setaMatDesc (dOTensor)
1158
1128
.setbMatDesc (vTransposeTensor)
1159
1129
.setcMatDesc (dSTensor)
1160
- #ifdef BWD_VAR
1161
1130
.setmOverrideDesc (seqlenQTensor)
1162
1131
.setnOverrideDesc (seqlenKTensor)
1163
- #endif
1164
1132
.setmatmulDesc (matmul_2_Desc)
1165
1133
.build ();
1166
1134
@@ -1268,10 +1236,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
1268
1236
.setaMatDesc (dPScaledTensor)
1269
1237
.setbMatDesc (kTensor )
1270
1238
.setcMatDesc (dqAccumTensor)
1271
- #ifdef BWD_VAR
1272
1239
.setmOverrideDesc (seqlenQTensor)
1273
1240
.setkOverrideDesc (seqlenKTensor)
1274
- #endif
1275
1241
.setmatmulDesc (matmul_3_Desc)
1276
1242
.build ();
1277
1243
@@ -1282,10 +1248,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
1282
1248
.setaMatDesc (dPScaledTensor)
1283
1249
.setbMatDesc (kTensor )
1284
1250
.setcMatDesc (dQTensor)
1285
- #ifdef BWD_VAR
1286
1251
.setmOverrideDesc (seqlenQTensor)
1287
1252
.setkOverrideDesc (seqlenKTensor)
1288
- #endif
1289
1253
.setmatmulDesc (matmul_3_Desc)
1290
1254
.build ();
1291
1255
@@ -1315,10 +1279,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
1315
1279
.setaMatDesc (dPTransposeTensor)
1316
1280
.setbMatDesc (qTensor)
1317
1281
.setcMatDesc (dKTensor)
1318
- #ifdef BWD_VAR
1319
1282
.setmOverrideDesc (seqlenKTensor)
1320
1283
.setkOverrideDesc (seqlenQTensor)
1321
- #endif
1322
1284
.setmatmulDesc (matmul_4_Desc)
1323
1285
.build ();
1324
1286
@@ -1415,10 +1377,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
1415
1377
data_ptrs.insert (std::pair<uint64_t , void *>(D_SEED_ID, devPtrDropoutSeed));
1416
1378
data_ptrs.insert (std::pair<uint64_t , void *>(D_OFFSET_ID, devPtrDropoutOffset));
1417
1379
data_ptrs.insert (std::pair<uint64_t , void *>(MASK_VAL_ID, &negInfinity));
1418
- #ifdef BWD_VAR
1419
1380
data_ptrs.insert (std::pair<uint64_t , void *>(Q_SEQLEN_ID, devActualSeqlenQ));
1420
1381
data_ptrs.insert (std::pair<uint64_t , void *>(K_SEQLEN_ID, devActualSeqlenK));
1421
- #endif
1422
1382
1423
1383
float scaleProb = 1 .0f - dropout_probability;
1424
1384
data_ptrs.insert (std::pair<uint64_t , void *>(D_CONST_ID, &scale_dropout));
0 commit comments