Skip to content

Commit

Permalink
Fix cuDNN sliding window size (#1212)
Browse files Browse the repository at this point in the history
* adjust window size to (i-window_size_left,i] for cuDNN

Signed-off-by: Charlene Yang <[email protected]>

* reduce the window to make any errors more pronouced

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
cyanguwa and ksivaman authored Oct 7, 2024
1 parent c24a4c4 commit c3b3cd2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
2 changes: 1 addition & 1 deletion tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_dot_product_attention(
# Test backend availability
window_size = (-1, -1)
if swa:
window_size = tuple(torch.randint(0, config.max_seqlen_kv, [2], dtype=torch.int32).tolist())
window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, window_size)
available_backends, fused_attn_backends = _get_attention_backends(
config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
}
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion();

try {
Expand Down Expand Up @@ -221,8 +218,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);

if (cudnn_runtime_version >= 90200 && window_size_left != s_q) {
sdpa_options.set_sliding_window_length(window_size_left);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_options.set_sliding_window_length(window_size_left + 1);
}

sdpa_options.set_alibi_mask(is_alibi);
Expand Down Expand Up @@ -407,9 +404,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
Expand Down Expand Up @@ -584,8 +578,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);

if (cudnn_runtime_version >= 90200 && window_size_left != s_q) {
sdpa_backward_options.set_sliding_window_length(window_size_left);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_backward_options.set_sliding_window_length(window_size_left + 1);
}

if (cudnn_runtime_version >= 90000) {
Expand Down

0 comments on commit c3b3cd2

Please sign in to comment.