@@ -110,8 +110,10 @@ createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
110
110
.setComputeType (CUDNN_DATA_FLOAT)
111
111
.build ();
112
112
113
- auto seqlenQTensor = tensor_create (CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
114
- auto seqlenKTensor = tensor_create (CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
113
+ auto seqlenQTensor = tensor_create (
114
+ CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
115
+ auto seqlenKTensor = tensor_create (
116
+ CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
115
117
116
118
// Create a matmul 1 node
117
119
auto matmul_op1 = cudnn_frontend::OperationBuilder (CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
@@ -136,14 +138,13 @@ createPaddingMask(int64_t b,
136
138
int64_t d,
137
139
NVTE_QKV_Layout layout,
138
140
cudnnDataType_t tensorType,
139
- std::vector<cudnn_frontend::Operation>& ops,
140
- cudnn_frontend::Tensor& prevBlockOutputTensor) {
141
+ std::vector<cudnn_frontend::Operation>* ops,
142
+ const cudnn_frontend::Tensor& prevBlockOutputTensor) {
141
143
CUDNN_FRONTEND_UNUSED (d);
142
144
CUDNN_FRONTEND_UNUSED (layout);
143
145
CUDNN_FRONTEND_UNUSED (tensorType);
144
146
145
- cudnn_frontend::throw_if (
146
- ops.size () == 0 , " Padding Mask constructed incorrectly as the first one" , CUDNN_STATUS_BAD_PARAM);
147
+ NVTE_CHECK (ops->size () != 0 , " Padding Mask constructed incorrectly as the first one" );
147
148
148
149
// subtraction output
149
150
int64_t afterBMM1_dim[4 ] = {b, h, s_q, s_kv};
@@ -156,30 +157,32 @@ createPaddingMask(int64_t b,
156
157
int64_t seqlen_stride[4 ] = {1 , 1 , 1 , 1 };
157
158
158
159
// mask value to put in the masked pixels
159
- auto maskValTensor =
160
- tensor_create (CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, false , true ); // is by value
161
- auto seqlenQTensor = tensor_create (CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
162
- auto seqlenKTensor = tensor_create (CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
160
+ auto maskValTensor = tensor_create (
161
+ CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, false , true );
162
+ auto seqlenQTensor = tensor_create (
163
+ CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
164
+ auto seqlenKTensor = tensor_create (
165
+ CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
163
166
164
167
// gen index row output
165
- auto rowIndexTensor =
166
- tensor_create ( CUDNN_DATA_FLOAT, VIRTUAL_ID + 300 , afterBMM1_dim, afterBMM1_stride, true , false ); // is virtual
168
+ auto rowIndexTensor = tensor_create (
169
+ CUDNN_DATA_FLOAT, VIRTUAL_ID + 300 , afterBMM1_dim, afterBMM1_stride, true , false );
167
170
// gen index column output
168
- auto columnIndexTensor =
169
- tensor_create ( CUDNN_DATA_FLOAT, VIRTUAL_ID + 301 , afterBMM1_dim, afterBMM1_stride, true , false ); // is virtual
171
+ auto columnIndexTensor = tensor_create (
172
+ CUDNN_DATA_FLOAT, VIRTUAL_ID + 301 , afterBMM1_dim, afterBMM1_stride, true , false );
170
173
// less than row output
171
174
auto lessThanRowTensor = tensor_create (
172
- CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 302 , afterBMM1_dim, afterBMM1_stride, true , false ); // is virtual
175
+ CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 302 , afterBMM1_dim, afterBMM1_stride, true , false );
173
176
// less than column output
174
177
auto lessThanColTensor = tensor_create (
175
- CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 303 , afterBMM1_dim, afterBMM1_stride, true , false ); // is virtual
178
+ CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 303 , afterBMM1_dim, afterBMM1_stride, true , false );
176
179
// padding mask (lessthanRow && lessthanCol)
177
180
auto paddingMaskTensor = tensor_create (
178
- CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 304 , afterBMM1_dim, afterBMM1_stride, true , false ); // is virtual
181
+ CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 304 , afterBMM1_dim, afterBMM1_stride, true , false );
179
182
180
183
// output after masking
181
- auto maskOutputTensor =
182
- tensor_create ( CUDNN_DATA_FLOAT, VIRTUAL_ID + 305 , afterBMM1_dim, afterBMM1_stride, true , false ); // is virtual
184
+ auto maskOutputTensor = tensor_create (
185
+ CUDNN_DATA_FLOAT, VIRTUAL_ID + 305 , afterBMM1_dim, afterBMM1_stride, true , false );
183
186
184
187
// Define the gen index for row descriptor
185
188
auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder ()
@@ -189,7 +192,8 @@ createPaddingMask(int64_t b,
189
192
.build ();
190
193
191
194
// Create a gen index Node.
192
- auto genIndexRow_op = unary_pw_op_create (prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
195
+ auto genIndexRow_op = unary_pw_op_create (
196
+ prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
193
197
194
198
// Define the gen index for row descriptor
195
199
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder ()
@@ -199,42 +203,45 @@ createPaddingMask(int64_t b,
199
203
.build ();
200
204
201
205
// Create a gen index Node.
202
- auto genIndexColumn_op = unary_pw_op_create (prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
206
+ auto genIndexColumn_op = unary_pw_op_create (
207
+ prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
203
208
204
209
// Define the less than comparison for row descriptor
205
210
auto lessThanRowDesc = pw_desc_create (CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
206
211
207
212
// Create a less than comparison for row Node.
208
- auto lessThanRow_op = binary_pw_op_create (rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc);
213
+ auto lessThanRow_op = binary_pw_op_create (
214
+ rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc);
209
215
210
216
// Define the less than comparison for column descriptor
211
217
auto lessThanColDesc = pw_desc_create (CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
212
218
213
219
// Create a less than comparison for col Node.
214
- auto lessThanCol_op = binary_pw_op_create (columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc);
220
+ auto lessThanCol_op = binary_pw_op_create (
221
+ columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc);
215
222
216
223
// Define the less than comparison for column descriptor
217
224
auto paddingMaskAndDesc = pw_desc_create (CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND);
218
225
219
226
// Create a and node for combining lessThanRow and lessThanCol
220
- auto paddingMaskAnd_op =
221
- binary_pw_op_create ( lessThanRowTensor, lessThanColTensor, paddingMaskTensor, paddingMaskAndDesc);
227
+ auto paddingMaskAnd_op = binary_pw_op_create (
228
+ lessThanRowTensor, lessThanColTensor, paddingMaskTensor, paddingMaskAndDesc);
222
229
223
230
// ///////////////// Apply the mask //////////////////////////
224
231
225
232
// Define the binary select to perform masking descriptor
226
233
auto maskDesc = pw_desc_create (CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT);
227
234
228
235
// Create a binary select Node.
229
- auto mask_op =
230
- ternary_pw_op_create ( prevBlockOutputTensor, maskValTensor, paddingMaskTensor, maskOutputTensor, maskDesc);
236
+ auto mask_op = ternary_pw_op_create (
237
+ prevBlockOutputTensor, maskValTensor, paddingMaskTensor, maskOutputTensor, maskDesc);
231
238
232
- ops. push_back (std::move (genIndexRow_op));
233
- ops. push_back (std::move (genIndexColumn_op));
234
- ops. push_back (std::move (lessThanRow_op));
235
- ops. push_back (std::move (lessThanCol_op));
236
- ops. push_back (std::move (paddingMaskAnd_op));
237
- ops. push_back (std::move (mask_op));
239
+ ops-> push_back (std::move (genIndexRow_op));
240
+ ops-> push_back (std::move (genIndexColumn_op));
241
+ ops-> push_back (std::move (lessThanRow_op));
242
+ ops-> push_back (std::move (lessThanCol_op));
243
+ ops-> push_back (std::move (paddingMaskAnd_op));
244
+ ops-> push_back (std::move (mask_op));
238
245
239
246
return maskOutputTensor;
240
247
}
@@ -637,8 +644,10 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
637
644
int64_t seqlen_dim[4 ] = {b, 1 , 1 , 1 };
638
645
int64_t seqlen_stride[4 ] = {1 , 1 , 1 , 1 };
639
646
640
- auto seqlenQTensor = tensor_create (CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
641
- auto seqlenKTensor = tensor_create (CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
647
+ auto seqlenQTensor = tensor_create (
648
+ CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
649
+ auto seqlenKTensor = tensor_create (
650
+ CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false , false );
642
651
643
652
auto vTensor = tensor_create (tensorType, V_ID, v_dim, v_stride, false , false );
644
653
// second GEMM output
@@ -716,8 +725,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
716
725
auto sAfterMaskTensor = createCausalMask (
717
726
b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor );
718
727
719
- auto sAfterPaddingMaskTensor = createPaddingMask (b, h, s_q, s_kv, d, layout, tensorType, ops, sAfterMaskTensor );
720
- auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move (sAfterPaddingMaskTensor ));
728
+ auto sAfterPaddingMaskTensor = createPaddingMask (
729
+ b, h, s_q, s_kv, d, layout, tensorType, &ops, sAfterMaskTensor );
730
+ auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(
731
+ std::move (sAfterPaddingMaskTensor ));
721
732
722
733
NVTE_CHECK (dropout_probability != 1 .0f ,
723
734
" Dropout probability cannot be 1.0" );
@@ -1013,8 +1024,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
1013
1024
auto pAfterMaskTensor = createCausalMask (
1014
1025
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor);
1015
1026
1016
- auto pAfterPaddingMaskTensor = createPaddingMask (b, h, s_q, s_kv, d, layout, tensorType, ops, pAfterMaskTensor);
1017
- auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move (pAfterPaddingMaskTensor));
1027
+ auto pAfterPaddingMaskTensor = createPaddingMask (
1028
+ b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterMaskTensor);
1029
+ auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(
1030
+ std::move (pAfterPaddingMaskTensor));
1018
1031
1019
1032
/* ******************************************************************************
1020
1033
* pAfterMaskTensor - softmaxStats -> pAfterSubtract */
@@ -1348,7 +1361,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
1348
1361
1349
1362
constexpr size_t nthreads_per_block = 128 ;
1350
1363
const size_t grid = (b + nthreads_per_block - 1 ) / nthreads_per_block;
1351
- void *devActualSeqlenQ = static_cast <int8_t *>(devPtrdQAccumulator) + dqAccum_workspace_size;
1364
+ void *devActualSeqlenQ =
1365
+ static_cast <int8_t *>(devPtrdQAccumulator) + dqAccum_workspace_size;
1352
1366
void *devActualSeqlenK = static_cast <int8_t *>(devActualSeqlenQ) + b * sizeof (int32_t );
1353
1367
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0 , stream>>> (
1354
1368
b, static_cast <const int32_t *>(devPtrCuSeqlenQ),
0 commit comments