Skip to content

Commit

Permalink
fix offset calculation
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 committed Feb 12, 2025
1 parent 98a6fcc commit eac4567
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions iree/turbine/kernel/wave/codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,15 @@ def _build_start_indices(
emitter: WaveEmitter,
src_indices: dict[IndexExpr, IndexSequence | IndexExpr],
dynamic_values: dict[IndexExpr, Any] = {},
) -> tuple[list[OpResult], list[OpResult]]:
split_indices = [_split_index(i) for i in _get_start_indices(src_indices)]
) -> tuple[list[OpResult], list[OpResult], list[OpResult]]:
start_indices = _get_start_indices(src_indices)
split_indices = [_split_index(i) for i in start_indices]
subs = add_emitter_subs(emitter, dynamic_values)
indices = [gen_sympy_index(subs, i) for i in start_indices]
indices_wg = [gen_sympy_index(subs, i[0]) for i in split_indices]
indices_th = [gen_sympy_index(subs, i[1]) for i in split_indices]

return indices_wg, indices_th
return indices, indices_wg, indices_th


def _get_fastest_index(indices: dict[IndexExpr, IndexSequence]):
Expand Down Expand Up @@ -167,7 +169,7 @@ def _construct_gather_scatter_indices(
is_read: bool,
dynamic_vals: tuple[Any, ...],
is_contiguous: bool,
) -> tuple[list[OpResult], list[OpResult], OpResult, OpResult]:
) -> tuple[list[OpResult], list[OpResult], list[OpResult], OpResult, OpResult]:
# Apply symbolc_shape order to indices, e.g. if original mapping is
# {M: iter(0), N: iter(1)} and symbolc_shape is (N, M), result will
# be (iter(1), iter(0))
Expand Down Expand Up @@ -213,10 +215,10 @@ def extract0(src):
for sym, val in zip(mapping.dynamic_val_indices.keys(), dynamic_vals)
}
if is_contiguous:
start_indices_wg, start_indices_th = _build_start_indices(
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, result_index, dynamic_vals_map_start
)
return start_indices_wg, start_indices_th, None, mask
return start_indices, start_indices_wg, start_indices_th, None, mask

start_indices = _get_start_indices(result_index)
start_indices_orig = _get_start_indices(index)
Expand Down Expand Up @@ -268,7 +270,7 @@ def extract0(src):
# In case we need dynamic `offsets_vec`, set all `start_indices` to 0
# and encode entire index info in `offsets_vec`.
result_index = {key: 0 for key in symbolc_shape}
start_indices_wg, start_indices_th = _build_start_indices(
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, result_index, dynamic_vals_map_start
)
subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)]
Expand All @@ -291,18 +293,18 @@ def extract0(src):
_compute_offset(indices, strides),
)
else:
start_indices_wg, start_indices_th = _build_start_indices(
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, result_index, dynamic_vals_map_start
)
if offsets == list(range(elements_per_thread)):
return start_indices_wg, start_indices_th, None, mask
return start_indices, start_indices_wg, start_indices_th, None, mask

offsets = [IntegerAttr.get(IndexType.get(), off) for off in offsets]
offsets_vec = arith_d.ConstantOp(
offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type)
)

return start_indices_wg, start_indices_th, offsets_vec, mask
return start_indices, start_indices_wg, start_indices_th, offsets_vec, mask


def _get_max_buffer_size(elem_type: IrType) -> int:
Expand Down Expand Up @@ -384,16 +386,14 @@ def _create_vec_read(
symbolic_shape: tuple[IndexExpr, ...],
mem: Value,
vector_type: IrType,
start_indices: tuple[Value],
start_indices_wg: tuple[Value],
start_indices_th: tuple[Value],
elements_per_thread: int,
mask: Optional[Value],
offsets_vec: Optional[Value],
) -> Value:
if mask is None and offsets_vec is None:
start_indices = [
arith_d.addi(i, j) for i, j in zip(start_indices_wg, start_indices_th)
]
return vector_d.load(vector_type, mem, start_indices)

is_gather = offsets_vec is not None
Expand Down Expand Up @@ -457,9 +457,6 @@ def _create_vec_read(
)
mask = _constant_mask(mask_vec_type)

start_indices = [
arith_d.addi(i, j) for i, j in zip(start_indices_wg, start_indices_th)
]
return vector_d.gather(
vector_type, mem, start_indices, offsets_vec, mask, passthru
)
Expand Down Expand Up @@ -487,7 +484,9 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
input_shape = _get_symbolic_shape(memory)
elements_per_thread = cast_py_literal(emitter, elements_per_thread)
if get_custom(node).has_identity_mapping():
start_indices_wg, start_indices_th = _build_start_indices(emitter, index)
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, index
)
mask = _build_mask(
emitter,
index,
Expand All @@ -498,6 +497,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
input_shape,
kb_src,
vector_type,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
Expand All @@ -509,6 +509,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals
)
(
start_indices,
start_indices_wg,
start_indices_th,
offsets_vec,
Expand All @@ -528,6 +529,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
input_shape,
kb_src,
vector_type,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
Expand All @@ -543,16 +545,14 @@ def _create_vec_write(
symbolic_shape: tuple[IndexExpr, ...],
mem: Value,
value: Value,
start_indices: tuple[Value],
start_indices_wg: tuple[Value],
start_indices_th: tuple[Value],
elements_per_thread: int,
mask: Optional[Value],
offsets_vec: Optional[Value],
):
if mask is None and offsets_vec is None:
start_indices = [
arith_d.addi(i, j) for i, j in zip(start_indices_wg, start_indices_th)
]
vector_d.store(value, mem, start_indices)
return

Expand Down Expand Up @@ -606,9 +606,6 @@ def _create_vec_write(
)
mask = _constant_mask(mask_vec_type)

start_indices = [
arith_d.addi(i, j) for i, j in zip(start_indices_wg, start_indices_th)
]
vector_d.scatter(mem, start_indices, offsets_vec, mask, value)


Expand Down Expand Up @@ -640,13 +637,16 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
output_shape = _get_symbolic_shape(memory)
elements_per_thread = cast_py_literal(emitter, elements_per_thread)
if get_custom(node).has_identity_mapping():
start_indices_wg, start_indices_th = _build_start_indices(emitter, index)
start_indices, start_indices_wg, start_indices_th = _build_start_indices(
emitter, index
)
mask = _build_mask(emitter, index, elements_per_thread)
_create_vec_write(
emitter,
output_shape,
kb_dest,
insert_vector,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
Expand All @@ -662,6 +662,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals
)
(
start_indices,
start_indices_wg,
start_indices_th,
offsets_vec,
Expand All @@ -682,6 +683,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):
output_shape,
kb_dest,
insert_vector,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
Expand Down

0 comments on commit eac4567

Please sign in to comment.