diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 3465c0716..418e86e5e 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -148,6 +148,10 @@ def minimum(lhs: "Register", rhs: "Register") -> "Register": ... +def and_op(lhs: "Register", rhs: "Register") -> "Register": + ... + + def broadcast( arg: "Register", target_shape: Optional[Sequence[IndexExpr | int]] = None ) -> "Register": @@ -748,6 +752,7 @@ def infer_shape(self) -> Any: @define_py_op(operator.truediv) @define_interface_op("maximum") @define_interface_op("minimum") +@define_interface_op("and_op") @dataclass class BinaryPyOp(BinaryOpBase, ABC): def infer_type(self): diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 473857ad0..44796b13b 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -58,6 +58,7 @@ CustomOp, abs, allocate, + and_op, apply_expr, broadcast, cast, @@ -1276,7 +1277,7 @@ def handle_div(lhs: Value, rhs: Value) -> OpResult: element_type.is_signed or element_type.is_signless ): result = arith_d.divsi(lhs, rhs) - elif _is_integer_like_type(element_type) and element_type.is_unsigned(): + elif _is_integer_like_type(element_type) and element_type.is_unsigned: result = arith_d.divui(lhs, rhs) else: raise ValidationError(f"Found unhandled operand type for div: {element_type}") @@ -1292,7 +1293,7 @@ def handle_gt(lhs: Value, rhs: Value) -> OpResult: element_type.is_signed or element_type.is_signless ): result = arith_d.cmpi(arith_d.CmpIPredicate.sgt, lhs, rhs) - elif _is_integer_like_type(element_type) and element_type.is_unsigned(): + elif _is_integer_like_type(element_type) and element_type.is_unsigned: result = arith_d.cmpi(arith_d.CmpIPredicate.ugt, lhs, rhs) else: raise ValidationError(f"Found unhandled operand type for gt: {element_type}") @@ -1308,7 +1309,7 @@ def handle_ge(lhs: Value, rhs: Value) -> OpResult: element_type.is_signed or element_type.is_signless ): result = arith_d.cmpi(arith_d.CmpIPredicate.sge, lhs, rhs) - elif _is_integer_like_type(element_type) and element_type.is_unsigned(): + elif _is_integer_like_type(element_type) and element_type.is_unsigned: result = arith_d.cmpi(arith_d.CmpIPredicate.uge, lhs, rhs) else: raise ValidationError(f"Found unhandled operand type for ge: {element_type}") @@ -1324,7 +1325,7 @@ def handle_lt(lhs: Value, rhs: Value) -> OpResult: element_type.is_signed or element_type.is_signless ): result = arith_d.cmpi(arith_d.CmpIPredicate.slt, lhs, rhs) - elif _is_integer_like_type(element_type) and element_type.is_unsigned(): + elif _is_integer_like_type(element_type) and element_type.is_unsigned: result = arith_d.cmpi(arith_d.CmpIPredicate.ult, lhs, rhs) else: raise ValidationError(f"Found unhandled operand type for lt: {element_type}") @@ -1340,13 +1341,24 @@ def handle_le(lhs: Value, rhs: Value) -> OpResult: element_type.is_signed or element_type.is_signless ): result = arith_d.cmpi(arith_d.CmpIPredicate.sle, lhs, rhs) - elif _is_integer_like_type(element_type) and element_type.is_unsigned(): + elif _is_integer_like_type(element_type) and element_type.is_unsigned: result = arith_d.cmpi(arith_d.CmpIPredicate.ule, lhs, rhs) else: raise ValidationError(f"Found unhandled operand type for le: {element_type}") return result +@handle_binary_op(and_op) +def handle_and_op(lhs: Value, rhs: Value) -> OpResult: + element_type = get_type_or_element_type(lhs.type) + if _is_integer_like_type(element_type): + result = arith_d.andi(lhs, rhs) + else: + raise ValidationError( + f"Found unhandled operand type for le: {element_type}") + return result + + @handle_binary_op(maximum) def handle_maximum(lhs: Value, rhs: Value) -> OpResult: element_type = get_type_or_element_type(lhs.type) @@ -1356,7 +1368,7 @@ def handle_maximum(lhs: Value, rhs: Value) -> OpResult: element_type.is_signed or element_type.is_signless ): result = arith_d.maxsi(lhs, rhs) - elif _is_integer_like_type(element_type) and element_type.is_unsigned(): + elif _is_integer_like_type(element_type) and element_type.is_unsigned: result = arith_d.maxui(lhs, rhs) else: raise ValidationError( @@ -1371,10 +1383,10 @@ def handle_minimum(lhs: Value, rhs: Value) -> OpResult: if _is_float_type(element_type): result = arith_d.minimumf(lhs, rhs) elif _is_integer_like_type(element_type) and ( - element_type.is_signed() or element_type.is_signless() + element_type.is_signed or element_type.is_signless ): result = arith_d.minsi(lhs, rhs) - elif _is_integer_like_type(element_type) and element_type.is_unsigned(): + elif _is_integer_like_type(element_type) and element_type.is_unsigned: result = arith_d.minui(lhs, rhs) else: raise ValidationError( @@ -1382,7 +1394,6 @@ def handle_minimum(lhs: Value, rhs: Value) -> OpResult: ) return result - ############################################################################### # Unary math Ops ############################################################################### diff --git a/playground/__init__.py b/playground/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/playground/attention_with_rpe_template.py b/playground/attention_with_rpe_template.py new file mode 100644 index 000000000..eb0e7c508 --- /dev/null +++ b/playground/attention_with_rpe_template.py @@ -0,0 +1,232 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from dataclasses import dataclass +from typing import Optional + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape + + +def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, + dynamic_dims: bool, + max_context_length: Optional[int]): + # RPE + ZERO = tkl.sym.ZERO + OFFSET = tkl.sym.OFFSET + + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD_QK = index_symbol("LOAD_ELEMS_PER_THREAD_QK") + LOAD_ELEMS_PER_THREAD_PV = index_symbol("LOAD_ELEMS_PER_THREAD_PV") + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0) + ] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 1)] + + if mfma_variant[1] == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant[1] == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(4, 1, 1), + mma_type=mfma_variant[1], + vector_shapes={ + B: 0, + M: Mvec, + N: Nvec + }, + ) + ] + + if dynamic_dims: + constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping(num_iterators=3, + inputs={ + B: i, + N: j, + M: k + }, + outputs={ + B: i, + M: k, + N: j + }) + + offset_mapping = tkw.IndexMapping( + num_iterators=1, + inputs={K2: i + OFFSET}, + outputs={K2: i}, + ) + + use_t5_rpe = max_context_length is not None + if use_t5_rpe: + rpe_layout = tkl.MemoryLayout(shape=[ + max_context_length, + ]) + assert use_t5_rpe, "use_t5_rpe needed until rpe arg can DCE without crashing" + + @tkw.wave(constraints) + def base_attention( + q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + # TODO: if not use_t5_rpe, this will DCE; atm DCE on blockargs crashes. + rpe: tkl.Memory[K2, GLOBAL_ADDRESS_SPACE, tkl.f32, rpe_layout], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0]) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + + #################################################################### + # T5 RPE + #################################################################### + # Fused T5 RPE adds attention bias pre-softmax normalization. + # When fusing into the FA variant, adding locally before the max and + # the partial softmax should be equivalent. + if use_t5_rpe: + ZERO = tkl.Register[M, K2, tkl.i64](0) + MAX = tkl.Register[M, K2, tkl.i64](max_context_length) + # 1. Indices i and j broadcasted along K2 with a twist: + # here we use *static* information that is *implicitly* encoded + # in the *transformation*: under the distribution constraints + # specified we know that the shape [M] will eventually resolve + # to [1] and can thus be "cast + broadcast" to [K2]. + i = tkw.self_index(M, tkl.i64, elements_per_thread=1) + i = tkw.broadcast(i, target_shape=[M, K2]) + j = tkw.self_index(K2, tkl.i64, elements_per_thread=1) + + # 2. Clip i - j to the proper bucket in [0, max_context_length] + # to represent the following: + # - if 0 < i - j < max_context_length + # then x_j += rpe_reg + # - otherwise (i.e. i - j == {0, max_context_length}) + # then x_j += 0 + # TODO: we may need scaling adjustements depending on how we want + # to do bucketing; atm it is bucketing of size 1. + idx = tkw.maximum(i - j, ZERO) + idx = tkw.minimum(idx, MAX) + # tkw.selec complains that: + # `expected vector_type.shape[0] == 1 but got vector<4xi64>`` + # idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), + # i - j, ZERO) + + # 3. Read indirect into the 1-D rpe array via offset_mapping. + tkw.set_symbol(OFFSET, idx) # offset will have shape [M, K2] + rpe_reg = tkw.read( + rpe, + mapping=offset_mapping, + elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + + # 4. Tadaaaa. + x_j = x_j + rpe_reg + tkw.cast(idx, tkl.f32) + + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD_PV) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + reciprocal_sum = tkw.reciprocal(res_sum) + res = res_mm * reciprocal_sum + tkw.write(res, + c, + mapping=mapping, + elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD_QK: + get_mfma_load_elems_per_thread(mfma_variant[0]), + LOAD_ELEMS_PER_THREAD_PV: + get_mfma_load_elems_per_thread(mfma_variant[1]), + STORE_ELEMS_PER_THREAD: + get_mfma_store_elems_per_thread(mfma_variant[1]), + BLOCK_B: 1, + BLOCK_M: 128, + BLOCK_N: 64, + BLOCK_K2: 64, + B: shape.num_query_heads, + M: shape.query_seq_len, + N: shape.head_size_kv, + K1: shape.head_size, + K2: shape.kv_seq_len, + } + + dynamic_symbols = [] + dynamic_symbols_map = {} + if dynamic_dims: + dynamic_symbols_map[M] = hyperparams[M] + dynamic_symbols_map[N] = hyperparams[N] + dynamic_symbols_map[B] = hyperparams[B] + dynamic_symbols_map[K2] = hyperparams[K2] + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(B) + dynamic_symbols.append(K2) + del hyperparams[M] + del hyperparams[N] + del hyperparams[B] + del hyperparams[K2] + + return base_attention, hyperparams, dynamic_symbols, dynamic_symbols_map diff --git a/playground/attention_with_rpe_test.py b/playground/attention_with_rpe_test.py new file mode 100644 index 000000000..af49697ed --- /dev/null +++ b/playground/attention_with_rpe_test.py @@ -0,0 +1,185 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import math +import torch +from torch.profiler import profile, record_function, ProfilerActivity +from torch.nn import functional as F +from torch.testing import assert_close +from typing import Any, Callable + +from iree.turbine.kernel.gen import TestLaunchContext +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape +from iree.turbine.kernel.wave.templates.vanilla_attention import ( + get_vanilla_attention_kernel as get_vanilla_tkw_attention_kernel) +from iree.turbine.kernel.wave.utils import ( + device_randn, + device_zeros, + get_default_run_config, + to_default_device, +) +from attention_with_rpe_template import ( + get_vanilla_attention_kernel as + get_vanilla_tkw_attention_with_rpe_output_kernel) + +torch.manual_seed(0) +torch.set_printoptions( + linewidth=1000000, + threshold=1000000, + precision=3, +) + + +### TKW Harness +def run(fun: Callable, hparams, *args) -> Any: + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.no_grad(): # Disable gradient calculations + with TestLaunchContext( + hparams, + canonicalize=True, + compile_config={"print-ir-after": "all"}, + run=True, + run_config=get_default_run_config(), + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + fun(*args) + + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="self_cuda_time_total", row_limit=10)) + + +################################################################################# +# INIT VALS +################################################################################# +# num_query_heads, num_kv_heads, head_size, head_size_kv +shape = AttentionShape(128, 128, 128, 128) +shape.query_seq_len = 128 +shape.kv_seq_len = 128 + +assert shape.num_query_heads == shape.num_kv_heads, \ + "expected query and kv to have the same number of heads!" + +q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size) +k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size) +v_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size_kv) +o_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size_kv) + +q = device_randn(q_shape, dtype=torch.float16) +k = device_randn(k_shape, dtype=torch.float16) +v = device_randn(v_shape, dtype=torch.float16) +tkw_attention_output = device_zeros(o_shape, dtype=torch.float32) +tkw_attention_with_rpe_output = device_zeros(o_shape, dtype=torch.float32) + +log2e = 1.44269504089 +dk_sqrt = math.sqrt(1.0 / q.shape[-1]) + +################################################################################# +# T5 RPE INIT VALS +################################################################################# +# T5 RPE parameter +max_context_length = 33 + +# Applied pre-softmax on the MMA'ed result so f32. +# Provision more room for clipping and adding 0 at the boundaries. +rpe = device_zeros(1000 + max_context_length + 2, dtype=torch.float32) +rpe = rpe[:max_context_length + 2].view(max_context_length + 2) +rpe.copy_(device_randn(max_context_length + 2, dtype=torch.float32)) +rpe[0] = 0 +rpe[max_context_length + 1] = 0 + + +def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, + dtype): + positions = to_default_device(torch.arange(sequence_length)) + pos_diff = positions.unsqueeze(1) - positions.unsqueeze(0) + mask = to_default_device((pos_diff >= 0) + & (pos_diff <= max_context_length)) + rpe_cond = device_zeros(sequence_length, sequence_length, dtype=dtype) + rpe_cond[mask] = rpe[pos_diff[mask]] + return rpe_cond + + +# rpe_cond is used by torch only +rpe_cond = t5_rpe_masked_cond(rpe, + max_context_length=max_context_length, + sequence_length=shape.kv_seq_len, + dtype=tkw_attention_with_rpe_output.dtype) + +################################################################################# +# TORCH ATTENTION and ATTENTION + RPE +################################################################################# +torch_attention_ref_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None) + +a = torch.matmul(q, k.transpose(-1, -2)) * dk_sqrt +torch_attention_output = torch.matmul(torch.softmax(a, dim=-1), v) + +# Sanity check that torch_attention_output and torch_attention_ref_output are +# the same so we can inject RPE pre-softmax and compute the delta. +# We will test that the delta post-softmax is the same for torch and TKW. +assert_close(torch_attention_output, + torch_attention_ref_output, + atol=2e-3, + rtol=2e-3) + +a += rpe_cond.unsqueeze(0) +torch_attention_with_rpe_output = torch.matmul(F.softmax(a, dim=-1), v) +torch_rpe_delta_output = torch_attention_with_rpe_output - torch_attention_output + +################################################################################# +# TKW BASE ATTENTION +################################################################################# +### Reference version +tkw_attention, hyperparams, dynamic_symbols, dynamic_symbols_map = \ + get_vanilla_tkw_attention_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, + MMAType.F32_16x16x16_F16], + dynamic_dims=False) + + +def attention(tq, tk, tv, toutput): + tkw_attention(tq, tk, tv, toutput) + + +run(attention, hyperparams, q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), + tkw_attention_output) + +assert_close(torch_attention_output.to(dtype=tkw_attention_output.dtype), + tkw_attention_output, + atol=2e-3, + rtol=2e-3) + +### RPE version +tkw_attention_with_rpe, hyperparams, dynamic_symbols, dynamic_symbols_map = \ + get_vanilla_tkw_attention_with_rpe_output_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, + MMAType.F32_16x16x16_F16], + dynamic_dims=False, + max_context_length = max_context_length + 2) + + +def attention_with_rpe(tq, tk, tv, trpe, toutput): + mb = tkw_attention_with_rpe(tq, tk, tv, trpe, toutput) + print(mb.module_op) + + +run(attention_with_rpe, hyperparams, q * dk_sqrt * log2e, k, + v.permute([0, 2, 1]), rpe, tkw_attention_with_rpe_output) + +tkw_rpe_delta_output = tkw_attention_with_rpe_output - tkw_attention_output +# print(tkw_rpe_delta_output) + +assert_close(torch_rpe_delta_output.to(dtype=tkw_rpe_delta_output.dtype), + tkw_rpe_delta_output, + atol=2e-3, + rtol=2e-3) diff --git a/playground/causal_attention_template.py b/playground/causal_attention_template.py new file mode 100644 index 000000000..f2de1fe25 --- /dev/null +++ b/playground/causal_attention_template.py @@ -0,0 +1,197 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from dataclasses import dataclass +import sympy +import sys +from typing import Optional + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape + + +def get_causal_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, + dynamic_dims: bool): + ZERO = tkl.sym.ZERO + + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD_QK = index_symbol("LOAD_ELEMS_PER_THREAD_QK") + LOAD_ELEMS_PER_THREAD_PV = index_symbol("LOAD_ELEMS_PER_THREAD_PV") + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0) + ] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 1)] + + if mfma_variant[1] == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant[1] == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(4, 1, 1), + mma_type=mfma_variant[1], + vector_shapes={ + B: 0, + M: Mvec, + N: Nvec + }, + ) + ] + + if dynamic_dims: + constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping(num_iterators=3, + inputs={ + B: i, + N: j, + M: k + }, + outputs={ + B: i, + M: k, + N: j + }) + + @tkw.wave(constraints) + def base_attention( + q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + tmp: tkl.Memory[K2, ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + imm_reg = tkw.permute(imm_reg, target_shape=[B, M, K2]) + #################################################################### + # Causal mask + #################################################################### + # Indices i and j broadcasted along K2 with a twist: + # here we use *static* information that is *implicitly* encoded + # in the *transformation*: under the distribution constraints + # specified we know that the shape [M] will eventually resolve + # to [1] and can thus be "cast + broadcast" to [K2]. + i = tkw.self_index(M, tkl.i64, elements_per_thread=1) + i = tkw.broadcast(i, target_shape=[K2]) + j = tkw.self_index(K2, tkl.i64, elements_per_thread=1) + ZERO = tkl.Register[K2, tkl.i64](0) + ONE = tkl.Register[K2, tkl.i64](1) + ZEROF = tkl.Register[K2, tkl.f32](0.0) + MIN_INF = tkl.Register[K2, tkl.f32](float('-inf')) + idx = j - i - ONE + bias = tkw.select(tkw.slt(idx, ZERO), ZEROF, MIN_INF) + ### Apply causality mask to imm_reg. + imm_reg = imm_reg + bias + imm_reg = tkw.permute(imm_reg, target_shape=[B, K2, M]) + + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0]) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD_PV) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + reciprocal_sum = tkw.reciprocal(res_sum) + res = res_mm * reciprocal_sum + tkw.write(res, + c, + mapping=mapping, + elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD_QK: + get_mfma_load_elems_per_thread(mfma_variant[0]), + LOAD_ELEMS_PER_THREAD_PV: + get_mfma_load_elems_per_thread(mfma_variant[1]), + STORE_ELEMS_PER_THREAD: + get_mfma_store_elems_per_thread(mfma_variant[1]), + BLOCK_B: 1, + BLOCK_M: 128, + BLOCK_N: 64, + BLOCK_K2: 64, + B: shape.num_query_heads, + M: shape.query_seq_len, + N: shape.head_size_kv, + K1: shape.head_size, + K2: shape.kv_seq_len, + ZERO: 0, + } + + dynamic_symbols = [] + dynamic_symbols_map = {} + if dynamic_dims: + dynamic_symbols_map[M] = hyperparams[M] + dynamic_symbols_map[N] = hyperparams[N] + dynamic_symbols_map[B] = hyperparams[B] + dynamic_symbols_map[K2] = hyperparams[K2] + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(B) + dynamic_symbols.append(K2) + del hyperparams[M] + del hyperparams[N] + del hyperparams[B] + del hyperparams[K2] + + return base_attention, hyperparams, dynamic_symbols, dynamic_symbols_map diff --git a/playground/causal_attention_test.py b/playground/causal_attention_test.py new file mode 100644 index 000000000..77a9a4ee8 --- /dev/null +++ b/playground/causal_attention_test.py @@ -0,0 +1,179 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import math +import torch +from torch.profiler import profile, record_function, ProfilerActivity +from torch.nn import functional as F +from torch.testing import assert_close +from typing import Any, Callable + +from iree.turbine.kernel.gen import TestLaunchContext +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape +from iree.turbine.kernel.wave.utils import ( + device_randn, + device_zeros, + get_default_run_config, + to_default_device, +) +from causal_attention_template import (get_causal_attention_kernel as + get_tkw_causal_attention_kernel) +from iree.turbine.kernel.wave.templates.vanilla_attention import ( + get_vanilla_attention_kernel as get_vanilla_tkw_attention_kernel) + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +torch.manual_seed(0) +torch.set_printoptions( + linewidth=1000000, + threshold=1000000, + precision=3, +) + + +def find_different_coordinates(tensor1, tensor2, rtol=1e-5, atol=1e-8): + # Calculate the difference in float32 to avoid range issues + diff = (tensor1.float() - tensor2.float()).abs() + + # Create a mask where the difference exceeds the tolerance + tolerance = atol + rtol * tensor2.float().abs() + diff_mask = diff > tolerance + + if not diff_mask.any(): # Tensors are close if the mask is all False + print("Tensors are close.") + return [] + + diff_indices = torch.nonzero(diff_mask) + + print("Tensors are different at the following coordinates:") + for coords in diff_indices: + print(tuple(coords.tolist())) + + return diff_indices + + +### TKW Harness +def run(fun: Callable, hparams, *args) -> Any: + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.no_grad(): # Disable gradient calculations + with TestLaunchContext( + hparams, + canonicalize=True, + # compile_config={"print_ir_after": "all"}, + run=True, + run_config=get_default_run_config(), + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + fun(*args) + + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="self_cuda_time_total", row_limit=10)) + + +################################################################################# +# INIT VALS +################################################################################# +# num_query_heads, num_kv_heads, head_size, head_size_kv +shape = AttentionShape(128, 128, 128, 128) +shape.query_seq_len = 128 +shape.kv_seq_len = 128 + +assert shape.num_query_heads == shape.num_kv_heads, \ + "expected query and kv to have the same number of heads!" + +q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size) +k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size) +v_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size_kv) +o_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size_kv) + +q = device_randn(q_shape, dtype=torch.float16) +k = device_randn(k_shape, dtype=torch.float16) +v = device_randn(v_shape, dtype=torch.float16) +tkw_attention_output = device_zeros(o_shape, dtype=torch.float32) +tkw_causal_attention_output = device_zeros(o_shape, dtype=torch.float32) + +log2e = 1.44269504089 +dk_sqrt = math.sqrt(1.0 / q.shape[-1]) + +################################################################################# +# TORCH ATTENTION +################################################################################# +torch_causal_attention_ref_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True) +torch_attention_ref_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None) +torch_delta = torch_attention_ref_output - torch_causal_attention_ref_output +print(torch_delta) + +################################################################################# +# TKW ATTENTION +################################################################################# +### Reference version +tkw_attention, hyperparams, dynamic_symbols, dynamic_symbols_map = \ + get_vanilla_tkw_attention_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, + MMAType.F32_16x16x16_F16], + dynamic_dims=False) + + +def attention(tq, tk, tv, toutput): + tkw_attention(tq, tk, tv, toutput) + + +run(attention, hyperparams, q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), + tkw_attention_output) + +assert_close(torch_attention_ref_output.to(dtype=tkw_attention_output.dtype), + tkw_attention_output, + atol=2e-3, + rtol=2e-3) + +### Causal version +tkw_causal_attention, hyperparams, dynamic_symbols, dynamic_symbols_map = \ + get_tkw_causal_attention_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, + MMAType.F32_16x16x16_F16], + dynamic_dims=False) + + +def causal_attention(tq, tk, tv, toutput): + mb = tkw_causal_attention(tq, tk, tv, toutput) + print(mb.module_op) + + +run(causal_attention, hyperparams, q * dk_sqrt * log2e, k, + v.permute([0, 2, 1]), tkw_causal_attention_output) + +# tkw_delta = tkw_causal_attention_output - tkw_attention_output +# print(tkw_delta) +# print(torch_causal_attention_ref_output[67, 16]) +# print(tkw_causal_attention_output[67, 16]) + +# Coordinates where we see discrepancies are: +# (*, 16:31, *) +# (*, 48:63, *) +# (*, 80:95, *) +# (*, 80:95, *) +# (*, 112:127, *) +# different_coords = find_different_coordinates( +# torch_causal_attention_ref_output, +# tkw_causal_attention_output, +# rtol=2e-3, +# atol=2e-3) +# print(different_coords) + +assert_close(torch_causal_attention_ref_output.to( + dtype=tkw_causal_attention_output.dtype), + tkw_causal_attention_output, + atol=2e-3, + rtol=2e-3) diff --git a/playground/stress.py b/playground/stress.py new file mode 100644 index 000000000..ac77b7613 --- /dev/null +++ b/playground/stress.py @@ -0,0 +1,203 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pdb import run +import pytest +import torch +from typing import Callable + +from iree.turbine.kernel._support.tracing import TestLaunchContext +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.utils import (device_randn, device_zeros, + get_default_run_config) +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw + +torch.set_printoptions(linewidth=300) + + +# We want each row to contain [0 .. num_cols] +def reference_row(rows: int, cols: int): + row_indices = torch.arange(cols).unsqueeze(0).expand(rows, cols) + print(row_indices.shape) + return row_indices + + +# We want each col to contain [0 .. num_rows] +def reference_col(rows: int, cols: int): + col_indices = torch.arange(rows).unsqueeze(1).expand(rows, cols) + return col_indices + + +def reference_row_plus_col(rows: int, cols: int): + return reference_row(rows, cols) + reference_col(rows, cols) + + +# Input sizes +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K +# Workgroup tile sizes +ITERATIONS_OF_M_PER_WAVE = tkl.sym.ITERATIONS_OF_M_PER_WAVE +ITERATIONS_OF_N_PER_WAVE = tkl.sym.ITERATIONS_OF_N_PER_WAVE +BLOCK_K = tkl.sym.BLOCK_K +# Address space (for GPU, shared(1) or global(0)) +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +# Other hyperparameters +LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD +STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + +# yapf: disable +def run_harness(fun: Callable, vM: int, vN: int, *args) -> bool: + config = get_default_run_config() + # Override manually to run. + config = {"backend": "rocm", "device": "hip", "target": "gfx90a"} + with TestLaunchContext({M: vM, N: vN}, + canonicalize=True, + run=True, + run_config=config): + return fun(*args) + + +# yapf: disable +### Setting all of the following at the same time to agree on the same value works: +# - ITERATIONS_OF_N_PER_WAVE +# - VECTOR_SHAPE_N +# - ELEMENTS_PER_THREAD_STORE +params = [ +# SIZE_M, SIZE_N, ITERATIONS_OF_M_PER_WAVE, ITERATIONS_OF_N_PER_WAVE, VECTOR_SHAPE_M, VECTOR_SHAPE_N, ELEMENTS_PER_THREAD_STORE + [ 4, 8, 1, 1, 1, 1, 1, ], + [ 4, 8, 1, 2, 1, 2, 2, ], + [ 4, 8, 1, 3, 1, 3, 3, ], + [ 4, 8, 1, 4, 1, 4, 4, ], +] +### However, The slightest discrepancy throws the TK compiler off: +# params = [ +# SIZE_M, SIZE_N, ITERATIONS_OF_M_PER_WAVE, ITERATIONS_OF_N_PER_WAVE, VECTOR_SHAPE_M, VECTOR_SHAPE_N, ELEMENTS_PER_THREAD_STORE +# [ 4, 8, 1, 1, 1, 4, 4, 4, ], # Tile size must be divisible by wave count and vector size, got: tile_size=1, wave_count=1, vector_size=4 +# [ 4, 8, 1, 4, 1, 1, 4, 4, ], # MISCOMPILE INCORRECT RESULTS +# [ 4, 8, 1, 4, 1, 4, 1, 4, ], # CRASH TK COMPILER: Shape doesn't match: (1,) and (4,) in register cast_M:0_N:0 and elements_per_thread 4 +# [ 4, 8, 1, 4, 1, 4, 4, 1, ], # CRASH TK COMPILER: Shape doesn't match: (4,) and (1,) in register cast_M:0_N:0 and elements_per_thread 1 +# ] + +for p in params: + SIZE_M, \ + SIZE_N, \ + ITERATIONS_OF_M_PER_WAVE, \ + ITERATIONS_OF_N_PER_WAVE, \ + VECTOR_SHAPE_M, \ + VECTOR_SHAPE_N, \ + ELEMENTS_PER_THREAD_STORE = p + + workgroup_constraints = [ + [tkw.WorkgroupConstraint(M, ITERATIONS_OF_M_PER_WAVE, 0)], + [tkw.WorkgroupConstraint(N, ITERATIONS_OF_N_PER_WAVE, 1)], + [ + tkw.WorkgroupConstraint(M, ITERATIONS_OF_M_PER_WAVE, 0), + tkw.WorkgroupConstraint(N, ITERATIONS_OF_N_PER_WAVE, 1) + ], + ] + wave_constraints = [ + [], + [tkw.WaveConstraint(M, 1)], + [tkw.WaveConstraint(N, 1)], + [tkw.WaveConstraint(M, 1), tkw.WaveConstraint(N, 1)], + [tkw.WaveConstraint(M, 2)], + [tkw.WaveConstraint(N, 2)], + [tkw.WaveConstraint(M, 2), tkw.WaveConstraint(N, 2)], + [tkw.WaveConstraint(M, 2), tkw.WaveConstraint(N, 2)], + [tkw.WaveConstraint(M, 1), tkw.WaveConstraint(N, 2)], + [tkw.WaveConstraint(M, 2), tkw.WaveConstraint(N, 1)], + ] + # yapf: enable + + for wgs in workgroup_constraints: + for wvs in wave_constraints: + unroll_N = True + # In these stress tests compute self_index(N) and we want to distinguish + # between the cases: + # 1. there is a WorkgroupConstraint on N, therefore N is distributed + # and using ELEMENTS_PER_THREAD_INDEX == 1 results in proper + # propagations + # 2. there is no WorkgroupConstraint on N, therefore N is unrolled and + # we have to use ELEMENTS_PER_THREAD_INDEX == ELEMENTS_PER_THREAD_STORE + # otherwise the TK compiler gets confused atm. + # Ideally, in the future, things would just work out of the box without + # having to adjust ELEMENTS_PER_THREAD_INDEX + for wg in wgs: + if wg.dim == N: + unroll_N = False + + # Skip this particular constraint if a WaveConstraint is set without + # first setting the corresponding WorkgroupConstraint: + # TK does not handle that case + skip = False + for wv in wvs: + skip_wv = True + for wg in wgs: + if wg.dim == wv.dim: + skip_wv = False + if skip_wv: + skip = True + if skip: + continue + + ELEMENTS_PER_THREAD_INDEX = ELEMENTS_PER_THREAD_STORE if unroll_N else 1 + + ###### User constraints + constraints: list[tkw.Constraint] = [] + constraints += wgs + constraints += wvs + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={ + M: VECTOR_SHAPE_M, + N: VECTOR_SHAPE_N + }, + ) + ] + + ###### Known cases to skip: + # When we unroll N, TK does not handle imperfect unrolling (with a + # remainder). + if unroll_N and SIZE_N % ITERATIONS_OF_N_PER_WAVE != 0: + continue + + @tkw.wave(constraints) + def row(c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32]): + i = tkw.self_index( + N, tkl.i64, elements_per_thread=ELEMENTS_PER_THREAD_INDEX) + res = tkw.cast(i, tkl.f32) + tkw.write(res, + c, + elements_per_thread=ELEMENTS_PER_THREAD_STORE) + + def fun_row(debug: bool = False) -> bool: + c = device_zeros(SIZE_M, SIZE_N, dtype=torch.float32) + if debug: + print(row(c).module_op) + return True + else: + row(c) + correct = torch.all( + torch.isclose(reference_row(SIZE_M, SIZE_N), + c.cpu().to(dtype=torch.int64))).item() + if not correct: + print(f"reference:\n{reference_row(SIZE_M, SIZE_N)}") + print(f"actual:\n{c.cpu().to(dtype=torch.int64)}") + print( + f"delta:\n{c.cpu().to(dtype=torch.int64) - reference_row(SIZE_M, SIZE_N)}" + ) + return correct + + correct = run_harness(fun_row, SIZE_M, SIZE_N) + if not correct: + print(f"\nError under stress test constraints: {constraints}") + run_harness(fun_row, SIZE_M, SIZE_N, True) + assert correct, "Incorrect execution: ran in debug mode now stop" diff --git a/playground/triangular.py b/playground/triangular.py new file mode 100644 index 000000000..58e32e5fc --- /dev/null +++ b/playground/triangular.py @@ -0,0 +1,108 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import math +import torch +from torch.profiler import profile, record_function, ProfilerActivity +from torch.nn import functional as F +from torch.testing import assert_close +from typing import Any, Callable + +from iree.turbine.kernel.gen import TestLaunchContext +from iree.turbine.kernel.lang.global_symbols import GLOBAL_ADDRESS_SPACE +from iree.turbine.kernel.wave.utils import ( + device_zeros, + get_default_run_config, + to_default_device, +) +from causal_attention_template import (get_causal_attention_kernel as + get_tkw_causal_attention_kernel) + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw + +torch.manual_seed(0) +torch.set_printoptions( + linewidth=1000000, + threshold=1000000, + precision=3, +) + +vM, vN = 10, 10 +torch_o = to_default_device(torch.ones(vM, vN)) +temp_mask = to_default_device( + torch.ones(vM, vN, dtype=torch.bool).tril(diagonal=0)) +torch_o.masked_fill_(temp_mask.logical_not(), float("-inf")) + +M = tkl.sym.M +N = tkl.sym.N +ONE = tkl.sym.ONE +# Expose user-constraints +constraints: list[tkw.Constraint] = [] +constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={ + M: 1, + N: 1, + }, + ) +] + +constraints += [tkw.WorkgroupConstraint(M, 1, 0)] +constraints += [tkw.WorkgroupConstraint(N, 1, 1)] + +# WARNING: these constraints generate wrong code +# constraints += [tkw.WorkgroupConstraint(M, 2, 0)] +# constraints += [tkw.WorkgroupConstraint(N, 2, 1)] +# constraints += [tkw.WaveConstraint(M, 1)] +# constraints += [tkw.WaveConstraint(N, 1)] + + +### TKW Harness +def run(fun: Callable, hparams, *args) -> Any: + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.no_grad(): # Disable gradient calculations + with TestLaunchContext( + hparams, + canonicalize=True, + # compile_config={"print_ir_after": "all"}, + run=True, + run_config=get_default_run_config(), + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + mb = fun(*args) + print(mb.module_op) + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="self_cuda_time_total", row_limit=10)) + + +@tkw.wave(constraints) +def test(o: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32]): + i = tkw.self_index(M, tkl.i64, elements_per_thread=1) + i = tkw.broadcast(i, target_shape=[N]) + j = tkw.self_index(N, tkl.i64, elements_per_thread=1) + ZERO = tkl.Register[N, tkl.i64](0) + ONE = tkl.Register[N, tkl.i64](1) + ZEROF = tkl.Register[N, tkl.f32](0.0) + MIN_INF = tkl.Register[N, tkl.f32](float('-inf')) + idx = j - i - ONE + res = tkw.select(tkw.slt(idx, ZERO), ZEROF, MIN_INF) + val = tkw.read(o, elements_per_thread=1) + res += val + tkw.write(res, o, elements_per_thread=1) + + +o = to_default_device(torch.ones(vM, vN)) +run(test, {M: vM, N: vN, ONE: 1}, o) + +# print(o) +assert_close(torch_o.to(dtype=o.dtype), o, atol=2e-3, rtol=2e-3)