diff --git a/iree/turbine/kernel/wave/codegen/read_write.py b/iree/turbine/kernel/wave/codegen/read_write.py index 7a0f1fc2b..d552817e4 100644 --- a/iree/turbine/kernel/wave/codegen/read_write.py +++ b/iree/turbine/kernel/wave/codegen/read_write.py @@ -405,7 +405,9 @@ def _create_vec_read( if mask is None and offsets_vec is None: return vector_d.load(vector_type, mem, start_indices) - is_gather = offsets_vec is not None + # Only use buffer ops if it's gather/scatter on global mem. + use_buffer_ops = offsets_vec is not None and mem.type.memory_space is None + element_type = vector_type.element_type if offsets_vec is None: offsets_vec_type = VectorType.get(vector_type.shape, IndexType.get()) @@ -420,9 +422,9 @@ def _create_vec_read( strides = strides_from_symbolic_shape( IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True ) - use_buffer_ops = emitter.params.get("use_buffer_load_ops", False) + buffer_ops_enabled = emitter.params.get("use_buffer_load_ops", False) has_int_strides = all(isinstance(s, int) for s in strides) - if use_buffer_ops and has_int_strides and is_gather: + if buffer_ops_enabled and has_int_strides and use_buffer_ops: result = vector_d.splat(vector_type, zero) strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides] @@ -563,7 +565,9 @@ def _create_vec_write( vector_d.store(value, mem, start_indices) return - is_scatter = offsets_vec is not None + # Only use buffer ops if it's gather/scatter on global mem. + use_buffer_ops = offsets_vec is not None and mem.type.memory_space is None + vector_type = value.type element_type = vector_type.element_type if offsets_vec is None: @@ -576,9 +580,9 @@ def _create_vec_write( strides = strides_from_symbolic_shape( IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True ) - use_buffer_ops = emitter.params.get("use_buffer_store_ops", False) + buffer_ops_enabled = emitter.params.get("use_buffer_store_ops", False) has_int_strides = all(isinstance(s, int) for s in strides) - if use_buffer_ops and has_int_strides and is_scatter: + if buffer_ops_enabled and has_int_strides and use_buffer_ops: strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides] data, offset_th = _linearize_memref( mem, start_indices_wg, start_indices_th, strides