From 4db2ad712c4ddd763f6622352d9ba31e6911b7a4 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 3 Feb 2025 18:57:50 +0100 Subject: [PATCH 1/7] ALiBi attention in a standalone test This implements (causal) ALiBi, attention with linear biases, following the paper "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation" by Press et.al in ICLR 2022. Signed-off-by: Alex Zinenko --- iree/turbine/kernel/wave/codegen/handlers.py | 2 +- .../kernel/wave/templates/alibi_attention.py | 176 ++++++++++++++++++ .../wave/attention/alibi_attention_test.py | 168 +++++++++++++++++ 3 files changed, 345 insertions(+), 1 deletion(-) create mode 100644 iree/turbine/kernel/wave/templates/alibi_attention.py create mode 100644 tests/kernel/wave/attention/alibi_attention_test.py diff --git a/iree/turbine/kernel/wave/codegen/handlers.py b/iree/turbine/kernel/wave/codegen/handlers.py index 5e4b0b1d1..e47efb789 100644 --- a/iree/turbine/kernel/wave/codegen/handlers.py +++ b/iree/turbine/kernel/wave/codegen/handlers.py @@ -545,7 +545,7 @@ def handle_minimum(lhs: Value, rhs: Value) -> OpResult: if _is_float_type(element_type): result = arith_d.minimumf(lhs, rhs) elif _is_integer_like_type(element_type) and ( - element_type.is_signed() or element_type.is_signless() + element_type.is_signed or element_type.is_signless ): result = arith_d.minsi(lhs, rhs) else: diff --git a/iree/turbine/kernel/wave/templates/alibi_attention.py b/iree/turbine/kernel/wave/templates/alibi_attention.py new file mode 100644 index 000000000..55fc02871 --- /dev/null +++ b/iree/turbine/kernel/wave/templates/alibi_attention.py @@ -0,0 +1,176 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from dataclasses import dataclass +from typing import Optional + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape + + +def get_alibi_attention_kernel( + shape: AttentionShape, mfma_variant: MMAType, dynamic_dims: bool +): + """Produces an attention kernel with linear biases (ALiBi). + + Note that this uses the numeric equality exp(x) = pow(2, x * log2(e)) for + efficiency reasons internally, therefore, either the Q or the K matrix + _and_ the linear biases must be pre-scaled by log2(e) before being passed + into the generated kernel. + """ + + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD_QK = index_symbol("LOAD_ELEMS_PER_THREAD_QK") + LOAD_ELEMS_PER_THREAD_PV = index_symbol("LOAD_ELEMS_PER_THREAD_PV") + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 1)] + + if mfma_variant[1] == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant[1] == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(4, 1, 1), + mma_type=mfma_variant[1], + vector_shapes={B: 0, M: Mvec, N: Nvec}, + ) + ] + + if dynamic_dims: + constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + ) + + @tkw.wave(constraints) + def alibi_attention( + q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + m: tkl.Memory[B, GLOBAL_ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + m_reg = tkw.read(m, elements_per_thread=1) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0]) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + + #################################################################### + # ALiBi + #################################################################### + # ALiBi is essentially adding a lower-triangular matrix of constants + # to the result of QK product. The triangular matrix is obtained by + # doing min(j - i, 0). + i = tkw.self_index(M, tkl.i64, elements_per_thread=1) + i = tkw.broadcast(i, target_shape=[B, M, K2]) + j = tkw.self_index(K2, tkl.i64, elements_per_thread=4) + zero = tkl.Register[B, M, K2, tkl.i64](0) + idx = tkw.minimum(j - i, zero) + local_m = tkw.broadcast(m_reg, target_shape=[B, M, K2]) + bias = local_m * tkw.cast(idx, tkl.f32) + x_j = x_j + bias + + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD_PV) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + + reciprocal_sum = tkw.reciprocal(res_sum) + res = res_mm * reciprocal_sum + tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD_QK: get_mfma_load_elems_per_thread(mfma_variant[0]), + LOAD_ELEMS_PER_THREAD_PV: get_mfma_load_elems_per_thread(mfma_variant[1]), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant[1]), + BLOCK_B: 1, + BLOCK_M: 128, + BLOCK_N: 64, + BLOCK_K2: 64, + B: shape.num_query_heads, + M: shape.query_seq_len, + N: shape.head_size_kv, + K1: shape.head_size, + K2: shape.kv_seq_len, + } + + dynamic_symbols = [] + dynamic_symbols_map = {} + if dynamic_dims: + dynamic_symbols_map[M] = hyperparams[M] + dynamic_symbols_map[N] = hyperparams[N] + dynamic_symbols_map[B] = hyperparams[B] + dynamic_symbols_map[K2] = hyperparams[K2] + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(B) + dynamic_symbols.append(K2) + del hyperparams[M] + del hyperparams[N] + del hyperparams[B] + del hyperparams[K2] + + return alibi_attention, hyperparams, dynamic_symbols, dynamic_symbols_map diff --git a/tests/kernel/wave/attention/alibi_attention_test.py b/tests/kernel/wave/attention/alibi_attention_test.py new file mode 100644 index 000000000..9da2d6c26 --- /dev/null +++ b/tests/kernel/wave/attention/alibi_attention_test.py @@ -0,0 +1,168 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest +import torch +import math +import iree.turbine.kernel as tk +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import ( + get_default_run_config, + get_default_scheduling_params, + device_arange, + device_full, + device_randn, + device_zeros, + to_default_device, +) +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.templates.alibi_attention import ( + get_alibi_attention_kernel, +) +from iree.turbine.kernel.wave.templates.attention_common import ( + AttentionShape, +) +import os +from torch.testing import assert_close +from ..common.utils import ( + require_e2e, + require_cdna3, + enable_scheduling_barriers, + dump_generated_mlir, + get_default_arch, +) +from ..common.shapes import get_test_shapes +from typing import List, Optional, Tuple + +shapes = [(128, 128, 128, 128, 128, 128)] + +def get_relative_positions(seq_len: int, kv_seq_len: Optional[int] = None) -> torch.Tensor: + """Returns a lower-trinagular tensor with distance between rows and columns. + + The tensor resembles the following: + + [ 0 0 0 0 0] + [-1 0 0 0 0] + [-2 -1 0 0 0] + [-3 -2 -1 0 0] + """ + if not kv_seq_len: kv_seq_len = seq_len + x = torch.arange(kv_seq_len)[None, :] + y = torch.arange(seq_len)[:, None] + return to_default_device(torch.minimum(x - y, torch.zeros(seq_len, kv_seq_len))) + + +def precompute_alibi_slopes(n_heads: int) -> torch.Tensor: + """Computes the constant slopes of linear biases to be added to the attention scores.""" + n = 2 ** math.floor(math.log2(n_heads)) + m_0 = 2.0 ** (-8.0 / n) + m = torch.pow(m_0, torch.arange(1, 1 + n)) + if n < n_heads: + m_hat_0 = 2.0 ** (-4.0 / n) + m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2)) + m = torch.cat([m, m_hat]) + return to_default_device(m) + + +def validate_accuracy( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor +) -> torch.Tensor: + # Precompute values. + dk_sqrt = math.sqrt(1.0 / query.shape[-1]) + alibi_slopes = precompute_alibi_slopes(query.shape[0]) + + # Straightforward implementation of attention with bias. + scores = torch.matmul(query, key.transpose(-1, -2)) * dk_sqrt + bias = alibi_slopes.unsqueeze(-1).unsqueeze(-1) * get_relative_positions(query.shape[1], key.shape[1]) + bias = bias.to(dtype=scores.dtype) + scores = scores + bias + reference = torch.matmul(torch.softmax(scores, dim=-1), value) + assert_close(reference, output, check_dtype=False, rtol=2e-3, atol=2e-3) + return reference + + +def create_inputs( + shape: AttentionShape, + dtype: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size) + k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size) + v_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size_kv) + q = device_randn(q_shape, dtype=dtype) + k = device_randn(k_shape, dtype=dtype) + v = device_randn(v_shape, dtype=dtype) + return (q, k, v) + +@require_e2e +@pytest.mark.parametrize("shape", shapes) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize( + "mfma_variant", + [(MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)], +) +def test_alibi_attention( + shape: tuple[int], + dtype: torch.dtype, + mfma_variant: MMAType, + request, +): + torch.manual_seed(0) + shape = AttentionShape( + num_query_heads=shape[0], + num_kv_heads=shape[1], + head_size=shape[2], + head_size_kv=shape[3], + query_seq_len=shape[4], + kv_seq_len=shape[5], + ) + assert shape.num_query_heads % shape.num_kv_heads == 0 + + (query, key, value) = create_inputs(shape, dtype) + alibi_attention, hyperparams, _, _ = get_alibi_attention_kernel(shape, mfma_variant, dynamic_dims=False) + output_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size_kv) + + hyperparams.update(get_default_scheduling_params()) + config = get_default_run_config() + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + log2e = 1.44269504089 + dk_sqrt = math.sqrt(1.0 / shape.head_size) + alibi_slopes = precompute_alibi_slopes(shape.head_size) + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + use_scheduling_barriers=enable_scheduling_barriers, + ): + output = device_zeros(output_shape, dtype=torch.float32) + # TODO: Add scaling of QK and ALiBi as part of kernel. + alibi_attention( + query * dk_sqrt * log2e, + key, + value.permute([0, 2, 1]), + # NOTE: since the kernel uses exp2 instead of exp, the ALiBi slopes must be + # multiplied by the same factor as the Q matrix to preserve the result post + # softmax: exp(x + alibi) = exp2((x + alibi) * log2(e)) + alibi_slopes * log2e, + output + ) + + validate_accuracy(query, key, value, output) From d94665312281c1fad4fc898e0ca65dedfcf60267 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 10 Feb 2025 06:37:47 -0800 Subject: [PATCH 2/7] Fix access to potentially non-existing variable Signed-off-by: Alex Zinenko --- iree/turbine/kernel/wave/codegen/handlers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/iree/turbine/kernel/wave/codegen/handlers.py b/iree/turbine/kernel/wave/codegen/handlers.py index e47efb789..9bdd4704f 100644 --- a/iree/turbine/kernel/wave/codegen/handlers.py +++ b/iree/turbine/kernel/wave/codegen/handlers.py @@ -849,8 +849,12 @@ def handle_broadcast(emitter: WaveEmitter, node: fx.Node): # Get thread_shape/size for broadcast. get_thread_shape = lambda index: max(subs_idxc(x.size) for x in index.values()) - src_thread_size = get_thread_shape(register.index) if register.index else None - target_thread_size = get_thread_shape(node.index) + src_thread_size = ( + get_thread_shape(register.index) + if hasattr(register, "index") and register.index + else None + ) + target_thread_size = get_thread_shape(node.index) if node.index else None # Check MLIR shape vector_src = cast_vector(emitter, register) From 6bc9ed2f318010a49ed8ed34d5667294a61392db Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 10 Feb 2025 10:15:14 -0800 Subject: [PATCH 3/7] format Signed-off-by: Alex Zinenko --- .../wave/attention/alibi_attention_test.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/kernel/wave/attention/alibi_attention_test.py b/tests/kernel/wave/attention/alibi_attention_test.py index 9da2d6c26..7f3f3a92a 100644 --- a/tests/kernel/wave/attention/alibi_attention_test.py +++ b/tests/kernel/wave/attention/alibi_attention_test.py @@ -39,7 +39,10 @@ shapes = [(128, 128, 128, 128, 128, 128)] -def get_relative_positions(seq_len: int, kv_seq_len: Optional[int] = None) -> torch.Tensor: + +def get_relative_positions( + seq_len: int, kv_seq_len: Optional[int] = None +) -> torch.Tensor: """Returns a lower-trinagular tensor with distance between rows and columns. The tensor resembles the following: @@ -49,7 +52,8 @@ def get_relative_positions(seq_len: int, kv_seq_len: Optional[int] = None) -> to [-2 -1 0 0 0] [-3 -2 -1 0 0] """ - if not kv_seq_len: kv_seq_len = seq_len + if not kv_seq_len: + kv_seq_len = seq_len x = torch.arange(kv_seq_len)[None, :] y = torch.arange(seq_len)[:, None] return to_default_device(torch.minimum(x - y, torch.zeros(seq_len, kv_seq_len))) @@ -68,10 +72,7 @@ def precompute_alibi_slopes(n_heads: int) -> torch.Tensor: def validate_accuracy( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, output: torch.Tensor ) -> torch.Tensor: # Precompute values. dk_sqrt = math.sqrt(1.0 / query.shape[-1]) @@ -79,7 +80,9 @@ def validate_accuracy( # Straightforward implementation of attention with bias. scores = torch.matmul(query, key.transpose(-1, -2)) * dk_sqrt - bias = alibi_slopes.unsqueeze(-1).unsqueeze(-1) * get_relative_positions(query.shape[1], key.shape[1]) + bias = alibi_slopes.unsqueeze(-1).unsqueeze(-1) * get_relative_positions( + query.shape[1], key.shape[1] + ) bias = bias.to(dtype=scores.dtype) scores = scores + bias reference = torch.matmul(torch.softmax(scores, dim=-1), value) @@ -88,8 +91,7 @@ def validate_accuracy( def create_inputs( - shape: AttentionShape, - dtype: torch.dtype + shape: AttentionShape, dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size) k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size) @@ -99,6 +101,7 @@ def create_inputs( v = device_randn(v_shape, dtype=dtype) return (q, k, v) + @require_e2e @pytest.mark.parametrize("shape", shapes) @pytest.mark.parametrize("dtype", [torch.float16]) @@ -124,7 +127,9 @@ def test_alibi_attention( assert shape.num_query_heads % shape.num_kv_heads == 0 (query, key, value) = create_inputs(shape, dtype) - alibi_attention, hyperparams, _, _ = get_alibi_attention_kernel(shape, mfma_variant, dynamic_dims=False) + alibi_attention, hyperparams, _, _ = get_alibi_attention_kernel( + shape, mfma_variant, dynamic_dims=False + ) output_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size_kv) hyperparams.update(get_default_scheduling_params()) @@ -162,7 +167,7 @@ def test_alibi_attention( # multiplied by the same factor as the Q matrix to preserve the result post # softmax: exp(x + alibi) = exp2((x + alibi) * log2(e)) alibi_slopes * log2e, - output + output, ) validate_accuracy(query, key, value, output) From 11801d14ef38c179c160c59bdcf15929024ce596 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 4 Feb 2025 09:16:58 -0800 Subject: [PATCH 4/7] Add T5 RPE variant Signed-off-by: Nicolas Vasilache --- iree/turbine/kernel/ops/wave_ops.py | 5 + iree/turbine/kernel/wave/codegen/handlers.py | 13 +- playground/__init__.py | 0 playground/attention_with_rpe_template.py | 234 +++++++++++++++++++ playground/attention_with_rpe_test.py | 187 +++++++++++++++ playground/causal_attention_template.py | 197 ++++++++++++++++ playground/causal_attention_test.py | 179 ++++++++++++++ playground/stress.py | 203 ++++++++++++++++ playground/triangular.py | 108 +++++++++ 9 files changed, 1125 insertions(+), 1 deletion(-) create mode 100644 playground/__init__.py create mode 100644 playground/attention_with_rpe_template.py create mode 100644 playground/attention_with_rpe_test.py create mode 100644 playground/causal_attention_template.py create mode 100644 playground/causal_attention_test.py create mode 100644 playground/stress.py create mode 100644 playground/triangular.py diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index b0e1da461..85fbc69c1 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -148,6 +148,10 @@ def minimum(lhs: "Register", rhs: "Register") -> "Register": ... +def and_op(lhs: "Register", rhs: "Register") -> "Register": + ... + + def broadcast( arg: "Register", target_shape: Optional[Sequence[IndexExpr | int]] = None ) -> "Register": @@ -769,6 +773,7 @@ def infer_shape(self) -> Any: @define_py_op(operator.truediv) @define_interface_op("maximum") @define_interface_op("minimum") +@define_interface_op("and_op") @dataclass class BinaryPyOp(BinaryOpBase, ABC): def infer_type(self): diff --git a/iree/turbine/kernel/wave/codegen/handlers.py b/iree/turbine/kernel/wave/codegen/handlers.py index 9bdd4704f..ed4bf284e 100644 --- a/iree/turbine/kernel/wave/codegen/handlers.py +++ b/iree/turbine/kernel/wave/codegen/handlers.py @@ -54,6 +54,7 @@ from ...ops.wave_ops import ( abs, allocate, + and_op, apply_expr, broadcast, cast, @@ -523,6 +524,17 @@ def handle_le(lhs: Value, rhs: Value) -> OpResult: return result +@handle_binary_op(and_op) +def handle_and_op(lhs: Value, rhs: Value) -> OpResult: + element_type = get_type_or_element_type(lhs.type) + if _is_integer_like_type(element_type): + result = arith_d.andi(lhs, rhs) + else: + raise ValidationError( + f"Found unhandled operand type for le: {element_type}") + return result + + @handle_binary_op(maximum) def handle_maximum(lhs: Value, rhs: Value) -> OpResult: element_type = get_type_or_element_type(lhs.type) @@ -554,7 +566,6 @@ def handle_minimum(lhs: Value, rhs: Value) -> OpResult: ) return result - ############################################################################### # Unary math Ops ############################################################################### diff --git a/playground/__init__.py b/playground/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/playground/attention_with_rpe_template.py b/playground/attention_with_rpe_template.py new file mode 100644 index 000000000..0c11a111f --- /dev/null +++ b/playground/attention_with_rpe_template.py @@ -0,0 +1,234 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from dataclasses import dataclass +from typing import Optional + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape + + +def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, + dynamic_dims: bool, + max_context_length: Optional[int]): + # RPE + ZERO = tkl.sym.ZERO + OFFSET = tkl.sym.OFFSET + + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD_QK = index_symbol("LOAD_ELEMS_PER_THREAD_QK") + LOAD_ELEMS_PER_THREAD_PV = index_symbol("LOAD_ELEMS_PER_THREAD_PV") + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0) + ] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 1)] + + if mfma_variant[1] == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant[1] == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(4, 1, 1), + mma_type=mfma_variant[1], + vector_shapes={ + B: 0, + M: Mvec, + N: Nvec + }, + ) + ] + + if dynamic_dims: + constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping(num_iterators=3, + inputs={ + B: i, + N: j, + M: k + }, + outputs={ + B: i, + M: k, + N: j + }) + + offset_mapping = tkw.IndexMapping( + num_iterators=1, + inputs={K2: i + OFFSET}, + outputs={K2: i}, + ) + + use_t5_rpe = max_context_length is not None + if use_t5_rpe: + rpe_layout = tkl.MemoryLayout(shape=[ + max_context_length, + ]) + assert use_t5_rpe, "use_t5_rpe needed until rpe arg can DCE without crashing" + + @tkw.wave(constraints) + def base_attention( + q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + # TODO: if not use_t5_rpe, this will DCE; atm DCE on blockargs crashes. + rpe: tkl.Memory[K2, GLOBAL_ADDRESS_SPACE, tkl.f32, rpe_layout], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0]) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + + #################################################################### + # T5 RPE + #################################################################### + # Fused T5 RPE adds attention bias pre-softmax normalization. + # When fusing into the FA variant, adding locally before the max and + # the partial softmax should be equivalent. + if use_t5_rpe: + ZERO = tkl.Register[M, K2, tkl.i64](0) + MAX = tkl.Register[M, K2, tkl.i64](max_context_length) + # 1. Indices i and j broadcasted along K2 with a twist: + # here we use *static* information that is *implicitly* encoded + # in the *transformation*: under the distribution constraints + # specified we know that the shape [M] will eventually resolve + # to [1] and can thus be "cast + broadcast" to [K2]. + i = tkw.self_index(M, tkl.i64, elements_per_thread=1) + i = tkw.broadcast(i, target_shape=[M, K2]) + j = tkw.self_index(K2, tkl.i64, elements_per_thread=1) + + # 2. Clip i - j to the proper bucket in [0, max_context_length] + # to represent the following: + # - if 0 < i - j < max_context_length + # then x_j += rpe_reg + # - otherwise (i.e. i - j == {0, max_context_length}) + # then x_j += 0 + # TODO: we may need scaling adjustements depending on how we want + # to do bucketing; atm it is bucketing of size 1. + + # min/max variant + # idx = tkw.maximum(i - j, ZERO) + # idx = tkw.minimum(idx, MAX) + + # select variant. + idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), + i - j, ZERO) + + # 3. Read indirect into the 1-D rpe array via offset_mapping. + tkw.set_symbol(OFFSET, idx) # offset will have shape [M, K2] + rpe_reg = tkw.read( + rpe, + mapping=offset_mapping, + elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + + # 4. Tadaaaa. + x_j = x_j + rpe_reg + tkw.cast(ZERO * idx, tkl.i64) + + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD_PV) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + reciprocal_sum = tkw.reciprocal(res_sum) + res = res_mm * reciprocal_sum + tkw.write(res, + c, + mapping=mapping, + elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD_QK: + get_mfma_load_elems_per_thread(mfma_variant[0]), + LOAD_ELEMS_PER_THREAD_PV: + get_mfma_load_elems_per_thread(mfma_variant[1]), + STORE_ELEMS_PER_THREAD: + get_mfma_store_elems_per_thread(mfma_variant[1]), + BLOCK_B: 1, + BLOCK_M: 128, + BLOCK_N: 64, + BLOCK_K2: 64, + B: shape.num_query_heads, + M: shape.query_seq_len, + N: shape.head_size_kv, + K1: shape.head_size, + K2: shape.kv_seq_len, + } + + dynamic_symbols = [] + dynamic_symbols_map = {} + if dynamic_dims: + dynamic_symbols_map[M] = hyperparams[M] + dynamic_symbols_map[N] = hyperparams[N] + dynamic_symbols_map[B] = hyperparams[B] + dynamic_symbols_map[K2] = hyperparams[K2] + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(B) + dynamic_symbols.append(K2) + del hyperparams[M] + del hyperparams[N] + del hyperparams[B] + del hyperparams[K2] + + return base_attention, hyperparams, dynamic_symbols, dynamic_symbols_map diff --git a/playground/attention_with_rpe_test.py b/playground/attention_with_rpe_test.py new file mode 100644 index 000000000..b0d161ef6 --- /dev/null +++ b/playground/attention_with_rpe_test.py @@ -0,0 +1,187 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import math +import torch +from torch.profiler import profile, record_function, ProfilerActivity +from torch.nn import functional as F +from torch.testing import assert_close +from typing import Any, Callable + +from iree.turbine.kernel.gen import TestLaunchContext +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape +from iree.turbine.kernel.wave.templates.vanilla_attention import ( + get_vanilla_attention_kernel as get_vanilla_tkw_attention_kernel) +from iree.turbine.kernel.wave.utils import ( + device_randn, + device_zeros, + get_default_run_config, + to_default_device, +) +from attention_with_rpe_template import ( + get_vanilla_attention_kernel as + get_vanilla_tkw_attention_with_rpe_output_kernel) + +torch.manual_seed(0) +torch.set_printoptions( + linewidth=1000000, + threshold=1000000, + precision=3, +) + + +### TKW Harness +def run(fun: Callable, hparams, *args) -> Any: + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.no_grad(): # Disable gradient calculations + with TestLaunchContext( + hparams, + canonicalize=True, + compile_config={"print-ir-after": "all"}, + run=True, + run_config=get_default_run_config(), + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + fun(*args) + + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="self_cuda_time_total", row_limit=10)) + + +################################################################################# +# INIT VALS +################################################################################# +# num_query_heads, num_kv_heads, head_size, head_size_kv +shape = AttentionShape(128, 128, 128, 128) +shape.query_seq_len = 128 +shape.kv_seq_len = 128 + +assert shape.num_query_heads == shape.num_kv_heads, \ + "expected query and kv to have the same number of heads!" + +q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size) +k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size) +v_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size_kv) +o_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size_kv) + +q = device_randn(q_shape, dtype=torch.float16) +k = device_randn(k_shape, dtype=torch.float16) +v = device_randn(v_shape, dtype=torch.float16) +tkw_attention_output = device_zeros(o_shape, dtype=torch.float32) +tkw_attention_with_rpe_output = device_zeros(o_shape, dtype=torch.float32) + +log2e = 1.44269504089 +dk_sqrt = math.sqrt(1.0 / q.shape[-1]) + +################################################################################# +# T5 RPE INIT VALS +################################################################################# +# T5 RPE parameter +max_context_length = 33 + +# Applied pre-softmax on the MMA'ed result so f32. +# Provision more room for clipping and adding 0 at the boundaries. +rpe = device_zeros(1000 + max_context_length + 2, dtype=torch.float32) +rpe = rpe[:max_context_length + 2].view(max_context_length + 2) +rpe.copy_(device_randn(max_context_length + 2, dtype=torch.float32)) +rpe[0] = 0 +rpe[max_context_length + 1] = 0 + + +def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, + dtype): + positions = to_default_device(torch.arange(sequence_length)) + pos_diff = positions.unsqueeze(1) - positions.unsqueeze(0) + mask = to_default_device((pos_diff >= 0) + & (pos_diff <= max_context_length)) + rpe_cond = device_zeros(sequence_length, sequence_length, dtype=dtype) + rpe_cond[mask] = rpe[pos_diff[mask]] + return rpe_cond + + +# rpe_cond is used by torch only +rpe_cond = t5_rpe_masked_cond(rpe, + max_context_length=max_context_length, + sequence_length=shape.kv_seq_len, + dtype=tkw_attention_with_rpe_output.dtype) + +################################################################################# +# TKW BASE ATTENTION +################################################################################# +### RPE version +tkw_attention_with_rpe, hyperparams, dynamic_symbols, dynamic_symbols_map = \ + get_vanilla_tkw_attention_with_rpe_output_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, + MMAType.F32_16x16x16_F16], + dynamic_dims=False, + max_context_length = max_context_length + 2) + + +def attention_with_rpe(tq, tk, tv, trpe, toutput): + mb = tkw_attention_with_rpe(tq, tk, tv, trpe, toutput) + print(mb.module_op) + + +run(attention_with_rpe, hyperparams, q * dk_sqrt * log2e, k, + v.permute([0, 2, 1]), rpe, tkw_attention_with_rpe_output) + +### Reference version +tkw_attention, hyperparams, dynamic_symbols, dynamic_symbols_map = \ + get_vanilla_tkw_attention_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, + MMAType.F32_16x16x16_F16], + dynamic_dims=False) + + +def attention(tq, tk, tv, toutput): + tkw_attention(tq, tk, tv, toutput) + + +run(attention, hyperparams, q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), + tkw_attention_output) + +tkw_rpe_delta_output = tkw_attention_with_rpe_output - tkw_attention_output +# print(tkw_rpe_delta_output) + +################################################################################# +# TORCH ATTENTION and ATTENTION + RPE +################################################################################# +torch_attention_ref_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None) + +a = torch.matmul(q, k.transpose(-1, -2)) * dk_sqrt +torch_attention_output = torch.matmul(torch.softmax(a, dim=-1), v) + +# Sanity check that torch_attention_output and torch_attention_ref_output are +# the same so we can inject RPE pre-softmax and compute the delta. +# We will test that the delta post-softmax is the same for torch and TKW. +assert_close(torch_attention_output, + torch_attention_ref_output, + atol=2e-3, + rtol=2e-3) + +a += rpe_cond.unsqueeze(0) +torch_attention_with_rpe_output = torch.matmul(F.softmax(a, dim=-1), v) +torch_rpe_delta_output = torch_attention_with_rpe_output - torch_attention_output + +# Check basic attentions match as we expect. +assert_close(torch_attention_output.to(dtype=tkw_attention_output.dtype), + tkw_attention_output, + atol=2e-3, + rtol=2e-3) + +# Check RPE attentions match as we expect. +assert_close(torch_rpe_delta_output.to(dtype=tkw_rpe_delta_output.dtype), + tkw_rpe_delta_output, + atol=2e-3, + rtol=2e-3) diff --git a/playground/causal_attention_template.py b/playground/causal_attention_template.py new file mode 100644 index 000000000..f2de1fe25 --- /dev/null +++ b/playground/causal_attention_template.py @@ -0,0 +1,197 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from dataclasses import dataclass +import sympy +import sys +from typing import Optional + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape + + +def get_causal_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, + dynamic_dims: bool): + ZERO = tkl.sym.ZERO + + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD_QK = index_symbol("LOAD_ELEMS_PER_THREAD_QK") + LOAD_ELEMS_PER_THREAD_PV = index_symbol("LOAD_ELEMS_PER_THREAD_PV") + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0) + ] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 1)] + + if mfma_variant[1] == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant[1] == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(4, 1, 1), + mma_type=mfma_variant[1], + vector_shapes={ + B: 0, + M: Mvec, + N: Nvec + }, + ) + ] + + if dynamic_dims: + constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping(num_iterators=3, + inputs={ + B: i, + N: j, + M: k + }, + outputs={ + B: i, + M: k, + N: j + }) + + @tkw.wave(constraints) + def base_attention( + q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + tmp: tkl.Memory[K2, ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + imm_reg = tkw.permute(imm_reg, target_shape=[B, M, K2]) + #################################################################### + # Causal mask + #################################################################### + # Indices i and j broadcasted along K2 with a twist: + # here we use *static* information that is *implicitly* encoded + # in the *transformation*: under the distribution constraints + # specified we know that the shape [M] will eventually resolve + # to [1] and can thus be "cast + broadcast" to [K2]. + i = tkw.self_index(M, tkl.i64, elements_per_thread=1) + i = tkw.broadcast(i, target_shape=[K2]) + j = tkw.self_index(K2, tkl.i64, elements_per_thread=1) + ZERO = tkl.Register[K2, tkl.i64](0) + ONE = tkl.Register[K2, tkl.i64](1) + ZEROF = tkl.Register[K2, tkl.f32](0.0) + MIN_INF = tkl.Register[K2, tkl.f32](float('-inf')) + idx = j - i - ONE + bias = tkw.select(tkw.slt(idx, ZERO), ZEROF, MIN_INF) + ### Apply causality mask to imm_reg. + imm_reg = imm_reg + bias + imm_reg = tkw.permute(imm_reg, target_shape=[B, K2, M]) + + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0]) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD_PV) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + reciprocal_sum = tkw.reciprocal(res_sum) + res = res_mm * reciprocal_sum + tkw.write(res, + c, + mapping=mapping, + elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD_QK: + get_mfma_load_elems_per_thread(mfma_variant[0]), + LOAD_ELEMS_PER_THREAD_PV: + get_mfma_load_elems_per_thread(mfma_variant[1]), + STORE_ELEMS_PER_THREAD: + get_mfma_store_elems_per_thread(mfma_variant[1]), + BLOCK_B: 1, + BLOCK_M: 128, + BLOCK_N: 64, + BLOCK_K2: 64, + B: shape.num_query_heads, + M: shape.query_seq_len, + N: shape.head_size_kv, + K1: shape.head_size, + K2: shape.kv_seq_len, + ZERO: 0, + } + + dynamic_symbols = [] + dynamic_symbols_map = {} + if dynamic_dims: + dynamic_symbols_map[M] = hyperparams[M] + dynamic_symbols_map[N] = hyperparams[N] + dynamic_symbols_map[B] = hyperparams[B] + dynamic_symbols_map[K2] = hyperparams[K2] + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(B) + dynamic_symbols.append(K2) + del hyperparams[M] + del hyperparams[N] + del hyperparams[B] + del hyperparams[K2] + + return base_attention, hyperparams, dynamic_symbols, dynamic_symbols_map diff --git a/playground/causal_attention_test.py b/playground/causal_attention_test.py new file mode 100644 index 000000000..77a9a4ee8 --- /dev/null +++ b/playground/causal_attention_test.py @@ -0,0 +1,179 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import math +import torch +from torch.profiler import profile, record_function, ProfilerActivity +from torch.nn import functional as F +from torch.testing import assert_close +from typing import Any, Callable + +from iree.turbine.kernel.gen import TestLaunchContext +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape +from iree.turbine.kernel.wave.utils import ( + device_randn, + device_zeros, + get_default_run_config, + to_default_device, +) +from causal_attention_template import (get_causal_attention_kernel as + get_tkw_causal_attention_kernel) +from iree.turbine.kernel.wave.templates.vanilla_attention import ( + get_vanilla_attention_kernel as get_vanilla_tkw_attention_kernel) + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +torch.manual_seed(0) +torch.set_printoptions( + linewidth=1000000, + threshold=1000000, + precision=3, +) + + +def find_different_coordinates(tensor1, tensor2, rtol=1e-5, atol=1e-8): + # Calculate the difference in float32 to avoid range issues + diff = (tensor1.float() - tensor2.float()).abs() + + # Create a mask where the difference exceeds the tolerance + tolerance = atol + rtol * tensor2.float().abs() + diff_mask = diff > tolerance + + if not diff_mask.any(): # Tensors are close if the mask is all False + print("Tensors are close.") + return [] + + diff_indices = torch.nonzero(diff_mask) + + print("Tensors are different at the following coordinates:") + for coords in diff_indices: + print(tuple(coords.tolist())) + + return diff_indices + + +### TKW Harness +def run(fun: Callable, hparams, *args) -> Any: + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.no_grad(): # Disable gradient calculations + with TestLaunchContext( + hparams, + canonicalize=True, + # compile_config={"print_ir_after": "all"}, + run=True, + run_config=get_default_run_config(), + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + fun(*args) + + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="self_cuda_time_total", row_limit=10)) + + +################################################################################# +# INIT VALS +################################################################################# +# num_query_heads, num_kv_heads, head_size, head_size_kv +shape = AttentionShape(128, 128, 128, 128) +shape.query_seq_len = 128 +shape.kv_seq_len = 128 + +assert shape.num_query_heads == shape.num_kv_heads, \ + "expected query and kv to have the same number of heads!" + +q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size) +k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size) +v_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size_kv) +o_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size_kv) + +q = device_randn(q_shape, dtype=torch.float16) +k = device_randn(k_shape, dtype=torch.float16) +v = device_randn(v_shape, dtype=torch.float16) +tkw_attention_output = device_zeros(o_shape, dtype=torch.float32) +tkw_causal_attention_output = device_zeros(o_shape, dtype=torch.float32) + +log2e = 1.44269504089 +dk_sqrt = math.sqrt(1.0 / q.shape[-1]) + +################################################################################# +# TORCH ATTENTION +################################################################################# +torch_causal_attention_ref_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True) +torch_attention_ref_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None) +torch_delta = torch_attention_ref_output - torch_causal_attention_ref_output +print(torch_delta) + +################################################################################# +# TKW ATTENTION +################################################################################# +### Reference version +tkw_attention, hyperparams, dynamic_symbols, dynamic_symbols_map = \ + get_vanilla_tkw_attention_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, + MMAType.F32_16x16x16_F16], + dynamic_dims=False) + + +def attention(tq, tk, tv, toutput): + tkw_attention(tq, tk, tv, toutput) + + +run(attention, hyperparams, q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), + tkw_attention_output) + +assert_close(torch_attention_ref_output.to(dtype=tkw_attention_output.dtype), + tkw_attention_output, + atol=2e-3, + rtol=2e-3) + +### Causal version +tkw_causal_attention, hyperparams, dynamic_symbols, dynamic_symbols_map = \ + get_tkw_causal_attention_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, + MMAType.F32_16x16x16_F16], + dynamic_dims=False) + + +def causal_attention(tq, tk, tv, toutput): + mb = tkw_causal_attention(tq, tk, tv, toutput) + print(mb.module_op) + + +run(causal_attention, hyperparams, q * dk_sqrt * log2e, k, + v.permute([0, 2, 1]), tkw_causal_attention_output) + +# tkw_delta = tkw_causal_attention_output - tkw_attention_output +# print(tkw_delta) +# print(torch_causal_attention_ref_output[67, 16]) +# print(tkw_causal_attention_output[67, 16]) + +# Coordinates where we see discrepancies are: +# (*, 16:31, *) +# (*, 48:63, *) +# (*, 80:95, *) +# (*, 80:95, *) +# (*, 112:127, *) +# different_coords = find_different_coordinates( +# torch_causal_attention_ref_output, +# tkw_causal_attention_output, +# rtol=2e-3, +# atol=2e-3) +# print(different_coords) + +assert_close(torch_causal_attention_ref_output.to( + dtype=tkw_causal_attention_output.dtype), + tkw_causal_attention_output, + atol=2e-3, + rtol=2e-3) diff --git a/playground/stress.py b/playground/stress.py new file mode 100644 index 000000000..ac77b7613 --- /dev/null +++ b/playground/stress.py @@ -0,0 +1,203 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pdb import run +import pytest +import torch +from typing import Callable + +from iree.turbine.kernel._support.tracing import TestLaunchContext +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.utils import (device_randn, device_zeros, + get_default_run_config) +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw + +torch.set_printoptions(linewidth=300) + + +# We want each row to contain [0 .. num_cols] +def reference_row(rows: int, cols: int): + row_indices = torch.arange(cols).unsqueeze(0).expand(rows, cols) + print(row_indices.shape) + return row_indices + + +# We want each col to contain [0 .. num_rows] +def reference_col(rows: int, cols: int): + col_indices = torch.arange(rows).unsqueeze(1).expand(rows, cols) + return col_indices + + +def reference_row_plus_col(rows: int, cols: int): + return reference_row(rows, cols) + reference_col(rows, cols) + + +# Input sizes +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K +# Workgroup tile sizes +ITERATIONS_OF_M_PER_WAVE = tkl.sym.ITERATIONS_OF_M_PER_WAVE +ITERATIONS_OF_N_PER_WAVE = tkl.sym.ITERATIONS_OF_N_PER_WAVE +BLOCK_K = tkl.sym.BLOCK_K +# Address space (for GPU, shared(1) or global(0)) +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +# Other hyperparameters +LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD +STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + +# yapf: disable +def run_harness(fun: Callable, vM: int, vN: int, *args) -> bool: + config = get_default_run_config() + # Override manually to run. + config = {"backend": "rocm", "device": "hip", "target": "gfx90a"} + with TestLaunchContext({M: vM, N: vN}, + canonicalize=True, + run=True, + run_config=config): + return fun(*args) + + +# yapf: disable +### Setting all of the following at the same time to agree on the same value works: +# - ITERATIONS_OF_N_PER_WAVE +# - VECTOR_SHAPE_N +# - ELEMENTS_PER_THREAD_STORE +params = [ +# SIZE_M, SIZE_N, ITERATIONS_OF_M_PER_WAVE, ITERATIONS_OF_N_PER_WAVE, VECTOR_SHAPE_M, VECTOR_SHAPE_N, ELEMENTS_PER_THREAD_STORE + [ 4, 8, 1, 1, 1, 1, 1, ], + [ 4, 8, 1, 2, 1, 2, 2, ], + [ 4, 8, 1, 3, 1, 3, 3, ], + [ 4, 8, 1, 4, 1, 4, 4, ], +] +### However, The slightest discrepancy throws the TK compiler off: +# params = [ +# SIZE_M, SIZE_N, ITERATIONS_OF_M_PER_WAVE, ITERATIONS_OF_N_PER_WAVE, VECTOR_SHAPE_M, VECTOR_SHAPE_N, ELEMENTS_PER_THREAD_STORE +# [ 4, 8, 1, 1, 1, 4, 4, 4, ], # Tile size must be divisible by wave count and vector size, got: tile_size=1, wave_count=1, vector_size=4 +# [ 4, 8, 1, 4, 1, 1, 4, 4, ], # MISCOMPILE INCORRECT RESULTS +# [ 4, 8, 1, 4, 1, 4, 1, 4, ], # CRASH TK COMPILER: Shape doesn't match: (1,) and (4,) in register cast_M:0_N:0 and elements_per_thread 4 +# [ 4, 8, 1, 4, 1, 4, 4, 1, ], # CRASH TK COMPILER: Shape doesn't match: (4,) and (1,) in register cast_M:0_N:0 and elements_per_thread 1 +# ] + +for p in params: + SIZE_M, \ + SIZE_N, \ + ITERATIONS_OF_M_PER_WAVE, \ + ITERATIONS_OF_N_PER_WAVE, \ + VECTOR_SHAPE_M, \ + VECTOR_SHAPE_N, \ + ELEMENTS_PER_THREAD_STORE = p + + workgroup_constraints = [ + [tkw.WorkgroupConstraint(M, ITERATIONS_OF_M_PER_WAVE, 0)], + [tkw.WorkgroupConstraint(N, ITERATIONS_OF_N_PER_WAVE, 1)], + [ + tkw.WorkgroupConstraint(M, ITERATIONS_OF_M_PER_WAVE, 0), + tkw.WorkgroupConstraint(N, ITERATIONS_OF_N_PER_WAVE, 1) + ], + ] + wave_constraints = [ + [], + [tkw.WaveConstraint(M, 1)], + [tkw.WaveConstraint(N, 1)], + [tkw.WaveConstraint(M, 1), tkw.WaveConstraint(N, 1)], + [tkw.WaveConstraint(M, 2)], + [tkw.WaveConstraint(N, 2)], + [tkw.WaveConstraint(M, 2), tkw.WaveConstraint(N, 2)], + [tkw.WaveConstraint(M, 2), tkw.WaveConstraint(N, 2)], + [tkw.WaveConstraint(M, 1), tkw.WaveConstraint(N, 2)], + [tkw.WaveConstraint(M, 2), tkw.WaveConstraint(N, 1)], + ] + # yapf: enable + + for wgs in workgroup_constraints: + for wvs in wave_constraints: + unroll_N = True + # In these stress tests compute self_index(N) and we want to distinguish + # between the cases: + # 1. there is a WorkgroupConstraint on N, therefore N is distributed + # and using ELEMENTS_PER_THREAD_INDEX == 1 results in proper + # propagations + # 2. there is no WorkgroupConstraint on N, therefore N is unrolled and + # we have to use ELEMENTS_PER_THREAD_INDEX == ELEMENTS_PER_THREAD_STORE + # otherwise the TK compiler gets confused atm. + # Ideally, in the future, things would just work out of the box without + # having to adjust ELEMENTS_PER_THREAD_INDEX + for wg in wgs: + if wg.dim == N: + unroll_N = False + + # Skip this particular constraint if a WaveConstraint is set without + # first setting the corresponding WorkgroupConstraint: + # TK does not handle that case + skip = False + for wv in wvs: + skip_wv = True + for wg in wgs: + if wg.dim == wv.dim: + skip_wv = False + if skip_wv: + skip = True + if skip: + continue + + ELEMENTS_PER_THREAD_INDEX = ELEMENTS_PER_THREAD_STORE if unroll_N else 1 + + ###### User constraints + constraints: list[tkw.Constraint] = [] + constraints += wgs + constraints += wvs + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={ + M: VECTOR_SHAPE_M, + N: VECTOR_SHAPE_N + }, + ) + ] + + ###### Known cases to skip: + # When we unroll N, TK does not handle imperfect unrolling (with a + # remainder). + if unroll_N and SIZE_N % ITERATIONS_OF_N_PER_WAVE != 0: + continue + + @tkw.wave(constraints) + def row(c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32]): + i = tkw.self_index( + N, tkl.i64, elements_per_thread=ELEMENTS_PER_THREAD_INDEX) + res = tkw.cast(i, tkl.f32) + tkw.write(res, + c, + elements_per_thread=ELEMENTS_PER_THREAD_STORE) + + def fun_row(debug: bool = False) -> bool: + c = device_zeros(SIZE_M, SIZE_N, dtype=torch.float32) + if debug: + print(row(c).module_op) + return True + else: + row(c) + correct = torch.all( + torch.isclose(reference_row(SIZE_M, SIZE_N), + c.cpu().to(dtype=torch.int64))).item() + if not correct: + print(f"reference:\n{reference_row(SIZE_M, SIZE_N)}") + print(f"actual:\n{c.cpu().to(dtype=torch.int64)}") + print( + f"delta:\n{c.cpu().to(dtype=torch.int64) - reference_row(SIZE_M, SIZE_N)}" + ) + return correct + + correct = run_harness(fun_row, SIZE_M, SIZE_N) + if not correct: + print(f"\nError under stress test constraints: {constraints}") + run_harness(fun_row, SIZE_M, SIZE_N, True) + assert correct, "Incorrect execution: ran in debug mode now stop" diff --git a/playground/triangular.py b/playground/triangular.py new file mode 100644 index 000000000..58e32e5fc --- /dev/null +++ b/playground/triangular.py @@ -0,0 +1,108 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import math +import torch +from torch.profiler import profile, record_function, ProfilerActivity +from torch.nn import functional as F +from torch.testing import assert_close +from typing import Any, Callable + +from iree.turbine.kernel.gen import TestLaunchContext +from iree.turbine.kernel.lang.global_symbols import GLOBAL_ADDRESS_SPACE +from iree.turbine.kernel.wave.utils import ( + device_zeros, + get_default_run_config, + to_default_device, +) +from causal_attention_template import (get_causal_attention_kernel as + get_tkw_causal_attention_kernel) + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw + +torch.manual_seed(0) +torch.set_printoptions( + linewidth=1000000, + threshold=1000000, + precision=3, +) + +vM, vN = 10, 10 +torch_o = to_default_device(torch.ones(vM, vN)) +temp_mask = to_default_device( + torch.ones(vM, vN, dtype=torch.bool).tril(diagonal=0)) +torch_o.masked_fill_(temp_mask.logical_not(), float("-inf")) + +M = tkl.sym.M +N = tkl.sym.N +ONE = tkl.sym.ONE +# Expose user-constraints +constraints: list[tkw.Constraint] = [] +constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={ + M: 1, + N: 1, + }, + ) +] + +constraints += [tkw.WorkgroupConstraint(M, 1, 0)] +constraints += [tkw.WorkgroupConstraint(N, 1, 1)] + +# WARNING: these constraints generate wrong code +# constraints += [tkw.WorkgroupConstraint(M, 2, 0)] +# constraints += [tkw.WorkgroupConstraint(N, 2, 1)] +# constraints += [tkw.WaveConstraint(M, 1)] +# constraints += [tkw.WaveConstraint(N, 1)] + + +### TKW Harness +def run(fun: Callable, hparams, *args) -> Any: + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.no_grad(): # Disable gradient calculations + with TestLaunchContext( + hparams, + canonicalize=True, + # compile_config={"print_ir_after": "all"}, + run=True, + run_config=get_default_run_config(), + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + mb = fun(*args) + print(mb.module_op) + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="self_cuda_time_total", row_limit=10)) + + +@tkw.wave(constraints) +def test(o: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32]): + i = tkw.self_index(M, tkl.i64, elements_per_thread=1) + i = tkw.broadcast(i, target_shape=[N]) + j = tkw.self_index(N, tkl.i64, elements_per_thread=1) + ZERO = tkl.Register[N, tkl.i64](0) + ONE = tkl.Register[N, tkl.i64](1) + ZEROF = tkl.Register[N, tkl.f32](0.0) + MIN_INF = tkl.Register[N, tkl.f32](float('-inf')) + idx = j - i - ONE + res = tkw.select(tkw.slt(idx, ZERO), ZEROF, MIN_INF) + val = tkw.read(o, elements_per_thread=1) + res += val + tkw.write(res, o, elements_per_thread=1) + + +o = to_default_device(torch.ones(vM, vN)) +run(test, {M: vM, N: vN, ONE: 1}, o) + +# print(o) +assert_close(torch_o.to(dtype=o.dtype), o, atol=2e-3, rtol=2e-3) From 92107e89f4b5d9fcb9bee3a863341702480cd2b7 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 5 Feb 2025 01:20:15 -0800 Subject: [PATCH 5/7] Debug OFFSET does not accept vector --- iree/turbine/kernel/compiler/vector_codegen.py | 8 +++++--- iree/turbine/kernel/wave/codegen/handlers.py | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/iree/turbine/kernel/compiler/vector_codegen.py b/iree/turbine/kernel/compiler/vector_codegen.py index ae3d21ab0..6becdab94 100644 --- a/iree/turbine/kernel/compiler/vector_codegen.py +++ b/iree/turbine/kernel/compiler/vector_codegen.py @@ -854,9 +854,10 @@ def cast_kernel_buffer( return value, MemRefType(ir_type), py_type -def cast_vector( - emitter: ThreadEmitter, value, *, element_type: Optional[IrType] = None -): +def cast_vector(emitter: ThreadEmitter, + value, + *, + element_type: Optional[IrType] = None): proxy_value = cast_py_value(emitter, value) # Cast scalar types correctly first. @@ -864,6 +865,7 @@ def cast_vector( # Implicit scalar type promotion. proxy_value = ScalarBuilder.to_dtype(proxy_value, element_type) + print(f"proxy_value {proxy_value}") value = proxy_value.ir_value # After scalar promotion, promote to vector. diff --git a/iree/turbine/kernel/wave/codegen/handlers.py b/iree/turbine/kernel/wave/codegen/handlers.py index ed4bf284e..3c9098e02 100644 --- a/iree/turbine/kernel/wave/codegen/handlers.py +++ b/iree/turbine/kernel/wave/codegen/handlers.py @@ -262,7 +262,8 @@ def handle_set_symbol(emitter: WaveEmitter, node: fx.Node): raise ValidationError("Malformed arguments") from e register = cast_vector(emitter, register, element_type=IndexType.get()) - emitter.dynamic_dims[symbol] = _to_scalar(register) + # emitter.dynamic_dims[symbol] = _to_scalar(register) + emitter.dynamic_dims[symbol] = register ############################################################################### From 8882a50e582f49e0b7ee3c92823bbe00a45c7dbc Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 7 Feb 2025 14:01:39 -0800 Subject: [PATCH 6/7] make rpe compile, still produces garbage --- .../turbine/kernel/compiler/vector_codegen.py | 9 +- iree/turbine/kernel/wave/codegen/handlers.py | 4 +- .../turbine/kernel/wave/codegen/read_write.py | 109 ++++++++------ iree/turbine/kernel/wave/utils.py | 10 ++ playground/attention_with_rpe_template.py | 73 ++++------ playground/attention_with_rpe_test.py | 136 +++++++++++------- 6 files changed, 194 insertions(+), 147 deletions(-) diff --git a/iree/turbine/kernel/compiler/vector_codegen.py b/iree/turbine/kernel/compiler/vector_codegen.py index 6becdab94..5586c8507 100644 --- a/iree/turbine/kernel/compiler/vector_codegen.py +++ b/iree/turbine/kernel/compiler/vector_codegen.py @@ -854,10 +854,9 @@ def cast_kernel_buffer( return value, MemRefType(ir_type), py_type -def cast_vector(emitter: ThreadEmitter, - value, - *, - element_type: Optional[IrType] = None): +def cast_vector( + emitter: ThreadEmitter, value, *, element_type: Optional[IrType] = None +): proxy_value = cast_py_value(emitter, value) # Cast scalar types correctly first. @@ -865,7 +864,7 @@ def cast_vector(emitter: ThreadEmitter, # Implicit scalar type promotion. proxy_value = ScalarBuilder.to_dtype(proxy_value, element_type) - print(f"proxy_value {proxy_value}") + # print(f"proxy_value {proxy_value}") value = proxy_value.ir_value # After scalar promotion, promote to vector. diff --git a/iree/turbine/kernel/wave/codegen/handlers.py b/iree/turbine/kernel/wave/codegen/handlers.py index 3c9098e02..1284c8884 100644 --- a/iree/turbine/kernel/wave/codegen/handlers.py +++ b/iree/turbine/kernel/wave/codegen/handlers.py @@ -531,8 +531,7 @@ def handle_and_op(lhs: Value, rhs: Value) -> OpResult: if _is_integer_like_type(element_type): result = arith_d.andi(lhs, rhs) else: - raise ValidationError( - f"Found unhandled operand type for le: {element_type}") + raise ValidationError(f"Found unhandled operand type for le: {element_type}") return result @@ -567,6 +566,7 @@ def handle_minimum(lhs: Value, rhs: Value) -> OpResult: ) return result + ############################################################################### # Unary math Ops ############################################################################### diff --git a/iree/turbine/kernel/wave/codegen/read_write.py b/iree/turbine/kernel/wave/codegen/read_write.py index 737807256..da12b2cb8 100644 --- a/iree/turbine/kernel/wave/codegen/read_write.py +++ b/iree/turbine/kernel/wave/codegen/read_write.py @@ -7,6 +7,7 @@ import sympy import functools from typing import Any, Callable, ClassVar, Optional, List, Type, Dict +import math import torch.fx as fx @@ -134,19 +135,9 @@ def _build_mask( return mask -def _get_splat_const(vec_type: IrType, value: Any) -> Value: - splat = DenseElementsAttr.get_splat( - vec_type, get_constant_attr(value, vec_type.element_type) - ) - return arith_d.constant(vec_type, splat) - - -def _constant_mask(vec_type: IrType) -> Value: - return _get_splat_const(vec_type, 1) - - def _construct_gather_scatter_indices( emitter: WaveEmitter, + # TODO TODO TODO fix typo symbolc_shape: tuple[IndexExpr], index: tuple[IndexExpr], mapping: IndexMapping, @@ -154,6 +145,7 @@ def _construct_gather_scatter_indices( is_read: bool, dynamic_vals: tuple[Any, ...], is_contiguous: bool, + vector_shaped_symbols={}, ) -> tuple[OpResult, OpResult, OpResult]: # Apply symbolc_shape order to indices, e.g. if original mapping is # {M: iter(0), N: iter(1)} and symbolc_shape is (N, M), result will @@ -189,7 +181,7 @@ def _construct_gather_scatter_indices( mask_vec_type = VectorType.get( [elements_per_thread], IntegerType.get_signless(1) ) - mask = _constant_mask(mask_vec_type) + mask = vector_d.constant_mask(mask_vec_type, [elements_per_thread]) def extract0(src): static_pos = [0] * src.type.rank @@ -221,34 +213,46 @@ def extract0(src): offsets = [] strides = strides_from_symbolic_shape(idxc, symbolc_shape, allow_mixed_shapes=True) start_indices_offset = _compute_offset(start_indices, strides) - for i in range(elements_per_thread): - # Update fastest dim, i.e. in case of identity mapping it will - # be equivalent to just vector.load - subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] - subs[fastest_dim] = (subs[fastest_dim][0], start_indices_orig[fastest_dim] + i) - indices = [i.subs(subs) for i in index_mapping] - # First, we build indices as if resulting gather/scatter `start_indices` - # are 0 as mapping expression may depend on absolute value of index - # (e.g. `index % 32`). Then we adjust for the non-0 `start_indices` by - # subtracting computed previously linear `start_indices_offset`. For - # simple cases like transpose, the resulting expression should fold into - # simple constant while more complex expressions may requires actual - # arith ops on dynamic values. - offset = _compute_offset(indices, strides) - start_indices_offset - offset = subs_idxc(offset) - - if offset.is_number: - # If resulted offset sympy expr is convertible to int constant it - # will be directly encoded into `arith.constant`. - # For non-constant expressions, we will generate a real sequence of - # arith ops and then `vector.insertelement` them into offsets vec. - offset = int(offset) - else: - need_dynamic_offsets = True - break + # TODO TODO TODO we don't necessarily care if they are vector shaped, but + # if they are indxed by the fastest varying dimension? + # Note that we may want to "expand" the symbol to per-element + # copies and trigger `need_dynamic_offests` below + if len(start_indices_offset.free_symbols.intersection(vector_shaped_symbols)) != 0: + need_dynamic_offsets = True - offsets.append(offset) + if not need_dynamic_offsets: + for i in range(elements_per_thread): + # Update fastest dim, i.e. in case of identity mapping it will + # be equivalent to just vector.load + subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] + subs[fastest_dim] = ( + subs[fastest_dim][0], + start_indices_orig[fastest_dim] + i, + ) + indices = [i.subs(subs) for i in index_mapping] + + # First, we build indices as if resulting gather/scatter `start_indices` + # are 0 as mapping expression may depend on absolute value of index + # (e.g. `index % 32`). Then we adjust for the non-0 `start_indices` by + # subtracting computed previously linear `start_indices_offset`. For + # simple cases like transpose, the resulting expression should fold into + # simple constant while more complex expressions may requires actual + # arith ops on dynamic values. + offset = _compute_offset(indices, strides) - start_indices_offset + offset = subs_idxc(offset) + + if offset.is_number: + # If resulted offset sympy expr is convertible to int constant it + # will be directly encoded into `arith.constant`. + # For non-constant expressions, we will generate a real sequence of + # arith ops and then `vector.insertelement` them into offsets vec. + offset = int(offset) + else: + need_dynamic_offsets = True + break + + offsets.append(offset) offsets_vec_type = VectorType.get([elements_per_thread], IndexType.get()) if need_dynamic_offsets: @@ -260,13 +264,22 @@ def extract0(src): ) subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] # Last item in `subs` corresponds to last item in `start_indices_orig` - # which is fastest changing dim. - # Replacing last element with `idxc.iota(elements_per_thread)` will - # generate vectorized index code, each element in it corresponding to - # individual vector element index. + # which is fastest changing dim. Replacing last element with + # `idxc.iota(elements_per_thread)` will generate vectorized index code, + # each element in it corresponding to individual vector element index. + # + # TODO TODO TODO: vector shaped symbol means we can't just iota here + # instead we should just take the value of the symbol; we should instead + # somehow get different values of OFFSET (or other vector shaped + # symbols) into the `get_sympy_index` below + # + # we also need to take care if there are several symbols, especially a + # mix of constant and non-constant symbols... + # + # what happens when there are no symbols? subs[-1] = ( subs[-1][0], - start_indices_orig[-1] + idxc.iota(elements_per_thread), + start_indices_orig[-1], # + idxc.iota(elements_per_thread), ) dynamic_vals_map = { sym: val @@ -460,6 +473,15 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): dyn_vals = tuple( cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals ) + + # TODO TODO TODO we can sink this down, actually... + vector_shaped_symbols = set( + sym + for sym, value in emitter.dynamic_dims.items() + if isinstance(value.type, ShapedType) + and math.prod(ShapedType(value.type).shape) != 1 + ) + start_indices, offsets_vec, mask = _construct_gather_scatter_indices( emitter=emitter, symbolc_shape=input_shape, @@ -469,6 +491,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): is_read=True, dynamic_vals=dyn_vals, is_contiguous=get_custom(node).is_contiguous_vec(), + vector_shaped_symbols=vector_shaped_symbols, ) result = _create_vec_read( emitter, diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index b72990051..412ffe84b 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -1424,6 +1424,7 @@ def check_is_mapping_contiguous( return True # TODO: Better dyn vals analysis. + # TODO TODO TODO: we also need to check if there are additional sybols in the mapping if mapping.num_dynamic_vals != 0: return False @@ -1441,6 +1442,15 @@ def check_is_mapping_contiguous( index_mapping = tuple(subs_idxc(i) for i in index_mapping) iters = mapping.iters + # TODO TODO TODO at this point, if the symbols present in the index are not + # known to be scalars or themselves contiguous, we shouldn't say the read is contiguous + # + # TODO TODO TODO we should thread through the fact that _some_ symbols have vector shape + for expr in index_mapping: + if len(expr.free_symbols - mapping.iters.keys()) != 0: + print("### avoided") + return False + subs = [(sym, sym + int(i == len(iters) - 1)) for i, sym in enumerate(iters)] diff = [ approximate_difference( diff --git a/playground/attention_with_rpe_template.py b/playground/attention_with_rpe_template.py index 0c11a111f..a4c78c6aa 100644 --- a/playground/attention_with_rpe_template.py +++ b/playground/attention_with_rpe_template.py @@ -18,9 +18,12 @@ from iree.turbine.kernel.wave.templates.attention_common import AttentionShape -def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, - dynamic_dims: bool, - max_context_length: Optional[int]): +def get_vanilla_attention_kernel( + shape: AttentionShape, + mfma_variant: MMAType, + dynamic_dims: bool, + max_context_length: Optional[int], +): # RPE ZERO = tkl.sym.ZERO OFFSET = tkl.sym.OFFSET @@ -44,9 +47,7 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD # Expose user-constraints - constraints: list[tkw.Constraint] = [ - tkw.WorkgroupConstraint(M, BLOCK_M, 0) - ] + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] @@ -65,11 +66,7 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, threads_per_wave=64, waves_per_block=(4, 1, 1), mma_type=mfma_variant[1], - vector_shapes={ - B: 0, - M: Mvec, - N: Nvec - }, + vector_shapes={B: 0, M: Mvec, N: Nvec}, ) ] @@ -79,17 +76,9 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, i = tkw.IndexMapping.iterator(0) j = tkw.IndexMapping.iterator(1) k = tkw.IndexMapping.iterator(2) - mapping = tkw.IndexMapping(num_iterators=3, - inputs={ - B: i, - N: j, - M: k - }, - outputs={ - B: i, - M: k, - N: j - }) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + ) offset_mapping = tkw.IndexMapping( num_iterators=1, @@ -99,9 +88,11 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, use_t5_rpe = max_context_length is not None if use_t5_rpe: - rpe_layout = tkl.MemoryLayout(shape=[ - max_context_length, - ]) + rpe_layout = tkl.MemoryLayout( + shape=[ + max_context_length, + ] + ) assert use_t5_rpe, "use_t5_rpe needed until rpe arg can DCE without crashing" @tkw.wave(constraints) @@ -138,16 +129,18 @@ def repeat( # When fusing into the FA variant, adding locally before the max and # the partial softmax should be equivalent. if use_t5_rpe: - ZERO = tkl.Register[M, K2, tkl.i64](0) - MAX = tkl.Register[M, K2, tkl.i64](max_context_length) + ZERO = tkl.Register[B, M, K2, tkl.i64](0) + MAX = tkl.Register[B, M, K2, tkl.i64](max_context_length) # 1. Indices i and j broadcasted along K2 with a twist: # here we use *static* information that is *implicitly* encoded # in the *transformation*: under the distribution constraints # specified we know that the shape [M] will eventually resolve # to [1] and can thus be "cast + broadcast" to [K2]. i = tkw.self_index(M, tkl.i64, elements_per_thread=1) - i = tkw.broadcast(i, target_shape=[M, K2]) - j = tkw.self_index(K2, tkl.i64, elements_per_thread=1) + i = tkw.broadcast(i, target_shape=[B, M, K2]) + j = tkw.self_index( + K2, tkl.i64, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK + ) # 2. Clip i - j to the proper bucket in [0, max_context_length] # to represent the following: @@ -163,18 +156,18 @@ def repeat( # idx = tkw.minimum(idx, MAX) # select variant. - idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), - i - j, ZERO) + idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), i - j, ZERO) # 3. Read indirect into the 1-D rpe array via offset_mapping. tkw.set_symbol(OFFSET, idx) # offset will have shape [M, K2] rpe_reg = tkw.read( rpe, mapping=offset_mapping, - elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + elements_per_thread=LOAD_ELEMS_PER_THREAD_QK, + ) # 4. Tadaaaa. - x_j = x_j + rpe_reg + tkw.cast(ZERO * idx, tkl.i64) + x_j = x_j + rpe_reg + tkw.cast(ZERO * idx, tkl.f32) m_j = tkw.max(x_j, partial_max, dim=K2) e_delta_max = tkw.exp2(partial_max - m_j) @@ -191,19 +184,13 @@ def repeat( res_max, res_sum, res_mm = repeat reciprocal_sum = tkw.reciprocal(res_sum) res = res_mm * reciprocal_sum - tkw.write(res, - c, - mapping=mapping, - elements_per_thread=STORE_ELEMS_PER_THREAD) + tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) hyperparams = { ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD_QK: - get_mfma_load_elems_per_thread(mfma_variant[0]), - LOAD_ELEMS_PER_THREAD_PV: - get_mfma_load_elems_per_thread(mfma_variant[1]), - STORE_ELEMS_PER_THREAD: - get_mfma_store_elems_per_thread(mfma_variant[1]), + LOAD_ELEMS_PER_THREAD_QK: get_mfma_load_elems_per_thread(mfma_variant[0]), + LOAD_ELEMS_PER_THREAD_PV: get_mfma_load_elems_per_thread(mfma_variant[1]), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant[1]), BLOCK_B: 1, BLOCK_M: 128, BLOCK_N: 64, diff --git a/playground/attention_with_rpe_test.py b/playground/attention_with_rpe_test.py index b0d161ef6..f9fc3ebef 100644 --- a/playground/attention_with_rpe_test.py +++ b/playground/attention_with_rpe_test.py @@ -15,7 +15,8 @@ from iree.turbine.kernel.wave.constraints import MMAType from iree.turbine.kernel.wave.templates.attention_common import AttentionShape from iree.turbine.kernel.wave.templates.vanilla_attention import ( - get_vanilla_attention_kernel as get_vanilla_tkw_attention_kernel) + get_vanilla_attention_kernel as get_vanilla_tkw_attention_kernel, +) from iree.turbine.kernel.wave.utils import ( device_randn, device_zeros, @@ -23,8 +24,8 @@ to_default_device, ) from attention_with_rpe_template import ( - get_vanilla_attention_kernel as - get_vanilla_tkw_attention_with_rpe_output_kernel) + get_vanilla_attention_kernel as get_vanilla_tkw_attention_with_rpe_output_kernel, +) torch.manual_seed(0) torch.set_printoptions( @@ -37,23 +38,26 @@ ### TKW Harness def run(fun: Callable, hparams, *args) -> Any: with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + activities=[torch.profiler.ProfilerActivity.CUDA] + ) as prof: with torch.no_grad(): # Disable gradient calculations with TestLaunchContext( - hparams, - canonicalize=True, - compile_config={"print-ir-after": "all"}, - run=True, - run_config=get_default_run_config(), - run_bench=False, - schedule=False, - use_scheduling_barriers=False, + hparams, + canonicalize=True, + compile_config={"print-ir-after": "all"}, + run=True, + run_config=get_default_run_config(), + run_bench=False, + schedule=False, + use_scheduling_barriers=False, ): fun(*args) print( prof.key_averages(group_by_input_shape=True).table( - sort_by="self_cuda_time_total", row_limit=10)) + sort_by="self_cuda_time_total", row_limit=10 + ) + ) ################################################################################# @@ -64,8 +68,9 @@ def run(fun: Callable, hparams, *args) -> Any: shape.query_seq_len = 128 shape.kv_seq_len = 128 -assert shape.num_query_heads == shape.num_kv_heads, \ - "expected query and kv to have the same number of heads!" +assert ( + shape.num_query_heads == shape.num_kv_heads +), "expected query and kv to have the same number of heads!" q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size) k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size) @@ -90,40 +95,44 @@ def run(fun: Callable, hparams, *args) -> Any: # Applied pre-softmax on the MMA'ed result so f32. # Provision more room for clipping and adding 0 at the boundaries. rpe = device_zeros(1000 + max_context_length + 2, dtype=torch.float32) -rpe = rpe[:max_context_length + 2].view(max_context_length + 2) +rpe = rpe[: max_context_length + 2].view(max_context_length + 2) rpe.copy_(device_randn(max_context_length + 2, dtype=torch.float32)) rpe[0] = 0 rpe[max_context_length + 1] = 0 -def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, - dtype): +def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, dtype): positions = to_default_device(torch.arange(sequence_length)) pos_diff = positions.unsqueeze(1) - positions.unsqueeze(0) - mask = to_default_device((pos_diff >= 0) - & (pos_diff <= max_context_length)) + mask = to_default_device((pos_diff >= 0) & (pos_diff <= max_context_length)) rpe_cond = device_zeros(sequence_length, sequence_length, dtype=dtype) rpe_cond[mask] = rpe[pos_diff[mask]] return rpe_cond # rpe_cond is used by torch only -rpe_cond = t5_rpe_masked_cond(rpe, - max_context_length=max_context_length, - sequence_length=shape.kv_seq_len, - dtype=tkw_attention_with_rpe_output.dtype) +rpe_cond = t5_rpe_masked_cond( + rpe, + max_context_length=max_context_length, + sequence_length=shape.kv_seq_len, + dtype=tkw_attention_with_rpe_output.dtype, +) ################################################################################# # TKW BASE ATTENTION ################################################################################# ### RPE version -tkw_attention_with_rpe, hyperparams, dynamic_symbols, dynamic_symbols_map = \ - get_vanilla_tkw_attention_with_rpe_output_kernel( - shape, - mfma_variant=[MMAType.F32_16x16x16_F16, - MMAType.F32_16x16x16_F16], - dynamic_dims=False, - max_context_length = max_context_length + 2) +( + tkw_attention_with_rpe, + hyperparams, + dynamic_symbols, + dynamic_symbols_map, +) = get_vanilla_tkw_attention_with_rpe_output_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16], + dynamic_dims=False, + max_context_length=max_context_length + 2, +) def attention_with_rpe(tq, tk, tv, trpe, toutput): @@ -131,24 +140,41 @@ def attention_with_rpe(tq, tk, tv, trpe, toutput): print(mb.module_op) -run(attention_with_rpe, hyperparams, q * dk_sqrt * log2e, k, - v.permute([0, 2, 1]), rpe, tkw_attention_with_rpe_output) +run( + attention_with_rpe, + hyperparams, + q * dk_sqrt * log2e, + k, + v.permute([0, 2, 1]), + rpe * log2e, + tkw_attention_with_rpe_output, +) ### Reference version -tkw_attention, hyperparams, dynamic_symbols, dynamic_symbols_map = \ - get_vanilla_tkw_attention_kernel( - shape, - mfma_variant=[MMAType.F32_16x16x16_F16, - MMAType.F32_16x16x16_F16], - dynamic_dims=False) +( + tkw_attention, + hyperparams, + dynamic_symbols, + dynamic_symbols_map, +) = get_vanilla_tkw_attention_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16], + dynamic_dims=False, +) def attention(tq, tk, tv, toutput): tkw_attention(tq, tk, tv, toutput) -run(attention, hyperparams, q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), - tkw_attention_output) +run( + attention, + hyperparams, + q * dk_sqrt * log2e, + k, + v.permute([0, 2, 1]), + tkw_attention_output, +) tkw_rpe_delta_output = tkw_attention_with_rpe_output - tkw_attention_output # print(tkw_rpe_delta_output) @@ -157,7 +183,8 @@ def attention(tq, tk, tv, toutput): # TORCH ATTENTION and ATTENTION + RPE ################################################################################# torch_attention_ref_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=None) + q, k, v, attn_mask=None +) a = torch.matmul(q, k.transpose(-1, -2)) * dk_sqrt torch_attention_output = torch.matmul(torch.softmax(a, dim=-1), v) @@ -165,23 +192,24 @@ def attention(tq, tk, tv, toutput): # Sanity check that torch_attention_output and torch_attention_ref_output are # the same so we can inject RPE pre-softmax and compute the delta. # We will test that the delta post-softmax is the same for torch and TKW. -assert_close(torch_attention_output, - torch_attention_ref_output, - atol=2e-3, - rtol=2e-3) +assert_close(torch_attention_output, torch_attention_ref_output, atol=2e-3, rtol=2e-3) a += rpe_cond.unsqueeze(0) torch_attention_with_rpe_output = torch.matmul(F.softmax(a, dim=-1), v) torch_rpe_delta_output = torch_attention_with_rpe_output - torch_attention_output # Check basic attentions match as we expect. -assert_close(torch_attention_output.to(dtype=tkw_attention_output.dtype), - tkw_attention_output, - atol=2e-3, - rtol=2e-3) +assert_close( + torch_attention_output.to(dtype=tkw_attention_output.dtype), + tkw_attention_output, + atol=2e-3, + rtol=2e-3, +) # Check RPE attentions match as we expect. -assert_close(torch_rpe_delta_output.to(dtype=tkw_rpe_delta_output.dtype), - tkw_rpe_delta_output, - atol=2e-3, - rtol=2e-3) +assert_close( + torch_rpe_delta_output.to(dtype=tkw_rpe_delta_output.dtype), + tkw_rpe_delta_output, + atol=2e-3, + rtol=2e-3, +) From a5a181897e893a238f62f1896730e911dd9ee59e Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 10 Feb 2025 14:37:16 -0800 Subject: [PATCH 7/7] debug: gather is wrong --- playground/attention_with_rpe_template.py | 29 ++++++++++++++++-- playground/attention_with_rpe_test.py | 36 ++++++++++++++--------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/playground/attention_with_rpe_template.py b/playground/attention_with_rpe_template.py index a4c78c6aa..5c7fd72da 100644 --- a/playground/attention_with_rpe_template.py +++ b/playground/attention_with_rpe_template.py @@ -86,6 +86,14 @@ def get_vanilla_attention_kernel( outputs={K2: i}, ) + # d = tkw.IndexMapping.dynamic_val(0) + # dynamic_mapping = tkw.IndexMapping( + # num_iterators=3, + # inputs = {B: d}, + # outputs = {B: i, M: j, K2: k}, + # dynamic_val_mappings = {B: i, M: j, K2: k} + # ) + use_t5_rpe = max_context_length is not None if use_t5_rpe: rpe_layout = tkl.MemoryLayout( @@ -103,6 +111,7 @@ def base_attention( # TODO: if not use_t5_rpe, this will DCE; atm DCE on blockargs crashes. rpe: tkl.Memory[K2, GLOBAL_ADDRESS_SPACE, tkl.f32, rpe_layout], c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + debug_out: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], ): c_reg = tkl.Register[B, N, M, tkl.f32](0.0) init_sum = tkl.Register[B, M, tkl.f32](0.0) @@ -152,11 +161,22 @@ def repeat( # to do bucketing; atm it is bucketing of size 1. # min/max variant - # idx = tkw.maximum(i - j, ZERO) - # idx = tkw.minimum(idx, MAX) + idx = tkw.maximum(i - j, ZERO) + idx = tkw.minimum(idx, MAX) # select variant. - idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), i - j, ZERO) + # idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), i - j, ZERO) + + idx = tkw.broadcast(idx, target_shape=[B,M,K2]) + + ### alternative + # rpe_reg = tkw.read( + # rpe, + # mapping=dynamic_mapping, + # mapping_dynamic_vals=(idx,), + # elements_per_thread=LOAD_ELEMS_PER_THREAD_QK, + # ) + ### # 3. Read indirect into the 1-D rpe array via offset_mapping. tkw.set_symbol(OFFSET, idx) # offset will have shape [M, K2] @@ -165,6 +185,9 @@ def repeat( mapping=offset_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK, ) + rpe_reg = tkw.broadcast(rpe_reg, target_shape=[B,M,K2]) + tkw.write(rpe_reg, debug_out, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + # tkw.write(tkw.cast(idx, tkl.f32), debug_out, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) # 4. Tadaaaa. x_j = x_j + rpe_reg + tkw.cast(ZERO * idx, tkl.f32) diff --git a/playground/attention_with_rpe_test.py b/playground/attention_with_rpe_test.py index f9fc3ebef..7f166aedc 100644 --- a/playground/attention_with_rpe_test.py +++ b/playground/attention_with_rpe_test.py @@ -37,14 +37,14 @@ ### TKW Harness def run(fun: Callable, hparams, *args) -> Any: - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA] - ) as prof: + # with torch.profiler.profile( + # activities=[torch.profiler.ProfilerActivity.CUDA] + # ) as prof: with torch.no_grad(): # Disable gradient calculations with TestLaunchContext( hparams, canonicalize=True, - compile_config={"print-ir-after": "all"}, + # compile_config={"print-ir-after": "all"}, run=True, run_config=get_default_run_config(), run_bench=False, @@ -53,11 +53,11 @@ def run(fun: Callable, hparams, *args) -> Any: ): fun(*args) - print( - prof.key_averages(group_by_input_shape=True).table( - sort_by="self_cuda_time_total", row_limit=10 - ) - ) + # print( + # prof.key_averages(group_by_input_shape=True).table( + # sort_by="self_cuda_time_total", row_limit=10 + # ) + # ) ################################################################################# @@ -90,7 +90,7 @@ def run(fun: Callable, hparams, *args) -> Any: # T5 RPE INIT VALS ################################################################################# # T5 RPE parameter -max_context_length = 33 +max_context_length = 30 # Applied pre-softmax on the MMA'ed result so f32. # Provision more room for clipping and adding 0 at the boundaries. @@ -100,6 +100,7 @@ def run(fun: Callable, hparams, *args) -> Any: rpe[0] = 0 rpe[max_context_length + 1] = 0 +tmp_out = device_zeros(shape.num_query_heads, shape.query_seq_len, shape.kv_seq_len, dtype=torch.float32) def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, dtype): positions = to_default_device(torch.arange(sequence_length)) @@ -118,6 +119,10 @@ def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, dtype dtype=tkw_attention_with_rpe_output.dtype, ) +# print(rpe) +print(rpe_cond.shape) +print(rpe_cond) + ################################################################################# # TKW BASE ATTENTION ################################################################################# @@ -135,21 +140,24 @@ def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, dtype ) -def attention_with_rpe(tq, tk, tv, trpe, toutput): - mb = tkw_attention_with_rpe(tq, tk, tv, trpe, toutput) +def attention_with_rpe(*args): + mb = tkw_attention_with_rpe(*args) print(mb.module_op) - run( attention_with_rpe, hyperparams, q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), - rpe * log2e, + rpe, # * log2e, tkw_attention_with_rpe_output, + tmp_out ) +print(tmp_out.shape) +print(tmp_out[0, :, :]) + ### Reference version ( tkw_attention,