Skip to content

Commit

Permalink
Add T5 RPE variant
Browse files Browse the repository at this point in the history
Signed-off-by: Nicolas Vasilache <[email protected]>
  • Loading branch information
nicolasvasilache committed Feb 4, 2025
1 parent 7ba8dcf commit 06180db
Show file tree
Hide file tree
Showing 3 changed files with 413 additions and 0 deletions.
Empty file added playground/__init__.py
Empty file.
228 changes: 228 additions & 0 deletions playground/attention_with_rpe_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from dataclasses import dataclass
from typing import Optional

import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.constraints import MMAType
from iree.turbine.kernel.wave.utils import (
get_mfma_load_elems_per_thread,
get_mfma_store_elems_per_thread,
)
from iree.turbine.kernel.wave.templates.attention_common import AttentionShape


def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType,
dynamic_dims: bool,
max_context_length: Optional[int]):
# RPE
ZERO = tkl.sym.ZERO
OFFSET = tkl.sym.OFFSET

# Input sizes
B = tkl.sym.B
M = tkl.sym.M
N = tkl.sym.N
K1 = tkl.sym.K1
K2 = tkl.sym.K2
# Workgroup tile sizes
BLOCK_B = tkl.sym.BLOCK_B
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K2 = tkl.sym.BLOCK_K2
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
LOAD_ELEMS_PER_THREAD_QK = index_symbol("LOAD_ELEMS_PER_THREAD_QK")
LOAD_ELEMS_PER_THREAD_PV = index_symbol("LOAD_ELEMS_PER_THREAD_PV")
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [
tkw.WorkgroupConstraint(M, BLOCK_M, 0)
]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)]
constraints += [tkw.TilingConstraint(K2, BLOCK_K2)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 1)]

if mfma_variant[1] == MMAType.F32_16x16x16_F16:
Mvec = 16
Nvec = 16
if mfma_variant[1] == MMAType.F32_32x32x8_F16:
Mvec = 32
Nvec = 32

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(4, 1, 1),
mma_type=mfma_variant[1],
vector_shapes={
B: 0,
M: Mvec,
N: Nvec
},
)
]

if dynamic_dims:
constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.iterator(2)
mapping = tkw.IndexMapping(num_iterators=3,
inputs={
B: i,
N: j,
M: k
},
outputs={
B: i,
M: k,
N: j
})

offset_mapping = tkw.IndexMapping(
num_iterators=1,
inputs={K2: i + OFFSET},
outputs={K2: i},
)

use_t5_rpe = max_context_length is not None
if use_t5_rpe:
rpe_layout = tkl.MemoryLayout(shape=[
max_context_length,
])
assert use_t5_rpe, "use_t5_rpe needed until rpe arg can DCE without crashing"

@tkw.wave(constraints)
def base_attention(
q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16],
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16],
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16],
# TODO: if not use_t5_rpe, this will DCE; atm DCE on blockargs crashes.
rpe: tkl.Memory[K2, GLOBAL_ADDRESS_SPACE, tkl.f32, rpe_layout],
c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[B, N, M, tkl.f32](0.0)
init_sum = tkl.Register[B, M, tkl.f32](0.0)
init_max = tkl.Register[B, M, tkl.f32](-1e6)

# This microkernel encodes the fact that if the reduction
# dimension were tiled, then we would need to materialize a loop.
@tkw.reduction(K2, init_args=[init_max, init_sum, c_reg])
def repeat(
partial_max: tkl.Register[B, M, tkl.f32],
partial_sum: tkl.Register[B, M, tkl.f32],
acc: tkl.Register[B, N, M, tkl.f32],
):
imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0)
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0])
x_j = tkw.permute(inner_acc, target_shape=[B, M, K2])

####################################################################
# T5 RPE
####################################################################
# Fused T5 RPE adds attention bias pre-softmax normalization.
# When fusing into the FA variant, adding locally before the max and
# the partial softmax should be equivalent.
if use_t5_rpe:
ZERO = tkl.Register[M, K2, tkl.i64](0)
MAX = tkl.Register[M, K2, tkl.i64](max_context_length)
# 1. Indices i and j broadcasted along K2 with a twist:
# here we use *static* information that is *implicitly* encoded
# in the *transformation*: under the distribution constraints
# specified we know that the shape [M] will eventually resolve
# to [1] and can thus be "cast + broadcast" to [K2].
i = tkw.self_index(M, tkl.i64, elements_per_thread=1)
i = tkw.broadcast(i, target_shape=[M, K2])
j = tkw.self_index(K2, tkl.i64, elements_per_thread=1)
# j = tkw.broadcast(j, target_shape=[M, K2])

# 2. Clip i - j to the proper bucket in [0, max_context_length]
# to represent the following:
# - if 0 < i - j < max_context_length
# then x_j += rpe_reg
# - otherwise (i.e. i - j == {0, max_context_length})
# then x_j += 0
# TODO: we may need scaling adjustements depending on how we want
# to do bucketing; atm it is bucketing of size 1.
idx = tkw.minimum(tkw.maximum(i - j, ZERO), MAX)

# 3. Read indirect into the 1-D rpe array via offset_mapping.
tkw.set_symbol(OFFSET, idx) # offset will have shape [M, K2]
rpe_reg = tkw.read(
rpe,
mapping=offset_mapping,
elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)

# 4. Tadaaaa.
x_j = x_j + rpe_reg

m_j = tkw.max(x_j, partial_max, dim=K2)
e_delta_max = tkw.exp2(partial_max - m_j)
e_delta = tkw.exp2(x_j - m_j)
e_init = partial_sum * e_delta_max
d_j = tkw.sum(e_delta, e_init, dim=K2)
imm_f16 = tkw.cast(e_delta, tkl.f16)
v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD_PV)
new_acc = acc * e_delta_max
acc = tkw.mma(v_reg, imm_f16, new_acc)
return m_j, d_j, acc

# repeat represents the results of the loop
res_max, res_sum, res_mm = repeat
reciprocal_sum = tkw.reciprocal(res_sum)
res = res_mm * reciprocal_sum
tkw.write(res,
c,
mapping=mapping,
elements_per_thread=STORE_ELEMS_PER_THREAD)

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
LOAD_ELEMS_PER_THREAD_QK:
get_mfma_load_elems_per_thread(mfma_variant[0]),
LOAD_ELEMS_PER_THREAD_PV:
get_mfma_load_elems_per_thread(mfma_variant[1]),
STORE_ELEMS_PER_THREAD:
get_mfma_store_elems_per_thread(mfma_variant[1]),
BLOCK_B: 1,
BLOCK_M: 128,
BLOCK_N: 64,
BLOCK_K2: 64,
B: shape.num_query_heads,
M: shape.query_seq_len,
N: shape.head_size_kv,
K1: shape.head_size,
K2: shape.kv_seq_len,
}

dynamic_symbols = []
dynamic_symbols_map = {}
if dynamic_dims:
dynamic_symbols_map[M] = hyperparams[M]
dynamic_symbols_map[N] = hyperparams[N]
dynamic_symbols_map[B] = hyperparams[B]
dynamic_symbols_map[K2] = hyperparams[K2]
dynamic_symbols.append(M)
dynamic_symbols.append(N)
dynamic_symbols.append(B)
dynamic_symbols.append(K2)
del hyperparams[M]
del hyperparams[N]
del hyperparams[B]
del hyperparams[K2]

return base_attention, hyperparams, dynamic_symbols, dynamic_symbols_map
Loading

0 comments on commit 06180db

Please sign in to comment.