Skip to content

Commit

Permalink
make rpe compile, still produces garbage
Browse files Browse the repository at this point in the history
  • Loading branch information
ftynse committed Feb 7, 2025
1 parent f4ec891 commit 3015755
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 136 deletions.
9 changes: 4 additions & 5 deletions iree/turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,18 +854,17 @@ def cast_kernel_buffer(
return value, MemRefType(ir_type), py_type


def cast_vector(emitter: ThreadEmitter,
value,
*,
element_type: Optional[IrType] = None):
def cast_vector(
emitter: ThreadEmitter, value, *, element_type: Optional[IrType] = None
):
proxy_value = cast_py_value(emitter, value)

# Cast scalar types correctly first.
if element_type and not ShapedType.isinstance(proxy_value.ir_value.type):
# Implicit scalar type promotion.
proxy_value = ScalarBuilder.to_dtype(proxy_value, element_type)

print(f"proxy_value {proxy_value}")
# print(f"proxy_value {proxy_value}")
value = proxy_value.ir_value

# After scalar promotion, promote to vector.
Expand Down
101 changes: 67 additions & 34 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,13 +721,15 @@ def _build_mask(

def _construct_gather_scatter_indices(
emitter: WaveEmitter,
# TODO TODO TODO fix typo
symbolc_shape: tuple[IndexExpr],
index: tuple[IndexExpr],
mapping: IndexMapping,
elements_per_thread: int,
is_read: bool,
dynamic_vals: tuple[Any, ...],
is_contiguous: bool,
vector_shaped_symbols={},
) -> tuple[OpResult, OpResult, OpResult]:
# Apply symbolc_shape order to indices, e.g. if original mapping is
# {M: iter(0), N: iter(1)} and symbolc_shape is (N, M), result will
Expand Down Expand Up @@ -795,34 +797,46 @@ def extract0(src):
offsets = []
strides = strides_from_symbolic_shape(idxc, symbolc_shape, allow_mixed_shapes=True)
start_indices_offset = _compute_offset(start_indices, strides)
for i in range(elements_per_thread):
# Update fastest dim, i.e. in case of identity mapping it will
# be equivalent to just vector.load
subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)]
subs[fastest_dim] = (subs[fastest_dim][0], start_indices_orig[fastest_dim] + i)
indices = [i.subs(subs) for i in index_mapping]

# First, we build indices as if resulting gather/scatter `start_indices`
# are 0 as mapping expression may depend on absolute value of index
# (e.g. `index % 32`). Then we adjust for the non-0 `start_indices` by
# subtracting computed previously linear `start_indices_offset`. For
# simple cases like transpose, the resulting expression should fold into
# simple constant while more complex expressions may requires actual
# arith ops on dynamic values.
offset = _compute_offset(indices, strides) - start_indices_offset
offset = subs_idxc(offset)

if offset.is_number:
# If resulted offset sympy expr is convertible to int constant it
# will be directly encoded into `arith.constant`.
# For non-constant expressions, we will generate a real sequence of
# arith ops and then `vector.insertelement` them into offsets vec.
offset = int(offset)
else:
need_dynamic_offsets = True
break
# TODO TODO TODO we don't necessarily care if they are vector shaped, but
# if they are indxed by the fastest varying dimension?
# Note that we may want to "expand" the symbol to per-element
# copies and trigger `need_dynamic_offests` below
if len(start_indices_offset.free_symbols.intersection(vector_shaped_symbols)) != 0:
need_dynamic_offsets = True

if not need_dynamic_offsets:
for i in range(elements_per_thread):
# Update fastest dim, i.e. in case of identity mapping it will
# be equivalent to just vector.load
subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)]
subs[fastest_dim] = (
subs[fastest_dim][0],
start_indices_orig[fastest_dim] + i,
)
indices = [i.subs(subs) for i in index_mapping]

# First, we build indices as if resulting gather/scatter `start_indices`
# are 0 as mapping expression may depend on absolute value of index
# (e.g. `index % 32`). Then we adjust for the non-0 `start_indices` by
# subtracting computed previously linear `start_indices_offset`. For
# simple cases like transpose, the resulting expression should fold into
# simple constant while more complex expressions may requires actual
# arith ops on dynamic values.
offset = _compute_offset(indices, strides) - start_indices_offset
offset = subs_idxc(offset)

if offset.is_number:
# If resulted offset sympy expr is convertible to int constant it
# will be directly encoded into `arith.constant`.
# For non-constant expressions, we will generate a real sequence of
# arith ops and then `vector.insertelement` them into offsets vec.
offset = int(offset)
else:
need_dynamic_offsets = True
break

offsets.append(offset)
offsets.append(offset)

offsets_vec_type = VectorType.get([elements_per_thread], IndexType.get())
if need_dynamic_offsets:
Expand All @@ -834,13 +848,22 @@ def extract0(src):
)
subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)]
# Last item in `subs` corresponds to last item in `start_indices_orig`
# which is fastest changing dim.
# Replacing last element with `idxc.iota(elements_per_thread)` will
# generate vectorized index code, each element in it corresponding to
# individual vector element index.
# which is fastest changing dim. Replacing last element with
# `idxc.iota(elements_per_thread)` will generate vectorized index code,
# each element in it corresponding to individual vector element index.
#
# TODO TODO TODO: vector shaped symbol means we can't just iota here
# instead we should just take the value of the symbol; we should instead
# somehow get different values of OFFSET (or other vector shaped
# symbols) into the `get_sympy_index` below
#
# we also need to take care if there are several symbols, especially a
# mix of constant and non-constant symbols...
#
# what happens when there are no symbols?
subs[-1] = (
subs[-1][0],
start_indices_orig[-1] + idxc.iota(elements_per_thread),
start_indices_orig[-1], # + idxc.iota(elements_per_thread),
)
dynamic_vals_map = {
sym: val
Expand Down Expand Up @@ -905,6 +928,15 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
dyn_vals = tuple(
cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals
)

# TODO TODO TODO we can sink this down, actually...
vector_shaped_symbols = set(
sym
for sym, value in emitter.dynamic_dims.items()
if isinstance(value.type, ShapedType)
and math.prod(ShapedType(value.type).shape) != 1
)

start_indices, offsets_vec, mask = _construct_gather_scatter_indices(
emitter=emitter,
symbolc_shape=input_shape,
Expand All @@ -914,6 +946,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
is_read=True,
dynamic_vals=dyn_vals,
is_contiguous=get_custom(node).is_contiguous_vec(),
vector_shaped_symbols=vector_shaped_symbols,
)

zero = get_constant_attr(0, element_type)
Expand Down Expand Up @@ -1326,8 +1359,7 @@ def handle_and_op(lhs: Value, rhs: Value) -> OpResult:
if _is_integer_like_type(element_type):
result = arith_d.andi(lhs, rhs)
else:
raise ValidationError(
f"Found unhandled operand type for le: {element_type}")
raise ValidationError(f"Found unhandled operand type for le: {element_type}")
return result


Expand Down Expand Up @@ -1362,6 +1394,7 @@ def handle_minimum(lhs: Value, rhs: Value) -> OpResult:
)
return result


###############################################################################
# Unary math Ops
###############################################################################
Expand Down Expand Up @@ -1750,7 +1783,7 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node):
dst_vector_type = VectorType.get(src_vector_type.shape, dst_elem_type)

if src_vector_type == dst_vector_type:
emitter.bind_node_proxy(node, vector_src)
emitter.bind_node_proxy(node, IRProxyValue(vector_src))
return

is_src_float = _is_float_type(src_elem_type)
Expand Down
10 changes: 10 additions & 0 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,7 @@ def check_is_mapping_contiguous(
return True

# TODO: Better dyn vals analysis.
# TODO TODO TODO: we also need to check if there are additional sybols in the mapping
if mapping.num_dynamic_vals != 0:
return False

Expand All @@ -1437,6 +1438,15 @@ def check_is_mapping_contiguous(
index_mapping = tuple(subs_idxc(i) for i in index_mapping)
iters = mapping.iters

# TODO TODO TODO at this point, if the symbols present in the index are not
# known to be scalars or themselves contiguous, we shouldn't say the read is contiguous
#
# TODO TODO TODO we should thread through the fact that _some_ symbols have vector shape
for expr in index_mapping:
if len(expr.free_symbols - mapping.iters.keys()) != 0:
print("### avoided")
return False

subs = [(sym, sym + int(i == len(iters) - 1)) for i, sym in enumerate(iters)]
diff = [
approximate_difference(
Expand Down
73 changes: 30 additions & 43 deletions playground/attention_with_rpe_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
from iree.turbine.kernel.wave.templates.attention_common import AttentionShape


def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType,
dynamic_dims: bool,
max_context_length: Optional[int]):
def get_vanilla_attention_kernel(
shape: AttentionShape,
mfma_variant: MMAType,
dynamic_dims: bool,
max_context_length: Optional[int],
):
# RPE
ZERO = tkl.sym.ZERO
OFFSET = tkl.sym.OFFSET
Expand All @@ -44,9 +47,7 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType,
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [
tkw.WorkgroupConstraint(M, BLOCK_M, 0)
]
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)]
constraints += [tkw.TilingConstraint(K2, BLOCK_K2)]
Expand All @@ -65,11 +66,7 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType,
threads_per_wave=64,
waves_per_block=(4, 1, 1),
mma_type=mfma_variant[1],
vector_shapes={
B: 0,
M: Mvec,
N: Nvec
},
vector_shapes={B: 0, M: Mvec, N: Nvec},
)
]

Expand All @@ -79,17 +76,9 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType,
i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.iterator(2)
mapping = tkw.IndexMapping(num_iterators=3,
inputs={
B: i,
N: j,
M: k
},
outputs={
B: i,
M: k,
N: j
})
mapping = tkw.IndexMapping(
num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j}
)

offset_mapping = tkw.IndexMapping(
num_iterators=1,
Expand All @@ -99,9 +88,11 @@ def get_vanilla_attention_kernel(shape: AttentionShape, mfma_variant: MMAType,

use_t5_rpe = max_context_length is not None
if use_t5_rpe:
rpe_layout = tkl.MemoryLayout(shape=[
max_context_length,
])
rpe_layout = tkl.MemoryLayout(
shape=[
max_context_length,
]
)
assert use_t5_rpe, "use_t5_rpe needed until rpe arg can DCE without crashing"

@tkw.wave(constraints)
Expand Down Expand Up @@ -138,16 +129,18 @@ def repeat(
# When fusing into the FA variant, adding locally before the max and
# the partial softmax should be equivalent.
if use_t5_rpe:
ZERO = tkl.Register[M, K2, tkl.i64](0)
MAX = tkl.Register[M, K2, tkl.i64](max_context_length)
ZERO = tkl.Register[B, M, K2, tkl.i64](0)
MAX = tkl.Register[B, M, K2, tkl.i64](max_context_length)
# 1. Indices i and j broadcasted along K2 with a twist:
# here we use *static* information that is *implicitly* encoded
# in the *transformation*: under the distribution constraints
# specified we know that the shape [M] will eventually resolve
# to [1] and can thus be "cast + broadcast" to [K2].
i = tkw.self_index(M, tkl.i64, elements_per_thread=1)
i = tkw.broadcast(i, target_shape=[M, K2])
j = tkw.self_index(K2, tkl.i64, elements_per_thread=1)
i = tkw.broadcast(i, target_shape=[B, M, K2])
j = tkw.self_index(
K2, tkl.i64, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK
)

# 2. Clip i - j to the proper bucket in [0, max_context_length]
# to represent the following:
Expand All @@ -163,18 +156,18 @@ def repeat(
# idx = tkw.minimum(idx, MAX)

# select variant.
idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX),
i - j, ZERO)
idx = tkw.select(tkw.and_op(i - j >= ZERO, i - j <= MAX), i - j, ZERO)

# 3. Read indirect into the 1-D rpe array via offset_mapping.
tkw.set_symbol(OFFSET, idx) # offset will have shape [M, K2]
rpe_reg = tkw.read(
rpe,
mapping=offset_mapping,
elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
elements_per_thread=LOAD_ELEMS_PER_THREAD_QK,
)

# 4. Tadaaaa.
x_j = x_j + rpe_reg + tkw.cast(ZERO * idx, tkl.i64)
x_j = x_j + rpe_reg + tkw.cast(ZERO * idx, tkl.f32)

m_j = tkw.max(x_j, partial_max, dim=K2)
e_delta_max = tkw.exp2(partial_max - m_j)
Expand All @@ -191,19 +184,13 @@ def repeat(
res_max, res_sum, res_mm = repeat
reciprocal_sum = tkw.reciprocal(res_sum)
res = res_mm * reciprocal_sum
tkw.write(res,
c,
mapping=mapping,
elements_per_thread=STORE_ELEMS_PER_THREAD)
tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD)

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
LOAD_ELEMS_PER_THREAD_QK:
get_mfma_load_elems_per_thread(mfma_variant[0]),
LOAD_ELEMS_PER_THREAD_PV:
get_mfma_load_elems_per_thread(mfma_variant[1]),
STORE_ELEMS_PER_THREAD:
get_mfma_store_elems_per_thread(mfma_variant[1]),
LOAD_ELEMS_PER_THREAD_QK: get_mfma_load_elems_per_thread(mfma_variant[0]),
LOAD_ELEMS_PER_THREAD_PV: get_mfma_load_elems_per_thread(mfma_variant[1]),
STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant[1]),
BLOCK_B: 1,
BLOCK_M: 128,
BLOCK_N: 64,
Expand Down
Loading

0 comments on commit 3015755

Please sign in to comment.