From c08692a7c41e6432d98d31064389f73406962c47 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 14 Feb 2025 11:28:25 -0800 Subject: [PATCH] Revert "Allow running without pytest" This reverts commit 1a12817b546c1190d8c67079427b4a34bbd2586a. --- .../wave/attention/extend_attention_test.py | 147 ++++++++---------- 1 file changed, 64 insertions(+), 83 deletions(-) diff --git a/tests/kernel/wave/attention/extend_attention_test.py b/tests/kernel/wave/attention/extend_attention_test.py index b49eddbd9..25c48261e 100644 --- a/tests/kernel/wave/attention/extend_attention_test.py +++ b/tests/kernel/wave/attention/extend_attention_test.py @@ -36,16 +36,15 @@ import os from enum import Enum from torch.testing import assert_allclose - -# from ..common.utils import ( -# require_e2e, -# require_cdna3, -# enable_scheduling_barriers, -# dump_generated_mlir, -# ) -# from ..common.shapes import get_test_shapes, construct_test_name +from ..common.utils import ( + require_e2e, + require_cdna3, + enable_scheduling_barriers, + dump_generated_mlir, +) from torch.nn.attention.flex_attention import flex_attention from torch.nn.attention.flex_attention import create_block_mask +from ..common.shapes import get_test_shapes, construct_test_name # Reference paged attention implementation from vLLM and sglang. @@ -274,27 +273,27 @@ def create_inputs( ) -# # TODO: Investigate errors on MI250. -# @require_e2e -# # @require_cdna3 -# @pytest.mark.parametrize("shape", get_test_shapes("extend")) -# @pytest.mark.parametrize("dtype", [torch.float16]) -# @pytest.mark.parametrize("enable_scheduling", [False]) -# @pytest.mark.parametrize("is_causal", [False]) -# @pytest.mark.parametrize( -# "mfma_variant", -# [ -# (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16), -# (MMAType.F32_32x32x8_F16, MMAType.F32_32x32x8_F16), -# ], -# ) +# TODO: Investigate errors on MI250. +@require_e2e +# @require_cdna3 +@pytest.mark.parametrize("shape", get_test_shapes("extend")) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("enable_scheduling", [False]) +@pytest.mark.parametrize("is_causal", [False]) +@pytest.mark.parametrize( + "mfma_variant", + [ + (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16), + (MMAType.F32_32x32x8_F16, MMAType.F32_32x32x8_F16), + ], +) def testExtendAttention( shape: list[AttentionShape], dtype: torch.dtype, enable_scheduling: bool, is_causal: bool, mfma_variant: MMAType, - # request, + request, ): torch.manual_seed(0) @@ -351,25 +350,25 @@ def testExtendAttention( ) hyperparams.update(get_default_scheduling_params()) config = get_default_run_config() - # run_bench = request.config.getoption("--runperf") - # dump_perf = request.config.getoption("--dump-perf-files-path") - # if run_bench: - # config["benchmark_batch_size"] = 1000 - # config["benchmark_repetitions"] = 3 - # if dump_perf is not None: - # perf_filename = construct_test_name( - # "wave_extend_attention", mfma_variant, is_causal, shape - # ) - # config["benchmark_results_file"] = os.path.join(dump_perf, perf_filename) + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + if run_bench: + config["benchmark_batch_size"] = 1000 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = construct_test_name( + "wave_extend_attention", mfma_variant, is_causal, shape + ) + config["benchmark_results_file"] = os.path.join(dump_perf, perf_filename) with tk.gen.TestLaunchContext( hyperparams, canonicalize=True, run=True, - # run_bench=run_bench, + run_bench=run_bench, run_config=config, schedule=enable_scheduling, - # use_scheduling_barriers=enable_scheduling_barriers, + use_scheduling_barriers=enable_scheduling_barriers, dynamic_symbols=dynamic_symbols, dynamic_symbols_map=dynamic_symbols_map, ): @@ -387,10 +386,10 @@ def testExtendAttention( output, ) - # if dump_generated_mlir: - # filename = f"wave_extend_attention_kernel_{'x'.join(map(str, shape))}.mlir" - # with open(filename, "w") as f: - # f.write(mb_qk.module_op.get_asm()) + if dump_generated_mlir: + filename = f"wave_extend_attention_kernel_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb_qk.module_op.get_asm()) # Run the reference implementation. ref_output = ref_extend_attn( @@ -411,26 +410,26 @@ def testExtendAttention( assert_allclose(output, ref_output, rtol=1e-3, atol=1e-3) -# # TODO: Investigate errors on MI250. -# @require_e2e -# # @require_cdna3 -# @pytest.mark.parametrize("shape", get_test_shapes("extend")) -# @pytest.mark.parametrize("dtype", [torch.float16]) -# @pytest.mark.parametrize("enable_scheduling", [False]) -# @pytest.mark.parametrize("is_causal", [False]) -# @pytest.mark.parametrize( -# "mfma_variant", -# [ -# (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16), -# ], -# ) +# TODO: Investigate errors on MI250. +@require_e2e +# @require_cdna3 +@pytest.mark.parametrize("shape", get_test_shapes("extend")) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("enable_scheduling", [False]) +@pytest.mark.parametrize("is_causal", [False]) +@pytest.mark.parametrize( + "mfma_variant", + [ + (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16), + ], +) def testExtendRpeAttention( shape: list[AttentionShape], dtype: torch.dtype, enable_scheduling: bool, is_causal: bool, mfma_variant: MMAType, - # request, + request, ): torch.manual_seed(0) @@ -487,27 +486,27 @@ def testExtendRpeAttention( ) hyperparams.update(get_default_scheduling_params()) config = get_default_run_config() - # run_bench = request.config.getoption("--runperf") - # dump_perf = request.config.getoption("--dump-perf-files-path") - # if run_bench: - # config["benchmark_batch_size"] = 1000 - # config["benchmark_repetitions"] = 3 - # if dump_perf is not None: - # perf_filename = construct_test_name( - # "wave_extend_attention", mfma_variant, is_causal, shape - # ) - # config["benchmark_results_file"] = os.path.join(dump_perf, perf_filename) + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + if run_bench: + config["benchmark_batch_size"] = 1000 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = construct_test_name( + "wave_extend_attention", mfma_variant, is_causal, shape + ) + config["benchmark_results_file"] = os.path.join(dump_perf, perf_filename) rpe_debug = torch.zeros((16, 864, 864), dtype=torch.float32, device="cuda") with tk.gen.TestLaunchContext( hyperparams, canonicalize=True, run=True, - # run_bench=run_bench, + run_bench=run_bench, compile_config={"print_ir_after": "all"}, run_config=config, schedule=enable_scheduling, - # use_scheduling_barriers=enable_scheduling_barriers, + use_scheduling_barriers=enable_scheduling_barriers, dynamic_symbols=dynamic_symbols, dynamic_symbols_map=dynamic_symbols_map, ): @@ -558,21 +557,3 @@ def testExtendRpeAttention( torch.testing.assert_close( output, ref_output, rtol=2e-3, atol=2e-3, check_dtype=False ) - - -shape = AttentionShape( - num_seqs=2, - context_len=1024, - num_query_heads=16, - num_kv_heads=1, - head_size=128, - head_size_kv=128, - block_size=64, -) -testExtendRpeAttention( - shape, - torch.float16, - enable_scheduling=False, - is_causal=False, - mfma_variant=(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16), -)