Skip to content

Commit

Permalink
fix lit
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 committed Feb 12, 2025
1 parent eac4567 commit 670e275
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,49 @@ def read_mapped(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
# CHECK-SAME: into vector<16xf16>


@run_test
def test_read_mapped_buffer():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16}
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
mapping = tkw.IndexMapping(
num_iterators=2, inputs={N: i, M: j}, outputs={N: i, M: j}
)

@tkw.wave(constraints)
def read_mapped_buffer(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
tkw.read(a, mapping=mapping, elements_per_thread=16)

with tk.gen.TestLaunchContext(
{
M: 16,
N: 16,
K: 16,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_K: 16,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
},
use_buffer_load_ops=True,
use_buffer_store_ops=True,
):
a = torch.randn(16, 16, dtype=torch.float16)
print(read_mapped_buffer(a).module_op)

# CHECK-LABEL: func.func @read_mapped_buffer
# CHECK-COUNT-1: memref.reinterpret_cast
# CHECK-COUNT-16: amdgpu.raw_buffer_load


@run_test
def test_read_write():
constraints: list[tkw.Constraint] = [
Expand Down Expand Up @@ -354,49 +397,6 @@ def read_write_masked(
# CHECK-SAME: strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16>


@run_test
def test_read_write_buffer():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 4, N: 4}
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def read_write_buffer(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
res = tkw.read(a, elements_per_thread=4)
tkw.write(res, b, elements_per_thread=4)

with tk.gen.TestLaunchContext(
{
M: 1,
N: 3,
BLOCK_M: 4,
BLOCK_N: 4,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
},
canonicalize=True,
use_buffer_load_ops=True,
use_buffer_store_ops=True,
):
a = torch.randn(4, 4, dtype=torch.float16)
b = torch.zeros(4, 4, dtype=torch.float16)
print(read_write_buffer(a, b).module_op)

# CHECK-LABEL: func.func @read_write_buffer
# CHECK-COUNT-1: memref.reinterpret_cast
# CHECK-COUNT-4: amdgpu.raw_buffer_load
# CHECK-COUNT-1: memref.reinterpret_cast
# CHECK-COUNT-4: amdgpu.raw_buffer_store


@run_test
def test_read_write_masked_shared():
constraints: list[tkw.Constraint] = [
Expand Down

0 comments on commit 670e275

Please sign in to comment.