Skip to content

Commit

Permalink
[TKW] Work on buffer ops (iree-org#492)
Browse files Browse the repository at this point in the history
* Split read/write ops indexing on thread-dependent and
thread-independent.
* Use symbolic vals for strides instead of extracting them from memref.
* For now only replace gather/scatter ops with buffer ops.

---------

Signed-off-by: Ivan Butygin <[email protected]>
Signed-off-by: xintin <[email protected]>
  • Loading branch information
Hardcode84 authored and xintin committed Feb 14, 2025
1 parent e0f3544 commit 64df31f
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 90 deletions.
192 changes: 145 additions & 47 deletions iree/turbine/kernel/wave/codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@
write,
)

from ..utils import subs_idxc, find_index_bounds
from ..utils import safe_subs, subs_idxc, find_index_bounds

from ..._support.indexing import IndexingContext, IndexExpr, IndexSequence, index_symbol
from ...lang.wave_types import IndexMapping
from ...lang.global_symbols import *

from .emitter import (
WaveEmitter,
Expand Down Expand Up @@ -75,15 +76,40 @@ def _get_start_indices(
return start_indices


def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]:
"""
Split index expr into thread-dependent and thread-independent parts
"""
subs_wg = {WORKGROUP_0: 0, WORKGROUP_1: 0, WORKGROUP_2: 0}
subs_th = {THREAD_0: 0, THREAD_1: 0, THREAD_2: 0}
# Replace all wg symbols with 0s to get thread-dependent index.
# All dynamic values will also be part of thread-index.
thread_dependend_index = safe_subs(src, subs_wg)

# Compute thread-independent index as `orig_index - thread_dependend_index`
# All thread symbols should cancel-out in the result, but to be sure
# replace all thread symbols by 0 in the result.
# We cannot just replace all thread symbols without the subtraction as
# any constant or dynamic values will end up in both expressions.
thread_indepdndent_index = sympy.simplify(
safe_subs(src - thread_dependend_index, subs_th)
)
return thread_indepdndent_index, thread_dependend_index


def _build_start_indices(
emitter: WaveEmitter,
src_indices: dict[IndexExpr, IndexSequence | IndexExpr],
dynamic_values: dict[IndexExpr, Any] = {},
) -> list[OpResult]:
return [
gen_sympy_index(add_emitter_subs(emitter, dynamic_values), i)
for i in _get_start_indices(src_indices)
]
) -> tuple[list[OpResult], list[OpResult], list[OpResult]]:
start_indices = _get_start_indices(src_indices)
split_indices = [_split_index(i) for i in start_indices]
subs = add_emitter_subs(emitter, dynamic_values)
indices = [gen_sympy_index(subs, i) for i in start_indices]
indices_wg = [gen_sympy_index(subs, i[0]) for i in split_indices]
indices_th = [gen_sympy_index(subs, i[1]) for i in split_indices]

return indices, indices_wg, indices_th


def _get_fastest_index(indices: dict[IndexExpr, IndexSequence]):
Expand Down Expand Up @@ -154,7 +180,7 @@ def _construct_gather_scatter_indices(
is_read: bool,
dynamic_vals: tuple[Any, ...],
is_contiguous: bool,
) -> tuple[OpResult, OpResult, OpResult]:
) -> tuple[list[OpResult], list[OpResult], list[OpResult], OpResult, OpResult]:
# Apply symbolc_shape order to indices, e.g. if original mapping is
# {M: iter(0), N: iter(1)} and symbolc_shape is (N, M), result will
# be (iter(1), iter(0))
Expand Down Expand Up @@ -200,10 +226,10 @@ def extract0(src):
for sym, val in zip(mapping.dynamic_val_indices.keys(), dynamic_vals)
}
if is_contiguous:
start_indices = _build_start_indices(
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, result_index, dynamic_vals_map_start
)
return start_indices, None, mask
return start_indices, start_indices_wg, start_indices_th, None, mask

start_indices = _get_start_indices(result_index)
start_indices_orig = _get_start_indices(index)
Expand Down Expand Up @@ -255,7 +281,7 @@ def extract0(src):
# In case we need dynamic `offsets_vec`, set all `start_indices` to 0
# and encode entire index info in `offsets_vec`.
result_index = {key: 0 for key in symbolc_shape}
start_indices = _build_start_indices(
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, result_index, dynamic_vals_map_start
)
subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)]
Expand All @@ -278,18 +304,18 @@ def extract0(src):
_compute_offset(indices, strides),
)
else:
start_indices = _build_start_indices(
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, result_index, dynamic_vals_map_start
)
if offsets == list(range(elements_per_thread)):
return start_indices, None, mask
return start_indices, start_indices_wg, start_indices_th, None, mask

offsets = [IntegerAttr.get(IndexType.get(), off) for off in offsets]
offsets_vec = arith_d.ConstantOp(
offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type)
)

return start_indices, offsets_vec, mask
return start_indices, start_indices_wg, start_indices_th, offsets_vec, mask


def _get_max_buffer_size(elem_type: IrType) -> int:
Expand All @@ -301,7 +327,12 @@ def _get_max_buffer_size(elem_type: IrType) -> int:
return ((1 << 31) - 1) // (elem_type.width // 8)


def _linearize_memref(mem: Value, offsets: tuple[Value | int]) -> Value:
def _linearize_memref(
mem: Value,
offsets_wg: tuple[Value | int],
offsets_th: tuple[Value | int],
strides: tuple[Value],
) -> tuple[Value, Value]:
"""
Convert n-D memref into 1-D memref, suitable for buffer ops.
Expand All @@ -310,23 +341,27 @@ def _linearize_memref(mem: Value, offsets: tuple[Value | int]) -> Value:
no-op.
"""
memref_type = mem.type
rank = memref_type.rank
results = memref_d.extract_strided_metadata(mem)
base = results[0]
offset = results[1]
results = results[2:]
sizes = results[:rank]
strides = results[rank:]
offset = None
offset_th = None
overflow_flags = arith_d.IntegerOverflowFlags.nsw
for ind, size, stride in zip(offsets, sizes, strides):
if isinstance(ind, int):
ind = arith_d.constant(IndexType.get(), ind)

offset = arith_d.addi(
offset,
arith_d.muli(ind, stride, overflow_flags=overflow_flags),
overflow_flags=overflow_flags,
)
for ind_wg, ind_th, stride in zip(offsets_wg, offsets_th, strides):
if isinstance(ind_wg, int):
ind_wg = arith_d.constant(IndexType.get(), ind_wg)

if isinstance(ind_th, int):
ind_th = arith_d.constant(IndexType.get(), ind_th)

off_wg = arith_d.muli(ind_wg, stride, overflow_flags=overflow_flags)
if offset is None:
offset = off_wg
else:
offset = arith_d.addi(offset, off_wg, overflow_flags=overflow_flags)

off_th = arith_d.muli(ind_th, stride, overflow_flags=overflow_flags)
if offset_th is None:
offset_th = off_th
else:
offset_th = arith_d.addi(offset_th, off_th, overflow_flags=overflow_flags)

size_full = arith_d.constant(
IndexType.get(), _get_max_buffer_size(memref_type.element_type) - 1
Expand All @@ -342,30 +377,39 @@ def _linearize_memref(mem: Value, offsets: tuple[Value | int]) -> Value:
layout=Attribute.parse("strided<[1], offset: ?>"),
memory_space=memory_space,
)
return memref_d.reinterpret_cast(
resut_type,
base,
offsets=[offset],
sizes=[size_full],
strides=[],
static_offsets=[dyn_val],
static_sizes=[dyn_val],
static_strides=[1],
return (
memref_d.reinterpret_cast(
resut_type,
mem,
offsets=[offset],
sizes=[size_full],
strides=[],
static_offsets=[dyn_val],
static_sizes=[dyn_val],
static_strides=[1],
),
offset_th,
)


def _create_vec_read(
emitter: WaveEmitter,
symbolic_shape: tuple[IndexExpr, ...],
mem: Value,
vector_type: IrType,
start_indices: tuple[Value],
start_indices_wg: tuple[Value],
start_indices_th: tuple[Value],
elements_per_thread: int,
mask: Optional[Value],
offsets_vec: Optional[Value],
) -> Value:
if mask is None and offsets_vec is None:
return vector_d.load(vector_type, mem, start_indices)

# Only use buffer ops if it's gather/scatter on global mem.
use_buffer_ops = offsets_vec is not None and mem.type.memory_space is None

element_type = vector_type.element_type
if offsets_vec is None:
offsets_vec_type = VectorType.get(vector_type.shape, IndexType.get())
Expand All @@ -377,10 +421,20 @@ def _create_vec_read(
zero = get_constant_attr(0, element_type)
zero = arith_d.constant(element_type, zero)

if emitter.params.get("use_buffer_load_ops", False):
strides = strides_from_symbolic_shape(
IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True
)
buffer_ops_enabled = emitter.params.get("use_buffer_load_ops", False)
has_int_strides = all(isinstance(s, int) for s in strides)
if buffer_ops_enabled and has_int_strides and use_buffer_ops:
result = vector_d.splat(vector_type, zero)

data = _linearize_memref(mem, start_indices)
strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides]
data, offset_th = _linearize_memref(
mem, start_indices_wg, start_indices_th, strides
)
offset_th = vector_d.splat(offsets_vec.type, offset_th)
offsets_vec = arith_d.addi(offsets_vec, offset_th)
if mask is not None:
i32 = IntegerType.get_signless(32)
i32vec = VectorType.get([elements_per_thread], i32)
Expand Down Expand Up @@ -441,17 +495,22 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
input_shape = _get_symbolic_shape(memory)
elements_per_thread = cast_py_literal(emitter, elements_per_thread)
if get_custom(node).has_identity_mapping():
start_indices = _build_start_indices(emitter, index)
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, index
)
mask = _build_mask(
emitter,
index,
elements_per_thread,
)
result = _create_vec_read(
emitter,
input_shape,
kb_src,
vector_type,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
mask,
offsets_vec=None,
Expand All @@ -460,7 +519,13 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
dyn_vals = tuple(
cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals
)
start_indices, offsets_vec, mask = _construct_gather_scatter_indices(
(
start_indices,
start_indices_wg,
start_indices_th,
offsets_vec,
mask,
) = _construct_gather_scatter_indices(
emitter=emitter,
symbolc_shape=input_shape,
index=index,
Expand All @@ -472,9 +537,12 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
)
result = _create_vec_read(
emitter,
input_shape,
kb_src,
vector_type,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
mask,
offsets_vec,
Expand All @@ -485,9 +553,12 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):

def _create_vec_write(
emitter: WaveEmitter,
symbolic_shape: tuple[IndexExpr, ...],
mem: Value,
value: Value,
start_indices: tuple[Value],
start_indices_wg: tuple[Value],
start_indices_th: tuple[Value],
elements_per_thread: int,
mask: Optional[Value],
offsets_vec: Optional[Value],
Expand All @@ -496,6 +567,9 @@ def _create_vec_write(
vector_d.store(value, mem, start_indices)
return

# Only use buffer ops if it's gather/scatter on global mem.
use_buffer_ops = offsets_vec is not None and mem.type.memory_space is None

vector_type = value.type
element_type = vector_type.element_type
if offsets_vec is None:
Expand All @@ -505,8 +579,18 @@ def _create_vec_write(
offsets_vec_type, DenseElementsAttr.get(vals, offsets_vec_type)
)

if emitter.params.get("use_buffer_store_ops", False):
data = _linearize_memref(mem, start_indices)
strides = strides_from_symbolic_shape(
IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True
)
buffer_ops_enabled = emitter.params.get("use_buffer_store_ops", False)
has_int_strides = all(isinstance(s, int) for s in strides)
if buffer_ops_enabled and has_int_strides and use_buffer_ops:
strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides]
data, offset_th = _linearize_memref(
mem, start_indices_wg, start_indices_th, strides
)
offset_th = vector_d.splat(offsets_vec.type, offset_th)
offsets_vec = arith_d.addi(offsets_vec, offset_th)
if mask is not None:
i32 = IntegerType.get_signless(32)
i32vec = VectorType.get([elements_per_thread], i32)
Expand Down Expand Up @@ -564,13 +648,18 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
output_shape = _get_symbolic_shape(memory)
elements_per_thread = cast_py_literal(emitter, elements_per_thread)
if get_custom(node).has_identity_mapping():
start_indices = _build_start_indices(emitter, index)
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, index
)
mask = _build_mask(emitter, index, elements_per_thread)
_create_vec_write(
emitter,
output_shape,
kb_dest,
insert_vector,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
mask,
offsets_vec=None,
Expand All @@ -583,7 +672,13 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
dyn_vals = tuple(
cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals
)
start_indices, offsets_vec, mask = _construct_gather_scatter_indices(
(
start_indices,
start_indices_wg,
start_indices_th,
offsets_vec,
mask,
) = _construct_gather_scatter_indices(
emitter=emitter,
symbolc_shape=output_shape,
index=index,
Expand All @@ -596,9 +691,12 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):

_create_vec_write(
emitter,
output_shape,
kb_dest,
insert_vector,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
mask,
offsets_vec,
Expand Down
Loading

0 comments on commit 64df31f

Please sign in to comment.