diff --git a/playground/__init__.py b/playground/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/playground/attention_with_rpe_template.py b/playground/attention_with_rpe_template.py new file mode 100644 index 000000000..6466e43cd --- /dev/null +++ b/playground/attention_with_rpe_template.py @@ -0,0 +1,228 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from dataclasses import dataclass +from typing import Optional + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape + + +def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType, + dynamic_dims: bool, + max_context_length: Optional[int]): + # RPE + ZERO = tkl.sym.ZERO + OFFSET = tkl.sym.OFFSET + + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD_QK = index_symbol("LOAD_ELEMS_PER_THREAD_QK") + LOAD_ELEMS_PER_THREAD_PV = index_symbol("LOAD_ELEMS_PER_THREAD_PV") + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0) + ] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 1)] + + if mfma_variant[1] == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant[1] == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(4, 1, 1), + mma_type=mfma_variant[1], + vector_shapes={ + B: 0, + M: Mvec, + N: Nvec + }, + ) + ] + + if dynamic_dims: + constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping(num_iterators=3, + inputs={ + B: i, + N: j, + M: k + }, + outputs={ + B: i, + M: k, + N: j + }) + + offset_mapping = tkw.IndexMapping( + num_iterators=1, + inputs={K2: i + OFFSET}, + outputs={K2: i}, + ) + + use_t5_rpe = max_context_length is not None + if use_t5_rpe: + rpe_layout = tkl.MemoryLayout(shape=[ + max_context_length, + ]) + assert use_t5_rpe, "use_t5_rpe needed until rpe arg can DCE without crashing" + + @tkw.wave(constraints) + def base_attention( + q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + # TODO: if not use_t5_rpe, this will DCE; atm DCE on blockargs crashes. + rpe: tkl.Memory[K2, GLOBAL_ADDRESS_SPACE, tkl.f32, rpe_layout], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0]) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + + #################################################################### + # T5 RPE + #################################################################### + # Fused T5 RPE adds attention bias pre-softmax normalization. + # When fusing into the FA variant, adding locally before the max and + # the partial softmax should be equivalent. + if use_t5_rpe: + ZERO = tkl.Register[M, K2, tkl.i64](0) + MAX = tkl.Register[M, K2, tkl.i64](max_context_length) + # 1. Indices i and j broadcasted along K2 with a twist: + # here we use *static* information that is *implicitly* encoded + # in the *transformation*: under the distribution constraints + # specified we know that the shape [M] will eventually resolve + # to [1] and can thus be "cast + broadcast" to [K2]. + i = tkw.self_index(M, tkl.i64, elements_per_thread=1) + i = tkw.broadcast(i, target_shape=[M, K2]) + j = tkw.self_index(K2, tkl.i64, elements_per_thread=1) + # j = tkw.broadcast(j, target_shape=[M, K2]) + + # 2. Clip i - j to the proper bucket in [0, max_context_length] + # to represent the following: + # - if 0 < i - j < max_context_length + # then x_j += rpe_reg + # - otherwise (i.e. i - j == {0, max_context_length}) + # then x_j += 0 + # TODO: we may need scaling adjustements depending on how we want + # to do bucketing; atm it is bucketing of size 1. + idx = tkw.minimum(tkw.maximum(i - j, ZERO), MAX) + + # 3. Read indirect into the 1-D rpe array via offset_mapping. + tkw.set_symbol(OFFSET, idx) # offset will have shape [M, K2] + rpe_reg = tkw.read( + rpe, + mapping=offset_mapping, + elements_per_thread=LOAD_ELEMS_PER_THREAD_QK) + + # 4. Tadaaaa. + x_j = x_j + rpe_reg + + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD_PV) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + reciprocal_sum = tkw.reciprocal(res_sum) + res = res_mm * reciprocal_sum + tkw.write(res, + c, + mapping=mapping, + elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD_QK: + get_mfma_load_elems_per_thread(mfma_variant[0]), + LOAD_ELEMS_PER_THREAD_PV: + get_mfma_load_elems_per_thread(mfma_variant[1]), + STORE_ELEMS_PER_THREAD: + get_mfma_store_elems_per_thread(mfma_variant[1]), + BLOCK_B: 1, + BLOCK_M: 128, + BLOCK_N: 64, + BLOCK_K2: 64, + B: shape.num_query_heads, + M: shape.query_seq_len, + N: shape.head_size_kv, + K1: shape.head_size, + K2: shape.kv_seq_len, + } + + dynamic_symbols = [] + dynamic_symbols_map = {} + if dynamic_dims: + dynamic_symbols_map[M] = hyperparams[M] + dynamic_symbols_map[N] = hyperparams[N] + dynamic_symbols_map[B] = hyperparams[B] + dynamic_symbols_map[K2] = hyperparams[K2] + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(B) + dynamic_symbols.append(K2) + del hyperparams[M] + del hyperparams[N] + del hyperparams[B] + del hyperparams[K2] + + return base_attention, hyperparams, dynamic_symbols, dynamic_symbols_map diff --git a/playground/attention_with_rpe_test.py b/playground/attention_with_rpe_test.py new file mode 100644 index 000000000..af49697ed --- /dev/null +++ b/playground/attention_with_rpe_test.py @@ -0,0 +1,185 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import math +import torch +from torch.profiler import profile, record_function, ProfilerActivity +from torch.nn import functional as F +from torch.testing import assert_close +from typing import Any, Callable + +from iree.turbine.kernel.gen import TestLaunchContext +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.templates.attention_common import AttentionShape +from iree.turbine.kernel.wave.templates.vanilla_attention import ( + get_vanilla_attention_kernel as get_vanilla_tkw_attention_kernel) +from iree.turbine.kernel.wave.utils import ( + device_randn, + device_zeros, + get_default_run_config, + to_default_device, +) +from attention_with_rpe_template import ( + get_vanilla_attention_kernel as + get_vanilla_tkw_attention_with_rpe_output_kernel) + +torch.manual_seed(0) +torch.set_printoptions( + linewidth=1000000, + threshold=1000000, + precision=3, +) + + +### TKW Harness +def run(fun: Callable, hparams, *args) -> Any: + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + with torch.no_grad(): # Disable gradient calculations + with TestLaunchContext( + hparams, + canonicalize=True, + compile_config={"print-ir-after": "all"}, + run=True, + run_config=get_default_run_config(), + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + fun(*args) + + print( + prof.key_averages(group_by_input_shape=True).table( + sort_by="self_cuda_time_total", row_limit=10)) + + +################################################################################# +# INIT VALS +################################################################################# +# num_query_heads, num_kv_heads, head_size, head_size_kv +shape = AttentionShape(128, 128, 128, 128) +shape.query_seq_len = 128 +shape.kv_seq_len = 128 + +assert shape.num_query_heads == shape.num_kv_heads, \ + "expected query and kv to have the same number of heads!" + +q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size) +k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size) +v_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size_kv) +o_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size_kv) + +q = device_randn(q_shape, dtype=torch.float16) +k = device_randn(k_shape, dtype=torch.float16) +v = device_randn(v_shape, dtype=torch.float16) +tkw_attention_output = device_zeros(o_shape, dtype=torch.float32) +tkw_attention_with_rpe_output = device_zeros(o_shape, dtype=torch.float32) + +log2e = 1.44269504089 +dk_sqrt = math.sqrt(1.0 / q.shape[-1]) + +################################################################################# +# T5 RPE INIT VALS +################################################################################# +# T5 RPE parameter +max_context_length = 33 + +# Applied pre-softmax on the MMA'ed result so f32. +# Provision more room for clipping and adding 0 at the boundaries. +rpe = device_zeros(1000 + max_context_length + 2, dtype=torch.float32) +rpe = rpe[:max_context_length + 2].view(max_context_length + 2) +rpe.copy_(device_randn(max_context_length + 2, dtype=torch.float32)) +rpe[0] = 0 +rpe[max_context_length + 1] = 0 + + +def t5_rpe_masked_cond(rpe, max_context_length: int, sequence_length: int, + dtype): + positions = to_default_device(torch.arange(sequence_length)) + pos_diff = positions.unsqueeze(1) - positions.unsqueeze(0) + mask = to_default_device((pos_diff >= 0) + & (pos_diff <= max_context_length)) + rpe_cond = device_zeros(sequence_length, sequence_length, dtype=dtype) + rpe_cond[mask] = rpe[pos_diff[mask]] + return rpe_cond + + +# rpe_cond is used by torch only +rpe_cond = t5_rpe_masked_cond(rpe, + max_context_length=max_context_length, + sequence_length=shape.kv_seq_len, + dtype=tkw_attention_with_rpe_output.dtype) + +################################################################################# +# TORCH ATTENTION and ATTENTION + RPE +################################################################################# +torch_attention_ref_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None) + +a = torch.matmul(q, k.transpose(-1, -2)) * dk_sqrt +torch_attention_output = torch.matmul(torch.softmax(a, dim=-1), v) + +# Sanity check that torch_attention_output and torch_attention_ref_output are +# the same so we can inject RPE pre-softmax and compute the delta. +# We will test that the delta post-softmax is the same for torch and TKW. +assert_close(torch_attention_output, + torch_attention_ref_output, + atol=2e-3, + rtol=2e-3) + +a += rpe_cond.unsqueeze(0) +torch_attention_with_rpe_output = torch.matmul(F.softmax(a, dim=-1), v) +torch_rpe_delta_output = torch_attention_with_rpe_output - torch_attention_output + +################################################################################# +# TKW BASE ATTENTION +################################################################################# +### Reference version +tkw_attention, hyperparams, dynamic_symbols, dynamic_symbols_map = \ + get_vanilla_tkw_attention_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, + MMAType.F32_16x16x16_F16], + dynamic_dims=False) + + +def attention(tq, tk, tv, toutput): + tkw_attention(tq, tk, tv, toutput) + + +run(attention, hyperparams, q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), + tkw_attention_output) + +assert_close(torch_attention_output.to(dtype=tkw_attention_output.dtype), + tkw_attention_output, + atol=2e-3, + rtol=2e-3) + +### RPE version +tkw_attention_with_rpe, hyperparams, dynamic_symbols, dynamic_symbols_map = \ + get_vanilla_tkw_attention_with_rpe_output_kernel( + shape, + mfma_variant=[MMAType.F32_16x16x16_F16, + MMAType.F32_16x16x16_F16], + dynamic_dims=False, + max_context_length = max_context_length + 2) + + +def attention_with_rpe(tq, tk, tv, trpe, toutput): + mb = tkw_attention_with_rpe(tq, tk, tv, trpe, toutput) + print(mb.module_op) + + +run(attention_with_rpe, hyperparams, q * dk_sqrt * log2e, k, + v.permute([0, 2, 1]), rpe, tkw_attention_with_rpe_output) + +tkw_rpe_delta_output = tkw_attention_with_rpe_output - tkw_attention_output +# print(tkw_rpe_delta_output) + +assert_close(torch_rpe_delta_output.to(dtype=tkw_rpe_delta_output.dtype), + tkw_rpe_delta_output, + atol=2e-3, + rtol=2e-3)