Skip to content

Commit

Permalink
refac
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 committed Feb 14, 2025
1 parent 2d9d3aa commit 2007ad2
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions iree/turbine/kernel/wave/codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,23 @@ def _get_start_indices(
return start_indices


def _split_index(src: IndexExpr) -> tuple[IndexExpr, IndexExpr]:
def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]:
"""
Split index expr into thread-dependent and thread-independent parts
"""
subs_wg = {WORKGROUP_0: 0, WORKGROUP_1: 0, WORKGROUP_2: 0}
subs_th = {THREAD_0: 0, THREAD_1: 0, THREAD_2: 0}
# Replace all wg symbols with 0s to get thread-dependent index.
# All dynamic values will also be part of thread-index.
thread_dependend_index = safe_subs(src, subs_wg)
return (
sympy.simplify(safe_subs(src - thread_dependend_index, subs_th)),
thread_dependend_index,

# Compute thread-independent index as `orig_index - thread_dependend_index`
# All thread symbols should cancel-out in the result, but to be sure
# replace all thread symbols by 0 in the result.
thread_indepdndent_index = sympy.simplify(
safe_subs(src - thread_dependend_index, subs_th)
)
return thread_indepdndent_index, thread_dependend_index


def _build_start_indices(
Expand Down Expand Up @@ -411,11 +420,9 @@ def _create_vec_read(
strides = strides_from_symbolic_shape(
IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True
)
if (
emitter.params.get("use_buffer_load_ops", False)
and all(isinstance(s, int) for s in strides)
and is_gather
):
use_buffer_ops = emitter.params.get("use_buffer_load_ops", False)
has_int_strides = all(isinstance(s, int) for s in strides)
if use_buffer_ops and has_int_strides and is_gather:
result = vector_d.splat(vector_type, zero)

strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides]
Expand Down Expand Up @@ -569,11 +576,9 @@ def _create_vec_write(
strides = strides_from_symbolic_shape(
IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True
)
if (
emitter.params.get("use_buffer_store_ops", False)
and all(isinstance(s, int) for s in strides)
and is_scatter
):
use_buffer_ops = emitter.params.get("use_buffer_store_ops", False)
has_int_strides = all(isinstance(s, int) for s in strides)
if use_buffer_ops and has_int_strides and is_scatter:
strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides]
data, offset_th = _linearize_memref(
mem, start_indices_wg, start_indices_th, strides
Expand Down

0 comments on commit 2007ad2

Please sign in to comment.