From 670e27523bedec2ccc47a5433c67352b05da58c3 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 12 Feb 2025 03:00:25 +0100 Subject: [PATCH] fix lit Signed-off-by: Ivan Butygin --- lit_tests/kernel/wave/codegen.py | 86 ++++++++++++++++---------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 4005c95a7..47d799b18 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -162,6 +162,49 @@ def read_mapped(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): # CHECK-SAME: into vector<16xf16> +@run_test +def test_read_mapped_buffer(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16} + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + mapping = tkw.IndexMapping( + num_iterators=2, inputs={N: i, M: j}, outputs={N: i, M: j} + ) + + @tkw.wave(constraints) + def read_mapped_buffer(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): + tkw.read(a, mapping=mapping, elements_per_thread=16) + + with tk.gen.TestLaunchContext( + { + M: 16, + N: 16, + K: 16, + BLOCK_M: 16, + BLOCK_N: 16, + BLOCK_K: 16, + ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, + }, + use_buffer_load_ops=True, + use_buffer_store_ops=True, + ): + a = torch.randn(16, 16, dtype=torch.float16) + print(read_mapped_buffer(a).module_op) + + # CHECK-LABEL: func.func @read_mapped_buffer + # CHECK-COUNT-1: memref.reinterpret_cast + # CHECK-COUNT-16: amdgpu.raw_buffer_load + + @run_test def test_read_write(): constraints: list[tkw.Constraint] = [ @@ -354,49 +397,6 @@ def read_write_masked( # CHECK-SAME: strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> -@run_test -def test_read_write_buffer(): - constraints: list[tkw.Constraint] = [ - tkw.HardwareConstraint( - threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 4, N: 4} - ) - ] - constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.WaveConstraint(M, BLOCK_M)] - constraints += [tkw.WaveConstraint(N, BLOCK_N)] - - @tkw.wave(constraints) - def read_write_buffer( - a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - ): - res = tkw.read(a, elements_per_thread=4) - tkw.write(res, b, elements_per_thread=4) - - with tk.gen.TestLaunchContext( - { - M: 1, - N: 3, - BLOCK_M: 4, - BLOCK_N: 4, - ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, - }, - canonicalize=True, - use_buffer_load_ops=True, - use_buffer_store_ops=True, - ): - a = torch.randn(4, 4, dtype=torch.float16) - b = torch.zeros(4, 4, dtype=torch.float16) - print(read_write_buffer(a, b).module_op) - - # CHECK-LABEL: func.func @read_write_buffer - # CHECK-COUNT-1: memref.reinterpret_cast - # CHECK-COUNT-4: amdgpu.raw_buffer_load - # CHECK-COUNT-1: memref.reinterpret_cast - # CHECK-COUNT-4: amdgpu.raw_buffer_store - - @run_test def test_read_write_masked_shared(): constraints: list[tkw.Constraint] = [