Skip to content

Commit 80aa2ae

Browse files
committed
Remove unnecessary code
Signed-off-by: Reese Wang <[email protected]>
1 parent 8c0e80b commit 80aa2ae

File tree

1 file changed

+0
-40
lines changed

1 file changed

+0
-40
lines changed

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@
4242
#define K_TRANSPOSE_ID 22
4343
#define dQ_ACCUM_ID 23
4444

45-
#define FWD_VAR 1
46-
#define BWD_VAR 1
47-
4845
#define VIRTUAL_ID 30
4946

5047
namespace transformer_engine {
@@ -121,10 +118,8 @@ createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
121118
.setaMatDesc(qTensor)
122119
.setbMatDesc(kTransposeTensor)
123120
.setcMatDesc(sTensor)
124-
#ifdef FWD_VAR
125121
.setmOverrideDesc(seqlenQTensor)
126122
.setnOverrideDesc(seqlenKTensor)
127-
#endif
128123
.setmatmulDesc(matmul_1_Desc)
129124
.build();
130125

@@ -192,19 +187,16 @@ createPaddingMask(int64_t b,
192187
.setAxis(2)
193188
.setComputeType(CUDNN_DATA_FLOAT)
194189
.build();
195-
// std::cout << genIndexRowDesc.describe() << std::endl;
196190

197191
// Create a gen index Node.
198192
auto genIndexRow_op = unary_pw_op_create(prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
199-
// std::cout << genIndexRow_op.describe() << std::endl;
200193

201194
// Define the gen index for row descriptor
202195
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder()
203196
.setMode(CUDNN_POINTWISE_GEN_INDEX)
204197
.setAxis(3)
205198
.setComputeType(CUDNN_DATA_FLOAT)
206199
.build();
207-
// std::cout << genIndexColumnDesc.describe() << std::endl;
208200

209201
// Create a gen index Node.
210202
auto genIndexColumn_op = unary_pw_op_create(prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
@@ -663,10 +655,8 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
663655
.setaMatDesc(afterScaleDropoutTensor)
664656
.setbMatDesc(vTensor)
665657
.setcMatDesc(oTensor)
666-
#ifdef FWD_VAR
667658
.setmOverrideDesc(seqlenQTensor)
668659
.setkOverrideDesc(seqlenKTensor)
669-
#endif
670660
.setmatmulDesc(matmul_2_Desc)
671661
.build();
672662

@@ -726,14 +716,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
726716
auto sAfterMaskTensor = createCausalMask(
727717
b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor);
728718

729-
// std::shared_ptr<cudnn_frontend::Tensor> softmaxInput;
730-
// softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move(sAfterMaskTensor));
731-
732-
// constexpr bool variable_sequence_length = true;
733-
// if (variable_sequence_length) {
734719
auto sAfterPaddingMaskTensor = createPaddingMask(b, h, s_q, s_kv, d, layout, tensorType, ops, sAfterMaskTensor);
735720
auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move(sAfterPaddingMaskTensor));
736-
// }
737721

738722
NVTE_CHECK(dropout_probability != 1.0f,
739723
"Dropout probability cannot be 1.0");
@@ -815,10 +799,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
815799
data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed));
816800
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
817801
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout));
818-
#ifdef FWD_VAR
819802
data_ptrs.insert(std::pair<uint64_t, void*>(Q_SEQLEN_ID, devActualSeqlenQ));
820803
data_ptrs.insert(std::pair<uint64_t, void*>(K_SEQLEN_ID, devActualSeqlenK));
821-
#endif
822804

823805
// If training mode, we write out softmax stats
824806
if (is_training) {
@@ -919,12 +901,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
919901
int64_t scale_dim[4] = {1, 1, 1, 1};
920902
int64_t scale_stride[4] = {1, 1, 1, 1};
921903

922-
#ifdef BWD_VAR
923904
auto seqlenQTensor = tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim,
924905
seqlen_stride, false, false);
925906
auto seqlenKTensor = tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim,
926907
seqlen_stride, false, false);
927-
#endif
928908

929909
/*******************************************************************************
930910
* Dot product dO * O */
@@ -1007,10 +987,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
1007987
.setaMatDesc(qTensor)
1008988
.setbMatDesc(kTransposeTensor)
1009989
.setcMatDesc(pTensor)
1010-
#ifdef BWD_VAR
1011990
.setmOverrideDesc(seqlenQTensor)
1012991
.setnOverrideDesc(seqlenKTensor)
1013-
#endif
1014992
.setmatmulDesc(matmul_0_Desc)
1015993
.build();
1016994

@@ -1032,17 +1010,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
10321010
/*******************************************************************************
10331011
* Causal masking -> pAfterMaskTensor */
10341012

1035-
// std::shared_ptr<cudnn_frontend::Tensor> softmaxInput;
10361013
auto pAfterMaskTensor = createCausalMask(
10371014
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor);
10381015

1039-
// softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move(pAfterMaskTensor));
1040-
1041-
// constexpr bool variable_sequence_length = true;
1042-
// if (variable_sequence_length) {
10431016
auto pAfterPaddingMaskTensor = createPaddingMask(b, h, s_q, s_kv, d, layout, tensorType, ops, pAfterMaskTensor);
10441017
auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move(pAfterPaddingMaskTensor));
1045-
// }
10461018

10471019
/*******************************************************************************
10481020
* pAfterMaskTensor - softmaxStats -> pAfterSubtract */
@@ -1128,10 +1100,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
11281100
.setaMatDesc(sTransposeTensor)
11291101
.setbMatDesc(dOTensor)
11301102
.setcMatDesc(dVTensor)
1131-
#ifdef BWD_VAR
11321103
.setmOverrideDesc(seqlenKTensor)
11331104
.setkOverrideDesc(seqlenQTensor)
1134-
#endif
11351105
.setmatmulDesc(matmul_1_Desc)
11361106
.build();
11371107

@@ -1157,10 +1127,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
11571127
.setaMatDesc(dOTensor)
11581128
.setbMatDesc(vTransposeTensor)
11591129
.setcMatDesc(dSTensor)
1160-
#ifdef BWD_VAR
11611130
.setmOverrideDesc(seqlenQTensor)
11621131
.setnOverrideDesc(seqlenKTensor)
1163-
#endif
11641132
.setmatmulDesc(matmul_2_Desc)
11651133
.build();
11661134

@@ -1268,10 +1236,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
12681236
.setaMatDesc(dPScaledTensor)
12691237
.setbMatDesc(kTensor)
12701238
.setcMatDesc(dqAccumTensor)
1271-
#ifdef BWD_VAR
12721239
.setmOverrideDesc(seqlenQTensor)
12731240
.setkOverrideDesc(seqlenKTensor)
1274-
#endif
12751241
.setmatmulDesc(matmul_3_Desc)
12761242
.build();
12771243

@@ -1282,10 +1248,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
12821248
.setaMatDesc(dPScaledTensor)
12831249
.setbMatDesc(kTensor)
12841250
.setcMatDesc(dQTensor)
1285-
#ifdef BWD_VAR
12861251
.setmOverrideDesc(seqlenQTensor)
12871252
.setkOverrideDesc(seqlenKTensor)
1288-
#endif
12891253
.setmatmulDesc(matmul_3_Desc)
12901254
.build();
12911255

@@ -1315,10 +1279,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
13151279
.setaMatDesc(dPTransposeTensor)
13161280
.setbMatDesc(qTensor)
13171281
.setcMatDesc(dKTensor)
1318-
#ifdef BWD_VAR
13191282
.setmOverrideDesc(seqlenKTensor)
13201283
.setkOverrideDesc(seqlenQTensor)
1321-
#endif
13221284
.setmatmulDesc(matmul_4_Desc)
13231285
.build();
13241286

@@ -1415,10 +1377,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
14151377
data_ptrs.insert(std::pair<uint64_t, void*>(D_SEED_ID, devPtrDropoutSeed));
14161378
data_ptrs.insert(std::pair<uint64_t, void*>(D_OFFSET_ID, devPtrDropoutOffset));
14171379
data_ptrs.insert(std::pair<uint64_t, void*>(MASK_VAL_ID, &negInfinity));
1418-
#ifdef BWD_VAR
14191380
data_ptrs.insert(std::pair<uint64_t, void *>(Q_SEQLEN_ID, devActualSeqlenQ));
14201381
data_ptrs.insert(std::pair<uint64_t, void *>(K_SEQLEN_ID, devActualSeqlenK));
1421-
#endif
14221382

14231383
float scaleProb = 1.0f - dropout_probability;
14241384
data_ptrs.insert(std::pair<uint64_t, void*>(D_CONST_ID, &scale_dropout));

0 commit comments

Comments
 (0)