Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 16, 2024
1 parent 9a09edb commit 277dd60
Show file tree
Hide file tree
Showing 13 changed files with 470 additions and 399 deletions.
4 changes: 2 additions & 2 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@ def test_dot_product_attention(
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: fused attn vs flash attn")
torch.save(fused_attn_fwd, 'fused_attn_fwd.pt')
torch.save(flash_attn_fwd, 'flash_attn_fwd.pt')
torch.save(fused_attn_fwd, "fused_attn_fwd.pt")
torch.save(flash_attn_fwd, "flash_attn_fwd.pt")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
Expand Down
144 changes: 74 additions & 70 deletions transformer_engine/common/fused_attn/fused_attn.cpp

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,39 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_QKV,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);

void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV,
Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);

void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);

void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
Expand All @@ -63,21 +64,23 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);

void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
int64_t window_size_right, bool bottom_right_diagonal, bool deterministic,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);

Expand Down
Loading

0 comments on commit 277dd60

Please sign in to comment.