diff --git a/iree/turbine/kernel/wave/codegen/read_write.py b/iree/turbine/kernel/wave/codegen/read_write.py index 737807256..6010988a8 100644 --- a/iree/turbine/kernel/wave/codegen/read_write.py +++ b/iree/turbine/kernel/wave/codegen/read_write.py @@ -42,10 +42,11 @@ write, ) -from ..utils import subs_idxc, find_index_bounds +from ..utils import safe_subs, subs_idxc, find_index_bounds from ..._support.indexing import IndexingContext, IndexExpr, IndexSequence, index_symbol from ...lang.wave_types import IndexMapping +from ...lang.global_symbols import * from .emitter import ( WaveEmitter, @@ -75,15 +76,40 @@ def _get_start_indices( return start_indices +def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]: + """ + Split index expr into thread-dependent and thread-independent parts + """ + subs_wg = {WORKGROUP_0: 0, WORKGROUP_1: 0, WORKGROUP_2: 0} + subs_th = {THREAD_0: 0, THREAD_1: 0, THREAD_2: 0} + # Replace all wg symbols with 0s to get thread-dependent index. + # All dynamic values will also be part of thread-index. + thread_dependend_index = safe_subs(src, subs_wg) + + # Compute thread-independent index as `orig_index - thread_dependend_index` + # All thread symbols should cancel-out in the result, but to be sure + # replace all thread symbols by 0 in the result. + # We cannot just replace all thread symbols without the subtraction as + # any constant or dynamic values will end up in both expressions. + thread_indepdndent_index = sympy.simplify( + safe_subs(src - thread_dependend_index, subs_th) + ) + return thread_indepdndent_index, thread_dependend_index + + def _build_start_indices( emitter: WaveEmitter, src_indices: dict[IndexExpr, IndexSequence | IndexExpr], dynamic_values: dict[IndexExpr, Any] = {}, -) -> list[OpResult]: - return [ - gen_sympy_index(add_emitter_subs(emitter, dynamic_values), i) - for i in _get_start_indices(src_indices) - ] +) -> tuple[list[OpResult], list[OpResult], list[OpResult]]: + start_indices = _get_start_indices(src_indices) + split_indices = [_split_index(i) for i in start_indices] + subs = add_emitter_subs(emitter, dynamic_values) + indices = [gen_sympy_index(subs, i) for i in start_indices] + indices_wg = [gen_sympy_index(subs, i[0]) for i in split_indices] + indices_th = [gen_sympy_index(subs, i[1]) for i in split_indices] + + return indices, indices_wg, indices_th def _get_fastest_index(indices: dict[IndexExpr, IndexSequence]): @@ -154,7 +180,7 @@ def _construct_gather_scatter_indices( is_read: bool, dynamic_vals: tuple[Any, ...], is_contiguous: bool, -) -> tuple[OpResult, OpResult, OpResult]: +) -> tuple[list[OpResult], list[OpResult], list[OpResult], OpResult, OpResult]: # Apply symbolc_shape order to indices, e.g. if original mapping is # {M: iter(0), N: iter(1)} and symbolc_shape is (N, M), result will # be (iter(1), iter(0)) @@ -200,10 +226,10 @@ def extract0(src): for sym, val in zip(mapping.dynamic_val_indices.keys(), dynamic_vals) } if is_contiguous: - start_indices = _build_start_indices( + start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, result_index, dynamic_vals_map_start ) - return start_indices, None, mask + return start_indices, start_indices_wg, start_indices_th, None, mask start_indices = _get_start_indices(result_index) start_indices_orig = _get_start_indices(index) @@ -255,7 +281,7 @@ def extract0(src): # In case we need dynamic `offsets_vec`, set all `start_indices` to 0 # and encode entire index info in `offsets_vec`. result_index = {key: 0 for key in symbolc_shape} - start_indices = _build_start_indices( + start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, result_index, dynamic_vals_map_start ) subs = [(sym, idx) for sym, idx in zip(iters.keys(), start_indices_orig)] @@ -278,18 +304,18 @@ def extract0(src): _compute_offset(indices, strides), ) else: - start_indices = _build_start_indices( + start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, result_index, dynamic_vals_map_start ) if offsets == list(range(elements_per_thread)): - return start_indices, None, mask + return start_indices, start_indices_wg, start_indices_th, None, mask offsets = [IntegerAttr.get(IndexType.get(), off) for off in offsets] offsets_vec = arith_d.ConstantOp( offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type) ) - return start_indices, offsets_vec, mask + return start_indices, start_indices_wg, start_indices_th, offsets_vec, mask def _get_max_buffer_size(elem_type: IrType) -> int: @@ -301,7 +327,12 @@ def _get_max_buffer_size(elem_type: IrType) -> int: return ((1 << 31) - 1) // (elem_type.width // 8) -def _linearize_memref(mem: Value, offsets: tuple[Value | int]) -> Value: +def _linearize_memref( + mem: Value, + offsets_wg: tuple[Value | int], + offsets_th: tuple[Value | int], + strides: tuple[Value], +) -> tuple[Value, Value]: """ Convert n-D memref into 1-D memref, suitable for buffer ops. @@ -310,23 +341,27 @@ def _linearize_memref(mem: Value, offsets: tuple[Value | int]) -> Value: no-op. """ memref_type = mem.type - rank = memref_type.rank - results = memref_d.extract_strided_metadata(mem) - base = results[0] - offset = results[1] - results = results[2:] - sizes = results[:rank] - strides = results[rank:] + offset = None + offset_th = None overflow_flags = arith_d.IntegerOverflowFlags.nsw - for ind, size, stride in zip(offsets, sizes, strides): - if isinstance(ind, int): - ind = arith_d.constant(IndexType.get(), ind) - - offset = arith_d.addi( - offset, - arith_d.muli(ind, stride, overflow_flags=overflow_flags), - overflow_flags=overflow_flags, - ) + for ind_wg, ind_th, stride in zip(offsets_wg, offsets_th, strides): + if isinstance(ind_wg, int): + ind_wg = arith_d.constant(IndexType.get(), ind_wg) + + if isinstance(ind_th, int): + ind_th = arith_d.constant(IndexType.get(), ind_th) + + off_wg = arith_d.muli(ind_wg, stride, overflow_flags=overflow_flags) + if offset is None: + offset = off_wg + else: + offset = arith_d.addi(offset, off_wg, overflow_flags=overflow_flags) + + off_th = arith_d.muli(ind_th, stride, overflow_flags=overflow_flags) + if offset_th is None: + offset_th = off_th + else: + offset_th = arith_d.addi(offset_th, off_th, overflow_flags=overflow_flags) size_full = arith_d.constant( IndexType.get(), _get_max_buffer_size(memref_type.element_type) - 1 @@ -342,23 +377,29 @@ def _linearize_memref(mem: Value, offsets: tuple[Value | int]) -> Value: layout=Attribute.parse("strided<[1], offset: ?>"), memory_space=memory_space, ) - return memref_d.reinterpret_cast( - resut_type, - base, - offsets=[offset], - sizes=[size_full], - strides=[], - static_offsets=[dyn_val], - static_sizes=[dyn_val], - static_strides=[1], + return ( + memref_d.reinterpret_cast( + resut_type, + mem, + offsets=[offset], + sizes=[size_full], + strides=[], + static_offsets=[dyn_val], + static_sizes=[dyn_val], + static_strides=[1], + ), + offset_th, ) def _create_vec_read( emitter: WaveEmitter, + symbolic_shape: tuple[IndexExpr, ...], mem: Value, vector_type: IrType, start_indices: tuple[Value], + start_indices_wg: tuple[Value], + start_indices_th: tuple[Value], elements_per_thread: int, mask: Optional[Value], offsets_vec: Optional[Value], @@ -366,6 +407,9 @@ def _create_vec_read( if mask is None and offsets_vec is None: return vector_d.load(vector_type, mem, start_indices) + # Only use buffer ops if it's gather/scatter on global mem. + use_buffer_ops = offsets_vec is not None and mem.type.memory_space is None + element_type = vector_type.element_type if offsets_vec is None: offsets_vec_type = VectorType.get(vector_type.shape, IndexType.get()) @@ -377,10 +421,20 @@ def _create_vec_read( zero = get_constant_attr(0, element_type) zero = arith_d.constant(element_type, zero) - if emitter.params.get("use_buffer_load_ops", False): + strides = strides_from_symbolic_shape( + IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True + ) + buffer_ops_enabled = emitter.params.get("use_buffer_load_ops", False) + has_int_strides = all(isinstance(s, int) for s in strides) + if buffer_ops_enabled and has_int_strides and use_buffer_ops: result = vector_d.splat(vector_type, zero) - data = _linearize_memref(mem, start_indices) + strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides] + data, offset_th = _linearize_memref( + mem, start_indices_wg, start_indices_th, strides + ) + offset_th = vector_d.splat(offsets_vec.type, offset_th) + offsets_vec = arith_d.addi(offsets_vec, offset_th) if mask is not None: i32 = IntegerType.get_signless(32) i32vec = VectorType.get([elements_per_thread], i32) @@ -441,7 +495,9 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): input_shape = _get_symbolic_shape(memory) elements_per_thread = cast_py_literal(emitter, elements_per_thread) if get_custom(node).has_identity_mapping(): - start_indices = _build_start_indices(emitter, index) + start_indices, start_indices_wg, start_indices_th = _build_start_indices( + emitter, index + ) mask = _build_mask( emitter, index, @@ -449,9 +505,12 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): ) result = _create_vec_read( emitter, + input_shape, kb_src, vector_type, start_indices, + start_indices_wg, + start_indices_th, elements_per_thread, mask, offsets_vec=None, @@ -460,7 +519,13 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): dyn_vals = tuple( cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals ) - start_indices, offsets_vec, mask = _construct_gather_scatter_indices( + ( + start_indices, + start_indices_wg, + start_indices_th, + offsets_vec, + mask, + ) = _construct_gather_scatter_indices( emitter=emitter, symbolc_shape=input_shape, index=index, @@ -472,9 +537,12 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): ) result = _create_vec_read( emitter, + input_shape, kb_src, vector_type, start_indices, + start_indices_wg, + start_indices_th, elements_per_thread, mask, offsets_vec, @@ -485,9 +553,12 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): def _create_vec_write( emitter: WaveEmitter, + symbolic_shape: tuple[IndexExpr, ...], mem: Value, value: Value, start_indices: tuple[Value], + start_indices_wg: tuple[Value], + start_indices_th: tuple[Value], elements_per_thread: int, mask: Optional[Value], offsets_vec: Optional[Value], @@ -496,6 +567,9 @@ def _create_vec_write( vector_d.store(value, mem, start_indices) return + # Only use buffer ops if it's gather/scatter on global mem. + use_buffer_ops = offsets_vec is not None and mem.type.memory_space is None + vector_type = value.type element_type = vector_type.element_type if offsets_vec is None: @@ -505,8 +579,18 @@ def _create_vec_write( offsets_vec_type, DenseElementsAttr.get(vals, offsets_vec_type) ) - if emitter.params.get("use_buffer_store_ops", False): - data = _linearize_memref(mem, start_indices) + strides = strides_from_symbolic_shape( + IndexingContext.current(), symbolic_shape, allow_mixed_shapes=True + ) + buffer_ops_enabled = emitter.params.get("use_buffer_store_ops", False) + has_int_strides = all(isinstance(s, int) for s in strides) + if buffer_ops_enabled and has_int_strides and use_buffer_ops: + strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides] + data, offset_th = _linearize_memref( + mem, start_indices_wg, start_indices_th, strides + ) + offset_th = vector_d.splat(offsets_vec.type, offset_th) + offsets_vec = arith_d.addi(offsets_vec, offset_th) if mask is not None: i32 = IntegerType.get_signless(32) i32vec = VectorType.get([elements_per_thread], i32) @@ -564,13 +648,18 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): output_shape = _get_symbolic_shape(memory) elements_per_thread = cast_py_literal(emitter, elements_per_thread) if get_custom(node).has_identity_mapping(): - start_indices = _build_start_indices(emitter, index) + start_indices, start_indices_wg, start_indices_th = _build_start_indices( + emitter, index + ) mask = _build_mask(emitter, index, elements_per_thread) _create_vec_write( emitter, + output_shape, kb_dest, insert_vector, start_indices, + start_indices_wg, + start_indices_th, elements_per_thread, mask, offsets_vec=None, @@ -583,7 +672,13 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): dyn_vals = tuple( cast_vector(emitter, reg, element_type=IndexType.get()) for reg in dyn_vals ) - start_indices, offsets_vec, mask = _construct_gather_scatter_indices( + ( + start_indices, + start_indices_wg, + start_indices_th, + offsets_vec, + mask, + ) = _construct_gather_scatter_indices( emitter=emitter, symbolc_shape=output_shape, index=index, @@ -596,9 +691,12 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): _create_vec_write( emitter, + output_shape, kb_dest, insert_vector, start_indices, + start_indices_wg, + start_indices_th, elements_per_thread, mask, offsets_vec, diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 4005c95a7..47d799b18 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -162,6 +162,49 @@ def read_mapped(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): # CHECK-SAME: into vector<16xf16> +@run_test +def test_read_mapped_buffer(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16} + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + mapping = tkw.IndexMapping( + num_iterators=2, inputs={N: i, M: j}, outputs={N: i, M: j} + ) + + @tkw.wave(constraints) + def read_mapped_buffer(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): + tkw.read(a, mapping=mapping, elements_per_thread=16) + + with tk.gen.TestLaunchContext( + { + M: 16, + N: 16, + K: 16, + BLOCK_M: 16, + BLOCK_N: 16, + BLOCK_K: 16, + ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, + }, + use_buffer_load_ops=True, + use_buffer_store_ops=True, + ): + a = torch.randn(16, 16, dtype=torch.float16) + print(read_mapped_buffer(a).module_op) + + # CHECK-LABEL: func.func @read_mapped_buffer + # CHECK-COUNT-1: memref.reinterpret_cast + # CHECK-COUNT-16: amdgpu.raw_buffer_load + + @run_test def test_read_write(): constraints: list[tkw.Constraint] = [ @@ -354,49 +397,6 @@ def read_write_masked( # CHECK-SAME: strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> -@run_test -def test_read_write_buffer(): - constraints: list[tkw.Constraint] = [ - tkw.HardwareConstraint( - threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 4, N: 4} - ) - ] - constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.WaveConstraint(M, BLOCK_M)] - constraints += [tkw.WaveConstraint(N, BLOCK_N)] - - @tkw.wave(constraints) - def read_write_buffer( - a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - ): - res = tkw.read(a, elements_per_thread=4) - tkw.write(res, b, elements_per_thread=4) - - with tk.gen.TestLaunchContext( - { - M: 1, - N: 3, - BLOCK_M: 4, - BLOCK_N: 4, - ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, - }, - canonicalize=True, - use_buffer_load_ops=True, - use_buffer_store_ops=True, - ): - a = torch.randn(4, 4, dtype=torch.float16) - b = torch.zeros(4, 4, dtype=torch.float16) - print(read_write_buffer(a, b).module_op) - - # CHECK-LABEL: func.func @read_write_buffer - # CHECK-COUNT-1: memref.reinterpret_cast - # CHECK-COUNT-4: amdgpu.raw_buffer_load - # CHECK-COUNT-1: memref.reinterpret_cast - # CHECK-COUNT-4: amdgpu.raw_buffer_store - - @run_test def test_read_write_masked_shared(): constraints: list[tkw.Constraint] = [ diff --git a/tests/kernel/wave/attention/extend_attention_test.py b/tests/kernel/wave/attention/extend_attention_test.py index 8e0b6561f..567238b57 100644 --- a/tests/kernel/wave/attention/extend_attention_test.py +++ b/tests/kernel/wave/attention/extend_attention_test.py @@ -239,6 +239,7 @@ def create_inputs( @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("enable_scheduling", [False]) @pytest.mark.parametrize("is_causal", [False, True]) +@pytest.mark.parametrize("use_buffer_ops", [False, True]) @pytest.mark.parametrize( "mfma_variant", [ @@ -251,6 +252,7 @@ def testExtendAttention( dtype: torch.dtype, enable_scheduling: bool, is_causal: bool, + use_buffer_ops: bool, mfma_variant: MMAType, request, ): @@ -312,6 +314,8 @@ def testExtendAttention( if run_bench: config["benchmark_batch_size"] = 1000 config["benchmark_repetitions"] = 3 + config["dump_intermediates"] = "./inter" + if dump_perf is not None: perf_filename = construct_test_name( "wave_extend_attention", mfma_variant, is_causal, shape @@ -328,6 +332,8 @@ def testExtendAttention( use_scheduling_barriers=enable_scheduling_barriers, dynamic_symbols=dynamic_symbols, dynamic_symbols_map=dynamic_symbols_map, + use_buffer_load_ops=use_buffer_ops, + use_buffer_store_ops=use_buffer_ops, ): mb_qk = extend_attention( q_extend,