Skip to content

Commit

Permalink
add left_bound/right_bound
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Dec 16, 2024
1 parent 8d17e10 commit 9a09edb
Show file tree
Hide file tree
Showing 14 changed files with 319 additions and 192 deletions.
2 changes: 2 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ def test_dot_product_attention(
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: fused attn vs flash attn")
torch.save(fused_attn_fwd, 'fused_attn_fwd.pt')
torch.save(flash_attn_fwd, 'flash_attn_fwd.pt')
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
Expand Down
68 changes: 34 additions & 34 deletions transformer_engine/common/fused_attn/fused_attn.cpp

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);

void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Expand All @@ -40,7 +40,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
Expand All @@ -50,7 +50,7 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
Expand All @@ -63,7 +63,7 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Expand All @@ -74,7 +74,7 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
Expand Down
36 changes: 16 additions & 20 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1652,7 +1652,7 @@ void fused_attn_fp8_bwd_impl(
void fused_attn_fp8_fwd_impl_v1(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, bool is_training,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK,
void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO,
void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV,
Expand Down Expand Up @@ -1688,9 +1688,7 @@ void fused_attn_fp8_fwd_impl_v1(
dropout_probability,
layout,
bias_type,
mask_type,
0,
0,
mask_type, window_size_left, window_size_right, bottom_right_diagonal,
true,
fwd_tensor_type,
fwd_tensor_type};
Expand Down Expand Up @@ -1952,7 +1950,7 @@ void fused_attn_fp8_fwd_impl_v1(
void fused_attn_fp8_bwd_impl_v1(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, float scaling_factor,
float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM,
void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, void* devPtrdV,
void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO,
void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS,
Expand Down Expand Up @@ -1992,9 +1990,7 @@ void fused_attn_fp8_bwd_impl_v1(
dropout_probability,
layout,
bias_type,
mask_type,
0,
0,
mask_type, window_size_left, window_size_right, bottom_right_diagonal,
false,
fwd_tensor_type,
bwd_tensor_type};
Expand Down Expand Up @@ -2352,7 +2348,7 @@ void fused_attn_fp8_bwd_impl_v1(
void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_dim, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor* input_QKV, Tensor* input_output_S, Tensor* output_O,
NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens,
const Tensor* rng_state, Tensor* workspace, cudaStream_t stream,
Expand Down Expand Up @@ -2422,7 +2418,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_fwd_impl_v1(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
Expand Down Expand Up @@ -2453,7 +2449,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
// fused attention BWD FP8 with packed QKV
void fused_attn_fp8_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M,
const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP,
const Tensor* output_dQKV, const Tensor* cu_seqlens, const Tensor* rng_state, Tensor* workspace,
Expand Down Expand Up @@ -2514,7 +2510,7 @@ void fused_attn_fp8_bwd_qkvpacked(
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv,
p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv,
devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK,
devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
Expand Down Expand Up @@ -2551,7 +2547,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor* input_Q,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q,
const Tensor* input_KV, Tensor* input_output_S, Tensor* output_O,
NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q,
const Tensor* cu_seqlens_kv, const Tensor* rng_state,
Expand Down Expand Up @@ -2623,7 +2619,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_fwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
Expand Down Expand Up @@ -2656,7 +2652,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
void fused_attn_fp8_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO,
const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP,
const Tensor* output_dQ, const Tensor* output_dKV, const Tensor* cu_seqlens_q,
Expand Down Expand Up @@ -2720,7 +2716,7 @@ void fused_attn_fp8_bwd_kvpacked(
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv,
p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv,
devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK,
devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
Expand Down Expand Up @@ -2757,7 +2753,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor* input_Q, const Tensor* input_K,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K,
const Tensor* input_V, Tensor* input_output_S, Tensor* output_O,
NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q,
const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace,
Expand Down Expand Up @@ -2821,7 +2817,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_fwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
Expand Down Expand Up @@ -2854,7 +2850,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor* input_Q,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor* input_Q,
const Tensor* input_K, const Tensor* input_V, const Tensor* input_O,
const Tensor* input_dO, const Tensor* input_M, const Tensor* input_ZInv,
const Tensor* input_S, Tensor* input_output_dP, const Tensor* output_dQ,
Expand Down Expand Up @@ -2911,7 +2907,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) {
fused_attn::fused_attn_fp8_bwd_impl_v1(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv,
p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv,
devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK,
devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP,
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
Expand Down
Loading

0 comments on commit 9a09edb

Please sign in to comment.