Skip to content

Commit c014de1

Browse files
authored
[ROCm][CI] Fix test_cudagraph_mode.py Failure For AMD CI (#29808)
Signed-off-by: Micah Williamson <[email protected]>
1 parent 1b1e35a commit c014de1

File tree

1 file changed

+14
-26
lines changed

1 file changed

+14
-26
lines changed

tests/v1/cudagraph/test_cudagraph_mode.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -100,32 +100,20 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
100100

101101
# test cudagraph_mode with different compilation mode.
102102
# (backend_name, cudagraph_mode, compilation_mode, supported)
103-
if current_platform.is_rocm():
104-
combo_cases_2 = [
105-
("RocmAttn", "FULL", CompilationMode.NONE, True),
106-
("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True),
107-
("RocmAttn", "PIECEWISE", CompilationMode.NONE, False),
108-
("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
109-
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
110-
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
111-
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
112-
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
113-
("RocmAttn", "NONE", CompilationMode.NONE, True),
114-
("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True),
115-
]
116-
else:
117-
combo_cases_2 = [
118-
("FA2", "FULL", CompilationMode.NONE, True),
119-
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
120-
("FA2", "PIECEWISE", CompilationMode.NONE, True),
121-
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
122-
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, True),
123-
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
124-
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
125-
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
126-
("FA2", "NONE", CompilationMode.NONE, True),
127-
("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
128-
]
103+
attn_backend = "RocmAttn" if current_platform.is_rocm() else "FA2"
104+
105+
combo_cases_2 = [
106+
(attn_backend, "FULL", CompilationMode.NONE, True),
107+
(attn_backend, "FULL", CompilationMode.VLLM_COMPILE, True),
108+
(attn_backend, "PIECEWISE", CompilationMode.NONE, True),
109+
(attn_backend, "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
110+
(attn_backend, "FULL_AND_PIECEWISE", CompilationMode.NONE, True),
111+
(attn_backend, "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
112+
(attn_backend, "FULL_DECODE_ONLY", CompilationMode.NONE, True),
113+
(attn_backend, "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
114+
(attn_backend, "NONE", CompilationMode.NONE, True),
115+
(attn_backend, "NONE", CompilationMode.VLLM_COMPILE, True),
116+
]
129117

130118

131119
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)