Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Work on buffer ops #492

Merged
merged 17 commits into from
Feb 14, 2025
192 changes: 145 additions & 47 deletions iree/turbine/kernel/wave/codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@
write,
)

from ..utils import subs_idxc, find_index_bounds
from ..utils import safe_subs, subs_idxc, find_index_bounds

from ..._support.indexing import IndexingContext, IndexExpr, IndexSequence, index_symbol
from ...lang.wave_types import IndexMapping
from ...lang.global_symbols import *

from .emitter import (
WaveEmitter,
Expand Down Expand Up @@ -75,15 +76,40 @@ def _get_start_indices(
return start_indices


def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]:
"""
Split index expr into thread-dependent and thread-independent parts
"""
subs_wg = {WORKGROUP_0: 0, WORKGROUP_1: 0, WORKGROUP_2: 0}
subs_th = {THREAD_0: 0, THREAD_1: 0, THREAD_2: 0}
# Replace all wg symbols with 0s to get thread-dependent index.
# All dynamic values will also be part of thread-index.
thread_dependend_index = safe_subs(src, subs_wg)

# Compute thread-independent index as `orig_index - thread_dependend_index`
# All thread symbols should cancel-out in the result, but to be sure
# replace all thread symbols by 0 in the result.
# We cannot just replace all thread symbols without the subtraction as
# any constant or dynamic values will end up in both expressions.
thread_indepdndent_index = sympy.simplify(
safe_subs(src - thread_dependend_index, subs_th)
)
return thread_indepdndent_index, thread_dependend_index


def _build_start_indices(
emitter: WaveEmitter,
src_indices: dict[IndexExpr, IndexSequence | IndexExpr],
dynamic_values: dict[IndexExpr, Any] = {},
) -> list[OpResult]:
return [
gen_sympy_index(add_emitter_subs(emitter, dynamic_values), i)
for i in _get_start_indices(src_indices)
]
) -> tuple[list[OpResult], list[OpResult], list[OpResult]]:
start_indices = _get_start_indices(src_indices)
split_indices = [_split_index(i) for i in start_indices]
subs = add_emitter_subs(emitter, dynamic_values)
indices = [gen_sympy_index(subs, i) for i in start_indices]
indices_wg = [gen_sympy_index(subs, i[0]) for i in split_indices]
indices_th = [gen_sympy_index(subs, i[1]) for i in split_indices]

return indices, indices_wg, indices_th


def _get_fastest_index(indices: dict[IndexExpr, IndexSequence]):
Expand Down Expand Up @@ -154,7 +180,7 @@ def _construct_gather_scatter_indices(
is_read: bool,
dynamic_vals: tuple[Any, ...],
is_contiguous: bool,
) -> tuple[OpResult, OpResult, OpResult]:
) -> tuple[list[OpResult], list[OpResult], list[OpResult], OpResult, OpResult]:
# Apply symbolc_shape order to indices, e.g. if original mapping is
# {M: iter(0), N: iter(1)} and symbolc_shape is (N, M), result will
# be (iter(1), iter(0))
Expand Down Expand Up @@ -200,10 +226,10 @@ def extract0(src):
for sym, val in zip(mapping.dynamic_val_indices.keys(), dynamic_vals)
}
if is_contiguous:
start_indices = _build_start_indices(
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, result_index, dynamic_vals_map_start
)
return start_indices, None, mask
return start_indices, start_indices_wg, start_indices_th, None, mask

start_indices = _get_start_indices(result_index)
start_indices_orig = _get_start_indices(index)
Expand Down Expand Up @@ -255,7 +281,7 @@ def extract0(src):
# In case we need dynamic `offsets_vec`, set all `start_indices` to 0
# and encode entire index info in `offsets_vec`.
result_index = {key: 0 for key in symbolc_shape}
start_indices = _build_start_indices(
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, result_index, dynamic_vals_map_start
)
subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)]
Expand All @@ -278,18 +304,18 @@ def extract0(src):
_compute_offset(indices, strides),
)
else:
start_indices = _build_start_indices(
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, result_index, dynamic_vals_map_start
)
if offsets == list(range(elements_per_thread)):
return start_indices, None, mask
return start_indices, start_indices_wg, start_indices_th, None, mask

offsets = [IntegerAttr.get(IndexType.get(), off) for off in offsets]
offsets_vec = arith_d.ConstantOp(
offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type)
)

return start_indices, offsets_vec, mask
return start_indices, start_indices_wg, start_indices_th, offsets_vec, mask


def _get_max_buffer_size(elem_type: IrType) -> int:
Expand All @@ -301,7 +327,12 @@ def _get_max_buffer_size(elem_type: IrType) -> int:
return ((1 << 31) - 1) // (elem_type.width // 8)


def _linearize_memref(mem: Value, offsets: tuple[Value | int]) -> Value:
def _linearize_memref(
mem: Value,
offsets_wg: tuple[Value | int],
offsets_th: tuple[Value | int],
strides: tuple[Value],
) -> tuple[Value, Value]:
"""
Convert n-D memref into 1-D memref, suitable for buffer ops.

Expand All @@ -310,23 +341,27 @@ def _linearize_memref(mem: Value, offsets: tuple[Value | int]) -> Value:
no-op.
"""
memref_type = mem.type
rank = memref_type.rank
results = memref_d.extract_strided_metadata(mem)
base = results[0]
offset = results[1]
results = results[2:]
sizes = results[:rank]
strides = results[rank:]
offset = None
offset_th = None
overflow_flags = arith_d.IntegerOverflowFlags.nsw
for ind, size, stride in zip(offsets, sizes, strides):
if isinstance(ind, int):
ind = arith_d.constant(IndexType.get(), ind)

offset = arith_d.addi(
offset,
arith_d.muli(ind, stride, overflow_flags=overflow_flags),
overflow_flags=overflow_flags,
)
for ind_wg, ind_th, stride in zip(offsets_wg, offsets_th, strides):
if isinstance(ind_wg, int):
ind_wg = arith_d.constant(IndexType.get(), ind_wg)

if isinstance(ind_th, int):
ind_th = arith_d.constant(IndexType.get(), ind_th)

off_wg = arith_d.muli(ind_wg, stride, overflow_flags=overflow_flags)
if offset is None:
offset = off_wg
else:
offset = arith_d.addi(offset, off_wg, overflow_flags=overflow_flags)

off_th = arith_d.muli(ind_th, stride, overflow_flags=overflow_flags)
if offset_th is None:
offset_th = off_th
else:
offset_th = arith_d.addi(offset_th, off_th, overflow_flags=overflow_flags)

size_full = arith_d.constant(
IndexType.get(), _get_max_buffer_size(memref_type.element_type) - 1
Expand All @@ -342,30 +377,39 @@ def _linearize_memref(mem: Value, offsets: tuple[Value | int]) -> Value:
layout=Attribute.parse("strided<[1], offset: ?>"),
memory_space=memory_space,
)
return memref_d.reinterpret_cast(
resut_type,
base,
offsets=[offset],
sizes=[size_full],
strides=[],
static_offsets=[dyn_val],
static_sizes=[dyn_val],
static_strides=[1],
return (
memref_d.reinterpret_cast(
resut_type,
mem,
offsets=[offset],
sizes=[size_full],
strides=[],
static_offsets=[dyn_val],
static_sizes=[dyn_val],
static_strides=[1],
),
offset_th,
)


def _create_vec_read(
emitter: WaveEmitter,
symbolic_shape: tuple[IndexExpr, ...],
mem: Value,
vector_type: IrType,
start_indices: tuple[Value],
start_indices_wg: tuple[Value],
start_indices_th: tuple[Value],
elements_per_thread: int,
mask: Optional[Value],
offsets_vec: Optional[Value],
) -> Value:
if mask is None and offsets_vec is None:
return vector_d.load(vector_type, mem, start_indices)

# Only use buffer ops if it's gather/scatter on global mem.
use_buffer_ops = offsets_vec is not None and mem.type.memory_space is None

element_type = vector_type.element_type
if offsets_vec is None:
offsets_vec_type = VectorType.get(vector_type.shape, IndexType.get())
Expand All @@ -377,10 +421,20 @@ def _create_vec_read(
zero = get_constant_attr(0, element_type)
zero = arith_d.constant(element_type, zero)

if emitter.params.get("use_buffer_load_ops", False):
strides = strides_from_symbolic_shape(
IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True
)
buffer_ops_enabled = emitter.params.get("use_buffer_load_ops", False)
has_int_strides = all(isinstance(s, int) for s in strides)
if buffer_ops_enabled and has_int_strides and use_buffer_ops:
result = vector_d.splat(vector_type, zero)

data = _linearize_memref(mem, start_indices)
strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides]
data, offset_th = _linearize_memref(
mem, start_indices_wg, start_indices_th, strides
)
offset_th = vector_d.splat(offsets_vec.type, offset_th)
offsets_vec = arith_d.addi(offsets_vec, offset_th)
if mask is not None:
i32 = IntegerType.get_signless(32)
i32vec = VectorType.get([elements_per_thread], i32)
Expand Down Expand Up @@ -441,17 +495,22 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
input_shape = _get_symbolic_shape(memory)
elements_per_thread = cast_py_literal(emitter, elements_per_thread)
if get_custom(node).has_identity_mapping():
start_indices = _build_start_indices(emitter, index)
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, index
)
mask = _build_mask(
emitter,
index,
elements_per_thread,
)
result = _create_vec_read(
emitter,
input_shape,
kb_src,
vector_type,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
mask,
offsets_vec=None,
Expand All @@ -460,7 +519,13 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
dyn_vals = tuple(
cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals
)
start_indices, offsets_vec, mask = _construct_gather_scatter_indices(
(
start_indices,
start_indices_wg,
start_indices_th,
offsets_vec,
mask,
) = _construct_gather_scatter_indices(
emitter=emitter,
symbolc_shape=input_shape,
index=index,
Expand All @@ -472,9 +537,12 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
)
result = _create_vec_read(
emitter,
input_shape,
kb_src,
vector_type,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
mask,
offsets_vec,
Expand All @@ -485,9 +553,12 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):

def _create_vec_write(
emitter: WaveEmitter,
symbolic_shape: tuple[IndexExpr, ...],
mem: Value,
value: Value,
start_indices: tuple[Value],
start_indices_wg: tuple[Value],
start_indices_th: tuple[Value],
elements_per_thread: int,
mask: Optional[Value],
offsets_vec: Optional[Value],
Expand All @@ -496,6 +567,9 @@ def _create_vec_write(
vector_d.store(value, mem, start_indices)
return

# Only use buffer ops if it's gather/scatter on global mem.
use_buffer_ops = offsets_vec is not None and mem.type.memory_space is None

vector_type = value.type
element_type = vector_type.element_type
if offsets_vec is None:
Expand All @@ -505,8 +579,18 @@ def _create_vec_write(
offsets_vec_type, DenseElementsAttr.get(vals, offsets_vec_type)
)

if emitter.params.get("use_buffer_store_ops", False):
data = _linearize_memref(mem, start_indices)
strides = strides_from_symbolic_shape(
IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True
)
buffer_ops_enabled = emitter.params.get("use_buffer_store_ops", False)
has_int_strides = all(isinstance(s, int) for s in strides)
if buffer_ops_enabled and has_int_strides and use_buffer_ops:
strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides]
data, offset_th = _linearize_memref(
mem, start_indices_wg, start_indices_th, strides
)
offset_th = vector_d.splat(offsets_vec.type, offset_th)
offsets_vec = arith_d.addi(offsets_vec, offset_th)
if mask is not None:
i32 = IntegerType.get_signless(32)
i32vec = VectorType.get([elements_per_thread], i32)
Expand Down Expand Up @@ -564,13 +648,18 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
output_shape = _get_symbolic_shape(memory)
elements_per_thread = cast_py_literal(emitter, elements_per_thread)
if get_custom(node).has_identity_mapping():
start_indices = _build_start_indices(emitter, index)
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, index
)
mask = _build_mask(emitter, index, elements_per_thread)
_create_vec_write(
emitter,
output_shape,
kb_dest,
insert_vector,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
mask,
offsets_vec=None,
Expand All @@ -583,7 +672,13 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
dyn_vals = tuple(
cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals
)
start_indices, offsets_vec, mask = _construct_gather_scatter_indices(
(
start_indices,
start_indices_wg,
start_indices_th,
offsets_vec,
mask,
) = _construct_gather_scatter_indices(
emitter=emitter,
symbolc_shape=output_shape,
index=index,
Expand All @@ -596,9 +691,12 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):

_create_vec_write(
emitter,
output_shape,
kb_dest,
insert_vector,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
mask,
offsets_vec,
Expand Down
Loading
Loading