From eac4567f4253dbaa2d9fe4899e33696721cb0f51 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 12 Feb 2025 02:31:03 +0100 Subject: [PATCH] fix offset calculation Signed-off-by: Ivan Butygin --- .../turbine/kernel/wave/codegen/read_write.py | 50 ++++++++++--------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/iree/turbine/kernel/wave/codegen/read_write.py b/iree/turbine/kernel/wave/codegen/read_write.py index e4d56fe74..397e40b05 100644 --- a/iree/turbine/kernel/wave/codegen/read_write.py +++ b/iree/turbine/kernel/wave/codegen/read_write.py @@ -90,13 +90,15 @@ def _build_start_indices( emitter: WaveEmitter, src_indices: dict[IndexExpr, IndexSequence | IndexExpr], dynamic_values: dict[IndexExpr, Any] = {}, -) -> tuple[list[OpResult], list[OpResult]]: - split_indices = [_split_index(i) for i in _get_start_indices(src_indices)] +) -> tuple[list[OpResult], list[OpResult], list[OpResult]]: + start_indices = _get_start_indices(src_indices) + split_indices = [_split_index(i) for i in start_indices] subs = add_emitter_subs(emitter, dynamic_values) + indices = [gen_sympy_index(subs, i) for i in start_indices] indices_wg = [gen_sympy_index(subs, i[0]) for i in split_indices] indices_th = [gen_sympy_index(subs, i[1]) for i in split_indices] - return indices_wg, indices_th + return indices, indices_wg, indices_th def _get_fastest_index(indices: dict[IndexExpr, IndexSequence]): @@ -167,7 +169,7 @@ def _construct_gather_scatter_indices( is_read: bool, dynamic_vals: tuple[Any, ...], is_contiguous: bool, -) -> tuple[list[OpResult], list[OpResult], OpResult, OpResult]: +) -> tuple[list[OpResult], list[OpResult], list[OpResult], OpResult, OpResult]: # Apply symbolc_shape order to indices, e.g. if original mapping is # {M: iter(0), N: iter(1)} and symbolc_shape is (N, M), result will # be (iter(1), iter(0)) @@ -213,10 +215,10 @@ def extract0(src): for sym, val in zip(mapping.dynamic_val_indices.keys(), dynamic_vals) } if is_contiguous: - start_indices_wg, start_indices_th = _build_start_indices( + start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, result_index, dynamic_vals_map_start ) - return start_indices_wg, start_indices_th, None, mask + return start_indices, start_indices_wg, start_indices_th, None, mask start_indices = _get_start_indices(result_index) start_indices_orig = _get_start_indices(index) @@ -268,7 +270,7 @@ def extract0(src): # In case we need dynamic `offsets_vec`, set all `start_indices` to 0 # and encode entire index info in `offsets_vec`. result_index = {key: 0 for key in symbolc_shape} - start_indices_wg, start_indices_th = _build_start_indices( + start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, result_index, dynamic_vals_map_start ) subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] @@ -291,18 +293,18 @@ def extract0(src): _compute_offset(indices, strides), ) else: - start_indices_wg, start_indices_th = _build_start_indices( + start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, result_index, dynamic_vals_map_start ) if offsets == list(range(elements_per_thread)): - return start_indices_wg, start_indices_th, None, mask + return start_indices, start_indices_wg, start_indices_th, None, mask offsets = [IntegerAttr.get(IndexType.get(), off) for off in offsets] offsets_vec = arith_d.ConstantOp( offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type) ) - return start_indices_wg, start_indices_th, offsets_vec, mask + return start_indices, start_indices_wg, start_indices_th, offsets_vec, mask def _get_max_buffer_size(elem_type: IrType) -> int: @@ -384,6 +386,7 @@ def _create_vec_read( symbolic_shape: tuple[IndexExpr, ...], mem: Value, vector_type: IrType, + start_indices: tuple[Value], start_indices_wg: tuple[Value], start_indices_th: tuple[Value], elements_per_thread: int, @@ -391,9 +394,6 @@ def _create_vec_read( offsets_vec: Optional[Value], ) -> Value: if mask is None and offsets_vec is None: - start_indices = [ - arith_d.addi(i, j) for i, j in zip(start_indices_wg, start_indices_th) - ] return vector_d.load(vector_type, mem, start_indices) is_gather = offsets_vec is not None @@ -457,9 +457,6 @@ def _create_vec_read( ) mask = _constant_mask(mask_vec_type) - start_indices = [ - arith_d.addi(i, j) for i, j in zip(start_indices_wg, start_indices_th) - ] return vector_d.gather( vector_type, mem, start_indices, offsets_vec, mask, passthru ) @@ -487,7 +484,9 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): input_shape = _get_symbolic_shape(memory) elements_per_thread = cast_py_literal(emitter, elements_per_thread) if get_custom(node).has_identity_mapping(): - start_indices_wg, start_indices_th = _build_start_indices(emitter, index) + start_indices, start_indices_wg, start_indices_th = _build_start_indices( + emitter, index + ) mask = _build_mask( emitter, index, @@ -498,6 +497,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): input_shape, kb_src, vector_type, + start_indices, start_indices_wg, start_indices_th, elements_per_thread, @@ -509,6 +509,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals ) ( + start_indices, start_indices_wg, start_indices_th, offsets_vec, @@ -528,6 +529,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): input_shape, kb_src, vector_type, + start_indices, start_indices_wg, start_indices_th, elements_per_thread, @@ -543,6 +545,7 @@ def _create_vec_write( symbolic_shape: tuple[IndexExpr, ...], mem: Value, value: Value, + start_indices: tuple[Value], start_indices_wg: tuple[Value], start_indices_th: tuple[Value], elements_per_thread: int, @@ -550,9 +553,6 @@ def _create_vec_write( offsets_vec: Optional[Value], ): if mask is None and offsets_vec is None: - start_indices = [ - arith_d.addi(i, j) for i, j in zip(start_indices_wg, start_indices_th) - ] vector_d.store(value, mem, start_indices) return @@ -606,9 +606,6 @@ def _create_vec_write( ) mask = _constant_mask(mask_vec_type) - start_indices = [ - arith_d.addi(i, j) for i, j in zip(start_indices_wg, start_indices_th) - ] vector_d.scatter(mem, start_indices, offsets_vec, mask, value) @@ -640,13 +637,16 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): output_shape = _get_symbolic_shape(memory) elements_per_thread = cast_py_literal(emitter, elements_per_thread) if get_custom(node).has_identity_mapping(): - start_indices_wg, start_indices_th = _build_start_indices(emitter, index) + start_indices, start_indices_wg, start_indices_th = _build_start_indices( + emitter, index + ) mask = _build_mask(emitter, index, elements_per_thread) _create_vec_write( emitter, output_shape, kb_dest, insert_vector, + start_indices, start_indices_wg, start_indices_th, elements_per_thread, @@ -662,6 +662,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals ) ( + start_indices, start_indices_wg, start_indices_th, offsets_vec, @@ -682,6 +683,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): output_shape, kb_dest, insert_vector, + start_indices, start_indices_wg, start_indices_th, elements_per_thread,