Skip to content

Commit

Permalink
Update the matmul tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
htyu committed Jan 8, 2025
1 parent 919e4e8 commit 0b0b964
Showing 1 changed file with 120 additions and 12 deletions.
132 changes: 120 additions & 12 deletions python/tutorials/10-warp-specialized-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def get_tma_descriptor_kernel_param(self, name):
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]


@triton.autotune(
configs=[
triton.Config(
Expand Down Expand Up @@ -130,7 +129,7 @@ def get_tma_descriptor_kernel_param(self, name):
use_cuda_graph=True,
)
@triton.jit
def matmul_persistent_tma_ws_cooperative_kernel(
def matmul_persistent_tma_ws_cooperative_annotated_kernel(
a_ptr,
b_ptr,
c_ptr,
Expand Down Expand Up @@ -191,12 +190,103 @@ def matmul_persistent_tma_ws_cooperative_kernel(
tl._experimental_descriptor_store(c_ptr, c, [offs_am, offs_bn])


@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"NUM_CONSUMER_GROUPS": 2,
},
num_stages=2,
num_warps=4,
num_consumer_groups=2,
num_buffers_warp_spec=3,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"NUM_CONSUMER_GROUPS": 1,
},
num_stages=3,
num_warps=4,
num_consumer_groups=0, # disable warp specialization
num_buffers_warp_spec=3,
),
],
key=["M", "N", "K"],
use_cuda_graph=True,
)
@triton.jit
def matmul_persistent_tma_ws_cooperative_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
NUM_CONSUMER_GROUPS: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""

num_tiles = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N)
for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)):
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
offs_k = 0

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl._experimental_descriptor_load(
a_ptr,
[offs_am, offs_k],
[BLOCK_SIZE_M, BLOCK_SIZE_K],
tl.float16,
)
b = tl._experimental_descriptor_load(
b_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tl.float16
)

accumulator += tl.dot(a, b)
offs_k += BLOCK_SIZE_K

c = accumulator.to(tl.float16)
tl._experimental_descriptor_store(c_ptr, c, [offs_am, offs_bn])


# %%
# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.


def matmul_persistent_tma_ws_cooperative(a, b):
def matmul_persistent_tma_ws_cooperative(a, b, use_annotation=False):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
Expand Down Expand Up @@ -253,24 +343,42 @@ def grid(META):
desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
matmul_persistent_tma_ws_cooperative_kernel[grid](
desc_a,
desc_b,
desc_c, #
M,
N,
K, #
)

if use_annotation:
matmul_persistent_tma_ws_cooperative_annotated_kernel[grid](
desc_a,
desc_b,
desc_c, #
M,
N,
K, #
)
else:
matmul_persistent_tma_ws_cooperative_kernel[grid](
desc_a,
desc_b,
desc_c, #
M,
N,
K, #
)
return c


def aten_matmul(a, b):
return a.mm(b)

def matmul_ws_annotated(a, b):
return matmul_persistent_tma_ws_cooperative(a, b, use_annotation=True)

def matmul_ws_automatic(a, b):
return matmul_persistent_tma_ws_cooperative(a, b, use_annotation=False)


test_impls = [
aten_matmul,
matmul_persistent_tma_ws_cooperative,
matmul_ws_annotated,
matmul_ws_automatic,
]


Expand Down

0 comments on commit 0b0b964

Please sign in to comment.