Skip to content

Commit

Permalink
check global mem
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 committed Feb 14, 2025
1 parent 2007ad2 commit f0cacae
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions iree/turbine/kernel/wave/codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,9 @@ def _create_vec_read(
if mask is None and offsets_vec is None:
return vector_d.load(vector_type, mem, start_indices)

is_gather = offsets_vec is not None
# Only use buffer ops if it's gather/scatter on global mem.
use_buffer_ops = offsets_vec is not None and mem.type.memory_space is None

element_type = vector_type.element_type
if offsets_vec is None:
offsets_vec_type = VectorType.get(vector_type.shape, IndexType.get())
Expand All @@ -420,9 +422,9 @@ def _create_vec_read(
strides = strides_from_symbolic_shape(
IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True
)
use_buffer_ops = emitter.params.get("use_buffer_load_ops", False)
buffer_ops_enabled = emitter.params.get("use_buffer_load_ops", False)
has_int_strides = all(isinstance(s, int) for s in strides)
if use_buffer_ops and has_int_strides and is_gather:
if buffer_ops_enabled and has_int_strides and use_buffer_ops:
result = vector_d.splat(vector_type, zero)

strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides]
Expand Down Expand Up @@ -563,7 +565,9 @@ def _create_vec_write(
vector_d.store(value, mem, start_indices)
return

is_scatter = offsets_vec is not None
# Only use buffer ops if it's gather/scatter on global mem.
use_buffer_ops = offsets_vec is not None and mem.type.memory_space is None

vector_type = value.type
element_type = vector_type.element_type
if offsets_vec is None:
Expand All @@ -576,9 +580,9 @@ def _create_vec_write(
strides = strides_from_symbolic_shape(
IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True
)
use_buffer_ops = emitter.params.get("use_buffer_store_ops", False)
buffer_ops_enabled = emitter.params.get("use_buffer_store_ops", False)
has_int_strides = all(isinstance(s, int) for s in strides)
if use_buffer_ops and has_int_strides and is_scatter:
if buffer_ops_enabled and has_int_strides and use_buffer_ops:
strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides]
data, offset_th = _linearize_memref(
mem, start_indices_wg, start_indices_th, strides
Expand Down

0 comments on commit f0cacae

Please sign in to comment.