Skip to content

Commit

Permalink
Revert "Allow running without pytest"
Browse files Browse the repository at this point in the history
This reverts commit 1a12817.
  • Loading branch information
nicolasvasilache committed Feb 14, 2025
1 parent c217848 commit c08692a
Showing 1 changed file with 64 additions and 83 deletions.
147 changes: 64 additions & 83 deletions tests/kernel/wave/attention/extend_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@
import os
from enum import Enum
from torch.testing import assert_allclose

# from ..common.utils import (
# require_e2e,
# require_cdna3,
# enable_scheduling_barriers,
# dump_generated_mlir,
# )
# from ..common.shapes import get_test_shapes, construct_test_name
from ..common.utils import (
require_e2e,
require_cdna3,
enable_scheduling_barriers,
dump_generated_mlir,
)
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.attention.flex_attention import create_block_mask
from ..common.shapes import get_test_shapes, construct_test_name

# Reference paged attention implementation from vLLM and sglang.

Expand Down Expand Up @@ -274,27 +273,27 @@ def create_inputs(
)


# # TODO: Investigate errors on MI250.
# @require_e2e
# # @require_cdna3
# @pytest.mark.parametrize("shape", get_test_shapes("extend"))
# @pytest.mark.parametrize("dtype", [torch.float16])
# @pytest.mark.parametrize("enable_scheduling", [False])
# @pytest.mark.parametrize("is_causal", [False])
# @pytest.mark.parametrize(
# "mfma_variant",
# [
# (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16),
# (MMAType.F32_32x32x8_F16, MMAType.F32_32x32x8_F16),
# ],
# )
# TODO: Investigate errors on MI250.
@require_e2e
# @require_cdna3
@pytest.mark.parametrize("shape", get_test_shapes("extend"))
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("enable_scheduling", [False])
@pytest.mark.parametrize("is_causal", [False])
@pytest.mark.parametrize(
"mfma_variant",
[
(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16),
(MMAType.F32_32x32x8_F16, MMAType.F32_32x32x8_F16),
],
)
def testExtendAttention(
shape: list[AttentionShape],
dtype: torch.dtype,
enable_scheduling: bool,
is_causal: bool,
mfma_variant: MMAType,
# request,
request,
):

torch.manual_seed(0)
Expand Down Expand Up @@ -351,25 +350,25 @@ def testExtendAttention(
)
hyperparams.update(get_default_scheduling_params())
config = get_default_run_config()
# run_bench = request.config.getoption("--runperf")
# dump_perf = request.config.getoption("--dump-perf-files-path")
# if run_bench:
# config["benchmark_batch_size"] = 1000
# config["benchmark_repetitions"] = 3
# if dump_perf is not None:
# perf_filename = construct_test_name(
# "wave_extend_attention", mfma_variant, is_causal, shape
# )
# config["benchmark_results_file"] = os.path.join(dump_perf, perf_filename)
run_bench = request.config.getoption("--runperf")
dump_perf = request.config.getoption("--dump-perf-files-path")
if run_bench:
config["benchmark_batch_size"] = 1000
config["benchmark_repetitions"] = 3
if dump_perf is not None:
perf_filename = construct_test_name(
"wave_extend_attention", mfma_variant, is_causal, shape
)
config["benchmark_results_file"] = os.path.join(dump_perf, perf_filename)

with tk.gen.TestLaunchContext(
hyperparams,
canonicalize=True,
run=True,
# run_bench=run_bench,
run_bench=run_bench,
run_config=config,
schedule=enable_scheduling,
# use_scheduling_barriers=enable_scheduling_barriers,
use_scheduling_barriers=enable_scheduling_barriers,
dynamic_symbols=dynamic_symbols,
dynamic_symbols_map=dynamic_symbols_map,
):
Expand All @@ -387,10 +386,10 @@ def testExtendAttention(
output,
)

# if dump_generated_mlir:
# filename = f"wave_extend_attention_kernel_{'x'.join(map(str, shape))}.mlir"
# with open(filename, "w") as f:
# f.write(mb_qk.module_op.get_asm())
if dump_generated_mlir:
filename = f"wave_extend_attention_kernel_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb_qk.module_op.get_asm())

# Run the reference implementation.
ref_output = ref_extend_attn(
Expand All @@ -411,26 +410,26 @@ def testExtendAttention(
assert_allclose(output, ref_output, rtol=1e-3, atol=1e-3)


# # TODO: Investigate errors on MI250.
# @require_e2e
# # @require_cdna3
# @pytest.mark.parametrize("shape", get_test_shapes("extend"))
# @pytest.mark.parametrize("dtype", [torch.float16])
# @pytest.mark.parametrize("enable_scheduling", [False])
# @pytest.mark.parametrize("is_causal", [False])
# @pytest.mark.parametrize(
# "mfma_variant",
# [
# (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16),
# ],
# )
# TODO: Investigate errors on MI250.
@require_e2e
# @require_cdna3
@pytest.mark.parametrize("shape", get_test_shapes("extend"))
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("enable_scheduling", [False])
@pytest.mark.parametrize("is_causal", [False])
@pytest.mark.parametrize(
"mfma_variant",
[
(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16),
],
)
def testExtendRpeAttention(
shape: list[AttentionShape],
dtype: torch.dtype,
enable_scheduling: bool,
is_causal: bool,
mfma_variant: MMAType,
# request,
request,
):

torch.manual_seed(0)
Expand Down Expand Up @@ -487,27 +486,27 @@ def testExtendRpeAttention(
)
hyperparams.update(get_default_scheduling_params())
config = get_default_run_config()
# run_bench = request.config.getoption("--runperf")
# dump_perf = request.config.getoption("--dump-perf-files-path")
# if run_bench:
# config["benchmark_batch_size"] = 1000
# config["benchmark_repetitions"] = 3
# if dump_perf is not None:
# perf_filename = construct_test_name(
# "wave_extend_attention", mfma_variant, is_causal, shape
# )
# config["benchmark_results_file"] = os.path.join(dump_perf, perf_filename)
run_bench = request.config.getoption("--runperf")
dump_perf = request.config.getoption("--dump-perf-files-path")
if run_bench:
config["benchmark_batch_size"] = 1000
config["benchmark_repetitions"] = 3
if dump_perf is not None:
perf_filename = construct_test_name(
"wave_extend_attention", mfma_variant, is_causal, shape
)
config["benchmark_results_file"] = os.path.join(dump_perf, perf_filename)

rpe_debug = torch.zeros((16, 864, 864), dtype=torch.float32, device="cuda")
with tk.gen.TestLaunchContext(
hyperparams,
canonicalize=True,
run=True,
# run_bench=run_bench,
run_bench=run_bench,
compile_config={"print_ir_after": "all"},
run_config=config,
schedule=enable_scheduling,
# use_scheduling_barriers=enable_scheduling_barriers,
use_scheduling_barriers=enable_scheduling_barriers,
dynamic_symbols=dynamic_symbols,
dynamic_symbols_map=dynamic_symbols_map,
):
Expand Down Expand Up @@ -558,21 +557,3 @@ def testExtendRpeAttention(
torch.testing.assert_close(
output, ref_output, rtol=2e-3, atol=2e-3, check_dtype=False
)


shape = AttentionShape(
num_seqs=2,
context_len=1024,
num_query_heads=16,
num_kv_heads=1,
head_size=128,
head_size_kv=128,
block_size=64,
)
testExtendRpeAttention(
shape,
torch.float16,
enable_scheduling=False,
is_causal=False,
mfma_variant=(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16),
)

0 comments on commit c08692a

Please sign in to comment.