Skip to content

Commit

Permalink
Fused attention fixes for cuDNN 8.9.3 (#311)
Browse files Browse the repository at this point in the history
* Fix bprop for cuDNN 8.9.3

Signed-off-by: Charlene Yang <[email protected]>

* Update cuDNN version requirement to 8.9.3

Signed-off-by: Charlene Yang <[email protected]>

* debug paddle CI

Signed-off-by: Charlene Yang <[email protected]>

* debug paddle CI; force LD_LIBRARY

Signed-off-by: Charlene Yang <[email protected]>

* debug paddle CI; force LD_LIBRARY to /opt

Signed-off-by: Charlene Yang <[email protected]>

* remove debug info for paddle

Signed-off-by: Charlene Yang <[email protected]>

* change cudnn requirement to 8.9.1 for v1 and 8.9.0 for v2; add batch size 32 for unit test; add LD library path for paddle tests temporarily

Signed-off-by: Charlene Yang <[email protected]>

* remove printf line in fused_attn.cpp

Signed-off-by: Charlene Yang <[email protected]>

* add batch size 32 for unit test

Signed-off-by: Charlene Yang <[email protected]>

* update cudnn-frontend to 0.9.2

Signed-off-by: Charlene Yang <[email protected]>

* remove temporary LD library path used for testing pre-released cudnn 8.9.3

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa authored Jul 14, 2023
1 parent 58d2eba commit 0707552
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 34 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Pre-requisites
* CUDA 11.8 or later
* NVIDIA Driver supporting CUDA 11.8 or later
* cuDNN 8.1 or later
* For FP8 fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.

From source
^^^^^^^^^^^
Expand Down
13 changes: 8 additions & 5 deletions tests/pytorch/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,20 @@ def __init__(

model_configs = {
"test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"),
"test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
"test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
"test4": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
"test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
"test2": ModelConfig(1, 1024, 16, 64, 512, 0.0, "causal"),
"test3": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
"test4": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
"test5": ModelConfig(1, 2048, 16, 128, 512, 0.0, "causal"),
"test6": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
"test7": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
"test8": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
}

param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
param_types.append(torch.bfloat16)

batch_sizes = [1, 2]
batch_sizes = [1, 2, 32]

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ createDropoutForward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
tensorType, D_CONST_ID, scale_dim,
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
Expand Down Expand Up @@ -454,7 +454,7 @@ createDropoutBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(
tensorType, D_CONST_ID, scale_dim,
CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(
Expand Down Expand Up @@ -738,6 +738,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
b, h, s_q, s_kv, d, o_stride,
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);

int64_t dqAccum_dim[4] = {b, h, s_q, d};
int64_t dqAccum_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dqAccum_stride,
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);

int64_t scale_dim[4] = {1, 1, 1, 1};
int64_t scale_stride[4] = {1, 1, 1, 1};

Expand Down Expand Up @@ -770,19 +775,19 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto afterReductionTensor = tensor_create(
CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, reduction_dim,
reduction_stride, true, false); // is virtual
auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder()
auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder()
.setComputeType(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_MAX)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();

// Create a reduction max node
auto reductionMax_op = cudnn_frontend::OperationBuilder(
// Create a reduction add node
auto reductionAdd_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(dotProductTensor)
.setyDesc(afterReductionTensor)
.setreductionDesc(reductionMaxDesc)
.setreductionDesc(reductionAddDesc)
.build();
ops.push_back(std::move(reductionMax_op));
ops.push_back(std::move(reductionAdd_op));


/*******************************************************************************
Expand Down Expand Up @@ -895,16 +900,25 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
ops.push_back(std::move(reshape_op));

// Outputs of bprop
int64_t dqkv_dim[4] = {b, h, s_kv, d};
int64_t dqkv_stride[4];
generateMatrixStrides(
b, h, s_q, s_kv, d, dqkv_stride,
int64_t dq_dim[4] = {b, h, s_q, d};
int64_t dq_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dq_stride,
layout, NVTE_QKV_Matrix::NVTE_Q_Matrix);

int64_t dk_dim[4] = {b, h, s_kv, d};
int64_t dk_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dk_stride,
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);

int64_t dv_dim[4] = {b, h, s_kv, d};
int64_t dv_stride[4];
generateMatrixStrides(b, h, s_q, s_kv, d, dv_stride,
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);

// Outputs of backprop
auto dQTensor = tensor_create(tensorType, dQ_ID, dqkv_dim, dqkv_stride, false, false);
auto dKTensor = tensor_create(tensorType, dK_ID, dqkv_dim, dqkv_stride, false, false);
auto dVTensor = tensor_create(tensorType, dV_ID, dqkv_dim, dqkv_stride, false, false);
auto dQTensor = tensor_create(tensorType, dQ_ID, dq_dim, dq_stride, false, false);
auto dKTensor = tensor_create(tensorType, dK_ID, dk_dim, dk_stride, false, false);
auto dVTensor = tensor_create(tensorType, dV_ID, dv_dim, dv_stride, false, false);
// not virtual

/*******************************************************************************
Expand Down Expand Up @@ -1028,8 +1042,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
* dP @ K -> dqAccumTensor */

auto dqAccumTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dqkv_dim)
.setStride(4, dqkv_stride)
.setDim(4, dqAccum_dim)
.setStride(4, dqAccum_stride)
.setId(dQ_ACCUM_ID)
.setAlignment(16) // 16B alignment is needed to run a tensor core engine
.setDataType(CUDNN_DATA_FLOAT)
Expand All @@ -1044,7 +1058,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.build();
auto matmul_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)
.setaMatDesc(dPTensor)
.setaMatDesc(dPScaledTensor)
.setbMatDesc(kTensor)
.setcMatDesc(dqAccumTensor)
.setmatmulDesc(matmul_3_Desc)
Expand All @@ -1060,7 +1074,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
p_transpose_stride, true, false); // is virtual
auto reshape_op3 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)
.setxDesc(dPTensor)
.setxDesc(dPScaledTensor)
.setyDesc(dPTransposeTensor)
.build();
ops.push_back(std::move(reshape_op3));
Expand Down Expand Up @@ -1185,7 +1199,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(

// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = num_head * head_dim;
const auto stride = 2 * num_head * head_dim;

void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
Expand Down Expand Up @@ -1256,7 +1270,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen,
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;

auto stride = num_head * head_dim;
auto stride = 2 * num_head * head_dim;
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,17 +680,23 @@ void fused_attn_max_512_fwd_impl(
// inference mode doesn't need the S auxiliary
auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) ||
(mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && is_training;
std::shared_ptr<cudnn_frontend::Tensor> maskInput;
auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops);

NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS,
"NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS has not been implemented.");

if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) {
createBias(b, h, s_q, s_kv, d, layout, tensorType, ops, bmm1_output);
auto bias_output = createBias(b, h, s_q, s_kv, d, layout,
tensorType, ops, bmm1_output);
maskInput = std::make_shared<cudnn_frontend::Tensor>(std::move(bias_output));
}
if (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) {
maskInput = std::make_shared<cudnn_frontend::Tensor>(std::move(bmm1_output));
}

auto mask_output = createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops,
bmm1_output, false);
*maskInput.get(), false);

NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0.");

Expand Down Expand Up @@ -1248,7 +1254,7 @@ void fused_attn_max_512_fwd_qkvpacked(

// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;
const auto stride = num_head * head_dim;
const auto stride = 2 * num_head * head_dim;

void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
Expand Down Expand Up @@ -1322,7 +1328,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
void *devPtrQ = input_Q->data.dptr;

// KV shape is [b, s, 2, h, d]
const auto stride = num_head * head_dim;
const auto stride = 2 * num_head * head_dim;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);

Expand Down Expand Up @@ -1393,7 +1399,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
// QKV shape is [b, s, 3, h, d]
void *devPtrQKV = input_QKV->data.dptr;

auto stride = num_head * head_dim;
auto stride = 2 * num_head * head_dim;
void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
Expand Down Expand Up @@ -1453,7 +1459,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k

// Q shape is [b, s, h, d]
// KV shape is [b, s, 2, h, d]
auto stride = num_head * head_dim;
auto stride = 2 * num_head * head_dim;
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_KV->data.dptr;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
Expand Down

0 comments on commit 0707552

Please sign in to comment.