Skip to content

Commit af63eac

Browse files
committed
Support variable sequence length after cuDNN 8.9.6
Signed-off-by: Reese Wang <[email protected]>
1 parent 7b3c057 commit af63eac

File tree

4 files changed

+165
-106
lines changed

4 files changed

+165
-106
lines changed

tests/jax/test_fused_attn.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# See LICENSE for license information.
44
"""Tests for fused attention"""
55

6+
import ctypes
67
import os
78
from enum import Enum
89
from math import sqrt
@@ -159,9 +160,13 @@ class TestSelfFusedAttn():
159160

160161
@staticmethod
161162
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
162-
head_dim):
163+
head_dim, pad_ratio):
163164

164-
assert isinstance(backend, Backend)
165+
lib = ctypes.CDLL('libcudnn.so')
166+
is_varlen_supported = lib.cudnnGetVersion() >= 8906
167+
168+
if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0 and not is_varlen_supported:
169+
pytest.skip("Arbitrary seqlen backend hasn't support padded input.")
165170

166171
if not is_fused_attn_kernel_available(dtype, dtype, attn_bias_type, attn_mask_type,
167172
dropout_probability, s, s, head_dim):
@@ -182,7 +187,8 @@ def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
182187
backend=backend,
183188
dropout_probability=dropout_probability,
184189
dtype=dtype,
185-
head_dim=d)
190+
head_dim=d,
191+
pad_ratio=pad_ratio)
186192
key = jax.random.PRNGKey(0)
187193
subkeys = jax.random.split(key, 2)
188194

0 commit comments

Comments
 (0)