Skip to content

Commit 7b3c057

Browse files
committed
Fix lint
Signed-off-by: Reese Wang <[email protected]>
1 parent 80aa2ae commit 7b3c057

File tree

1 file changed

+54
-40
lines changed

1 file changed

+54
-40
lines changed

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ createQKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
110110
.setComputeType(CUDNN_DATA_FLOAT)
111111
.build();
112112

113-
auto seqlenQTensor = tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
114-
auto seqlenKTensor = tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
113+
auto seqlenQTensor = tensor_create(
114+
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
115+
auto seqlenKTensor = tensor_create(
116+
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
115117

116118
// Create a matmul 1 node
117119
auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
@@ -136,14 +138,13 @@ createPaddingMask(int64_t b,
136138
int64_t d,
137139
NVTE_QKV_Layout layout,
138140
cudnnDataType_t tensorType,
139-
std::vector<cudnn_frontend::Operation>& ops,
140-
cudnn_frontend::Tensor& prevBlockOutputTensor) {
141+
std::vector<cudnn_frontend::Operation>* ops,
142+
const cudnn_frontend::Tensor& prevBlockOutputTensor) {
141143
CUDNN_FRONTEND_UNUSED(d);
142144
CUDNN_FRONTEND_UNUSED(layout);
143145
CUDNN_FRONTEND_UNUSED(tensorType);
144146

145-
cudnn_frontend::throw_if(
146-
ops.size() == 0, "Padding Mask constructed incorrectly as the first one", CUDNN_STATUS_BAD_PARAM);
147+
NVTE_CHECK(ops->size() != 0, "Padding Mask constructed incorrectly as the first one");
147148

148149
// subtraction output
149150
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
@@ -156,30 +157,32 @@ createPaddingMask(int64_t b,
156157
int64_t seqlen_stride[4] = {1, 1, 1, 1};
157158

158159
// mask value to put in the masked pixels
159-
auto maskValTensor =
160-
tensor_create(CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, false, true); // is by value
161-
auto seqlenQTensor = tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
162-
auto seqlenKTensor = tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
160+
auto maskValTensor = tensor_create(
161+
CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, false, true);
162+
auto seqlenQTensor = tensor_create(
163+
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
164+
auto seqlenKTensor = tensor_create(
165+
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
163166

164167
// gen index row output
165-
auto rowIndexTensor =
166-
tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 300, afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
168+
auto rowIndexTensor = tensor_create(
169+
CUDNN_DATA_FLOAT, VIRTUAL_ID + 300, afterBMM1_dim, afterBMM1_stride, true, false);
167170
// gen index column output
168-
auto columnIndexTensor =
169-
tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 301, afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
171+
auto columnIndexTensor = tensor_create(
172+
CUDNN_DATA_FLOAT, VIRTUAL_ID + 301, afterBMM1_dim, afterBMM1_stride, true, false);
170173
// less than row output
171174
auto lessThanRowTensor = tensor_create(
172-
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 302, afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
175+
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 302, afterBMM1_dim, afterBMM1_stride, true, false);
173176
// less than column output
174177
auto lessThanColTensor = tensor_create(
175-
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 303, afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
178+
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 303, afterBMM1_dim, afterBMM1_stride, true, false);
176179
// padding mask (lessthanRow && lessthanCol)
177180
auto paddingMaskTensor = tensor_create(
178-
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 304, afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
181+
CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 304, afterBMM1_dim, afterBMM1_stride, true, false);
179182

180183
// output after masking
181-
auto maskOutputTensor =
182-
tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 305, afterBMM1_dim, afterBMM1_stride, true, false); // is virtual
184+
auto maskOutputTensor = tensor_create(
185+
CUDNN_DATA_FLOAT, VIRTUAL_ID + 305, afterBMM1_dim, afterBMM1_stride, true, false);
183186

184187
// Define the gen index for row descriptor
185188
auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder()
@@ -189,7 +192,8 @@ createPaddingMask(int64_t b,
189192
.build();
190193

191194
// Create a gen index Node.
192-
auto genIndexRow_op = unary_pw_op_create(prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
195+
auto genIndexRow_op = unary_pw_op_create(
196+
prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc);
193197

194198
// Define the gen index for row descriptor
195199
auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder()
@@ -199,42 +203,45 @@ createPaddingMask(int64_t b,
199203
.build();
200204

201205
// Create a gen index Node.
202-
auto genIndexColumn_op = unary_pw_op_create(prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
206+
auto genIndexColumn_op = unary_pw_op_create(
207+
prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc);
203208

204209
// Define the less than comparison for row descriptor
205210
auto lessThanRowDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
206211

207212
// Create a less than comparison for row Node.
208-
auto lessThanRow_op = binary_pw_op_create(rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc);
213+
auto lessThanRow_op = binary_pw_op_create(
214+
rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc);
209215

210216
// Define the less than comparison for column descriptor
211217
auto lessThanColDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT);
212218

213219
// Create a less than comparison for col Node.
214-
auto lessThanCol_op = binary_pw_op_create(columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc);
220+
auto lessThanCol_op = binary_pw_op_create(
221+
columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc);
215222

216223
// Define the less than comparison for column descriptor
217224
auto paddingMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND);
218225

219226
// Create a and node for combining lessThanRow and lessThanCol
220-
auto paddingMaskAnd_op =
221-
binary_pw_op_create(lessThanRowTensor, lessThanColTensor, paddingMaskTensor, paddingMaskAndDesc);
227+
auto paddingMaskAnd_op = binary_pw_op_create(
228+
lessThanRowTensor, lessThanColTensor, paddingMaskTensor, paddingMaskAndDesc);
222229

223230
/////////////////// Apply the mask //////////////////////////
224231

225232
// Define the binary select to perform masking descriptor
226233
auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT);
227234

228235
// Create a binary select Node.
229-
auto mask_op =
230-
ternary_pw_op_create(prevBlockOutputTensor, maskValTensor, paddingMaskTensor, maskOutputTensor, maskDesc);
236+
auto mask_op = ternary_pw_op_create(
237+
prevBlockOutputTensor, maskValTensor, paddingMaskTensor, maskOutputTensor, maskDesc);
231238

232-
ops.push_back(std::move(genIndexRow_op));
233-
ops.push_back(std::move(genIndexColumn_op));
234-
ops.push_back(std::move(lessThanRow_op));
235-
ops.push_back(std::move(lessThanCol_op));
236-
ops.push_back(std::move(paddingMaskAnd_op));
237-
ops.push_back(std::move(mask_op));
239+
ops->push_back(std::move(genIndexRow_op));
240+
ops->push_back(std::move(genIndexColumn_op));
241+
ops->push_back(std::move(lessThanRow_op));
242+
ops->push_back(std::move(lessThanCol_op));
243+
ops->push_back(std::move(paddingMaskAnd_op));
244+
ops->push_back(std::move(mask_op));
238245

239246
return maskOutputTensor;
240247
}
@@ -637,8 +644,10 @@ createSVBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
637644
int64_t seqlen_dim[4] = {b, 1, 1, 1};
638645
int64_t seqlen_stride[4] = {1, 1, 1, 1};
639646

640-
auto seqlenQTensor = tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
641-
auto seqlenKTensor = tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
647+
auto seqlenQTensor = tensor_create(
648+
CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
649+
auto seqlenKTensor = tensor_create(
650+
CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false);
642651

643652
auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false);
644653
// second GEMM output
@@ -716,8 +725,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
716725
auto sAfterMaskTensor = createCausalMask(
717726
b, h, s_q, s_kv, d, layout, tensorType, &ops, sScaleTensor);
718727

719-
auto sAfterPaddingMaskTensor = createPaddingMask(b, h, s_q, s_kv, d, layout, tensorType, ops, sAfterMaskTensor);
720-
auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move(sAfterPaddingMaskTensor));
728+
auto sAfterPaddingMaskTensor = createPaddingMask(
729+
b, h, s_q, s_kv, d, layout, tensorType, &ops, sAfterMaskTensor);
730+
auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(
731+
std::move(sAfterPaddingMaskTensor));
721732

722733
NVTE_CHECK(dropout_probability != 1.0f,
723734
"Dropout probability cannot be 1.0");
@@ -1013,8 +1024,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
10131024
auto pAfterMaskTensor = createCausalMask(
10141025
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterScaleTensor);
10151026

1016-
auto pAfterPaddingMaskTensor = createPaddingMask(b, h, s_q, s_kv, d, layout, tensorType, ops, pAfterMaskTensor);
1017-
auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(std::move(pAfterPaddingMaskTensor));
1027+
auto pAfterPaddingMaskTensor = createPaddingMask(
1028+
b, h, s_q, s_kv, d, layout, tensorType, &ops, pAfterMaskTensor);
1029+
auto softmaxInput = std::make_shared<cudnn_frontend::Tensor>(
1030+
std::move(pAfterPaddingMaskTensor));
10181031

10191032
/*******************************************************************************
10201033
* pAfterMaskTensor - softmaxStats -> pAfterSubtract */
@@ -1348,7 +1361,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
13481361

13491362
constexpr size_t nthreads_per_block = 128;
13501363
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
1351-
void *devActualSeqlenQ = static_cast<int8_t *>(devPtrdQAccumulator) + dqAccum_workspace_size;
1364+
void *devActualSeqlenQ =
1365+
static_cast<int8_t *>(devPtrdQAccumulator) + dqAccum_workspace_size;
13521366
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
13531367
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
13541368
b, static_cast<const int32_t *>(devPtrCuSeqlenQ),

0 commit comments

Comments
 (0)