From eead05633f2038a9779bbee1944a1e2c64467bff Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 13 Feb 2025 09:02:21 -0800 Subject: [PATCH] Debug RPE with MMA 32x32x8 Repro: ``` python3 tests/kernel/wave/attention/t5_rpe_attention_test.py ``` Currently fails with: ``` iree.turbine.kernel.compiler.base.CodegenError: Unknown symbol $dynamic_val0 ``` --- iree/turbine/kernel/ops/wave_ops.py | 5 +- iree/turbine/kernel/wave/codegen/emitter.py | 2 + .../turbine/kernel/wave/codegen/read_write.py | 1 + .../kernel/wave/index_sequence_analysis.py | 1 + .../kernel/wave/templates/t5_rpe_attention.py | 8 +-- .../wave/attention/t5_rpe_attention_test.py | 59 ++++++++++--------- 6 files changed, 43 insertions(+), 33 deletions(-) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index f6959d1f2..703309702 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1257,11 +1257,14 @@ def transform_index_backwards( subs = { k: index[v] for k, v in zip(iters, self.mapping.output_mapping.keys()) } - return { + print(f"subs {subs}") + d = { k: IndexSequence.from_expr(mapping[k], subs) for k in get_custom(arg).type.symbolic_shape if k in mapping } + print(f"d {d}") + return d return index diff --git a/iree/turbine/kernel/wave/codegen/emitter.py b/iree/turbine/kernel/wave/codegen/emitter.py index 9e1c685ff..e4288309c 100644 --- a/iree/turbine/kernel/wave/codegen/emitter.py +++ b/iree/turbine/kernel/wave/codegen/emitter.py @@ -390,12 +390,14 @@ def _get_const(val): # Substitute in frozen vars to simplify expression. if not isinstance(expr, sympy.Expr): expr = sympy.sympify(expr) + print(f"\n\nexpr {expr} dynamics {dynamics}") expr = expr.subs(idxc.subs) # Why affine, for now simply create indexing expressions. # This can easily be adapted to affine expressions later. select_stack = [] if isinstance(expr, sympy.Piecewise): assert len(expr.args) == 2 and expr.args[1][1], f"Unsupported piecewise {expr}" + print(f"\n\nexpr {expr} dynamics {dynamics}") for term in sympy.postorder_traversal(expr): match term: case sympy.Symbol(): diff --git a/iree/turbine/kernel/wave/codegen/read_write.py b/iree/turbine/kernel/wave/codegen/read_write.py index 737807256..c5da41602 100644 --- a/iree/turbine/kernel/wave/codegen/read_write.py +++ b/iree/turbine/kernel/wave/codegen/read_write.py @@ -80,6 +80,7 @@ def _build_start_indices( src_indices: dict[IndexExpr, IndexSequence | IndexExpr], dynamic_values: dict[IndexExpr, Any] = {}, ) -> list[OpResult]: + print(f"dynamic_values {dynamic_values}") return [ gen_sympy_index(add_emitter_subs(emitter, dynamic_values), i) for i in _get_start_indices(src_indices) diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 63d8bf6c6..8cb25bd72 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -210,6 +210,7 @@ def has_gpr_offsets(node: fx.Node) -> bool: ] assert len(dim_with_gpr_offsets) == 1, "Expected only 1-Dim has gpr offsets" gpr_offset_dim, gpr_offset_expr = dim_with_gpr_offsets[0] + print(f"elements_per_thread {elements_per_thread}") gpr_offsets = [ gpr_offset_expr.subs({GPR_NUM: i}) for i in range(elements_per_thread) ] diff --git a/iree/turbine/kernel/wave/templates/t5_rpe_attention.py b/iree/turbine/kernel/wave/templates/t5_rpe_attention.py index 76f54bad4..9ff9e823b 100644 --- a/iree/turbine/kernel/wave/templates/t5_rpe_attention.py +++ b/iree/turbine/kernel/wave/templates/t5_rpe_attention.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from dataclasses import dataclass -from typing import Optional +from typing import Optional, Tuple import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw @@ -22,7 +22,7 @@ def get_t5_rpe_attention_kernel( shape: AttentionShape, - mfma_variant: MMAType, + mfma_variant: Tuple[MMAType], dynamic_dims: bool, max_context_length: int, ): @@ -123,9 +123,7 @@ def repeat( # the partial softmax should be equivalent. i = tkw.self_index(M, tkl.i64, elements_per_thread=1) i = tkw.broadcast(i, target_shape=[B, M, K2]) - j = tkw.self_index( - K2, tkl.i64, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK - ) + j = tkw.self_index(K2, tkl.i64, elements_per_thread=4) rpe_reg = tkw.read( rpe, mapping=offset_mapping, diff --git a/tests/kernel/wave/attention/t5_rpe_attention_test.py b/tests/kernel/wave/attention/t5_rpe_attention_test.py index 95ed42e56..417bbb4f3 100644 --- a/tests/kernel/wave/attention/t5_rpe_attention_test.py +++ b/tests/kernel/wave/attention/t5_rpe_attention_test.py @@ -24,11 +24,12 @@ from iree.turbine.kernel.wave.templates.t5_rpe_attention import ( get_t5_rpe_attention_kernel, ) -from ..common.utils import ( - require_e2e, - require_cdna3, - enable_scheduling_barriers, -) + +# from ..common.utils import ( +# require_e2e, +# require_cdna3, +# enable_scheduling_barriers, +# ) from typing import Tuple shapes = [(128, 128, 128, 128, 128, 128)] @@ -83,19 +84,19 @@ def create_inputs( # TODO: Debug why failing numerics on MI250. -@require_e2e -@require_cdna3 -@pytest.mark.parametrize("shape", shapes) -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize( - "mfma_variant", - [(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)], -) +# @require_e2e +# @require_cdna3 +# @pytest.mark.parametrize("shape", shapes) +# @pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize( +# "mfma_variant", +# [(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)], +# ) def test_t5_rpe_attention( shape: Tuple[int], dtype: torch.dtype, - mfma_variant: MMAType, - request, + mfma_variant: Tuple[MMAType], + # request, ): torch.manual_seed(0) shape = AttentionShape( @@ -119,16 +120,15 @@ def test_t5_rpe_attention( hyperparams.update(get_default_scheduling_params()) config = get_default_run_config() - run_bench = request.config.getoption("--runperf") - dump_perf = request.config.getoption("--dump-perf-files-path") - if run_bench: - config["benchmark_batch_size"] = 10 - config["benchmark_repetitions"] = 3 - if dump_perf is not None: - perf_filename = request.node.name + ".json" - config["benchmark_results_file"] = os.path.join( - dump_perf, "tk_" + perf_filename - ) + # run_bench = request.config.getoption("--runperf") + # dump_perf = request.config.getoption("--dump-perf-files-path") + # if run_bench: + # config["benchmark_batch_size"] = 10 + # config["benchmark_repetitions"] = 3 + # if dump_perf is not None: + # perf_filename = request.node.name + ".json" + # config["benchmark_results_file"] = os.path.join( + # dump_perf, "tk_" + perf_filename) log2e = 1.44269504089 dk_sqrt = math.sqrt(1.0 / shape.head_size) @@ -143,9 +143,9 @@ def test_t5_rpe_attention( hyperparams, canonicalize=True, run=True, - run_bench=run_bench, + # run_bench=run_bench, run_config=config, - use_scheduling_barriers=enable_scheduling_barriers, + # use_scheduling_barriers=enable_scheduling_barriers, ): output = device_zeros(output_shape, dtype=torch.float32) # TODO: Add scaling of QK and t5_rpe as part of kernel. @@ -161,3 +161,8 @@ def test_t5_rpe_attention( ) validate_accuracy(query, key, value, rpe, output) + + +test_t5_rpe_attention( + shapes[0], torch.float16, (MMAType.F32_32x32x8_F16, MMAType.F32_32x32x8_F16) +)