Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nvidia] Support passing TMA descriptors by-value #4498

Merged
merged 15 commits into from
Aug 19, 2024
40 changes: 40 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,42 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
return amendedFuncOp;
}

// Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM
// attributes.
static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) {
const bool isKernel = LLVM::isKernel(llvmFuncOp);
for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) {
const auto attrs = llvmFuncOp.getArgAttrDict(i);
if (!attrs) {
continue;
}

for (const auto &attr : attrs) {
if (attr.getName() == "tt.nv_tma_desc") {
const auto i32_type =
mlir::IntegerType::get(llvmFuncOp.getContext(), 32);
assert(attr.getValue() == mlir::IntegerAttr::get(i32_type, 1));
assert(isKernel &&
"tt.nv_tma_desc is not supported for device functions");

// See
// https://github.com/google/jax/blob/main/jaxlib/mosaic/gpu/passes.cc
mlir::BlockArgument arg = llvmFuncOp.getArgument(i);
const auto byteType =
mlir::IntegerType::get(llvmFuncOp.getContext(), 8);
const auto arrayType = mlir::LLVM::LLVMArrayType::get(
llvmFuncOp.getContext(), byteType, 128);
llvmFuncOp.setArgAttr(i, "llvm.byval",
mlir::TypeAttr::get(arrayType));
llvmFuncOp.setArgAttr(i, "nvvm.grid_constant",
mlir::UnitAttr::get(llvmFuncOp.getContext()));
llvmFuncOp.setArgAttr(i, "llvm.align",
mlir::IntegerAttr::get(i32_type, 64));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 64 a required alignment value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Here is the definition of CUtensorMap in <cuda.h>:

typedef struct CUtensorMap_st {
    alignas(64)
    unsigned long long opaque[16];
} CUtensorMap;

}
}
}
}

LogicalResult
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -105,6 +141,10 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
newFuncOp->setAttr("nvvm.reqntid",
rewriter.getDenseI32ArrayAttr(32 * numWarps));
rewriter.eraseOp(funcOp);

// Add attributes for by-value TMA descriptor args (nvidia)
handleByvalTmaDescArgs(newFuncOp);

return success();
}

Expand Down
100 changes: 42 additions & 58 deletions python/test/unit/hopper/test_experimental_tma.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,62 @@
import pytest
import torch
import tempfile

import triton
import triton.language as tl
from triton.tools.experimental_descriptor import create_1d_tma_descriptor, create_2d_tma_descriptor


def test_descriptor_load_ttgir():
embg marked this conversation as resolved.
Show resolved Hide resolved
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9:
pytest.skip("Test requires Hopper target.")
return
device = "cuda"
SIZE = 128
def create_tma_desc_gmem_ptr(ptr, dims, block_dims, element_size):
cpu_desc = torch.empty(128, device="cpu")
if len(dims) == 1:
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size,
cpu_desc.data_ptr())
else:
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], block_dims[1],
element_size, cpu_desc.data_ptr())
return cpu_desc.cuda()

x = torch.randn(SIZE, dtype=torch.float32, device=device)
desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size())
size_in_bytes = SIZE * x.element_size()

ir = f"""
#blocked = #triton_gpu.blocked<{{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}}>
#shared = #triton_gpu.shared<{{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}}>
module attributes {{"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i8> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
%c0_i32 = arith.constant 0 : i32
%0 = tt.make_range {{end = {SIZE} : i32, start = 0 : i32}} : tensor<{SIZE}xi32, #blocked>
%1 = triton_gpu.local_alloc : () -> !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable>
%2 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.init_barrier %2, 1 : <1xi64, #shared, #triton_gpu.shared_memory, mutable>
%true = arith.constant 1 : i1
triton_nvidia_gpu.barrier_expect %2, {size_in_bytes}, %true : <1xi64, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0_i32] %1, %2, %true : <i8>, <1xi64, #shared, #triton_gpu.shared_memory, mutable> -> <{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.wait_barrier %2, %c0_i32 : <1xi64, #shared, #triton_gpu.shared_memory, mutable>
%3 = triton_gpu.local_load %1 : !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<{SIZE}xf32, #blocked>
%4 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<{SIZE}x!tt.ptr<f32>, #blocked>
%5 = tt.addptr %4, %0 : tensor<{SIZE}x!tt.ptr<f32>, #blocked>, tensor<{SIZE}xi32, #blocked>
tt.store %5, %3 : tensor<{SIZE}x!tt.ptr<f32>, #blocked>
tt.return
}}
}}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)

z_tri = torch.empty_like(x)
kernel[(1, 1, 1)](z_tri, desc)
assert torch.equal(x, z_tri)
TMA_FENCE_ASM: tl.constexpr = "fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg"


def test_experimetal_descriptor_load():
@pytest.mark.parametrize("byval_tma", [True, False])
def test_experimetal_descriptor_load(byval_tma):
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9:
pytest.skip("Test requires Hopper target.")
return
device = "cuda"
SIZE = 128

@triton.jit
def kernel(Z, desc, SIZE: tl.constexpr):
def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr):
if not BYVAL_TMA:
tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [desc], dtype=tl.int32, is_pure=False, pack=1)
off_desc = 0
off = tl.arange(0, SIZE)
x = tl._experimental_descriptor_load(desc, [off_desc], [SIZE], Z.dtype.element_ty)
tl.store(Z + off, x)

x = torch.randn(SIZE, dtype=torch.float32, device=device)
desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size())
if byval_tma:
desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size())
else:
desc = create_tma_desc_gmem_ptr(x.data_ptr(), [SIZE], [SIZE], x.element_size())
z_tri = torch.empty_like(x)
kernel[(1, )](z_tri, desc, SIZE=SIZE, num_warps=4)
compiled_kernel = kernel[(1, )](z_tri, desc, SIZE=SIZE, BYVAL_TMA=byval_tma, num_warps=4)
assert torch.equal(x, z_tri)
if byval_tma:
assert ".param .align 64 .b8" in compiled_kernel.asm["ptx"]


@triton.jit
def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):
# TODO(embg) remove TMA fence after __grid_constant__ lands
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
embg marked this conversation as resolved.
Show resolved Hide resolved
M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
BYVAL_TMA: tl.constexpr):
if not BYVAL_TMA:
tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand All @@ -101,7 +77,8 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #

@pytest.mark.parametrize("num_stages", [1, 4])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)])
def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K):
@pytest.mark.parametrize("byval_tma", [True, False])
def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma):
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9:
pytest.skip("Test requires Hopper target.")
return
Expand All @@ -111,13 +88,20 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K):
A = torch.randn((M, K), dtype=torch.float16, device=device)
B = torch.randn((K, N), dtype=torch.float16, device=device)
C = torch.empty((M, N), dtype=torch.float16, device=device)
desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size())
desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size())
desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size())
if byval_tma:
desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size())
desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size())
desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size())
else:
desc_a = create_tma_desc_gmem_ptr(A.data_ptr(), [M, K], [BLOCK_M, BLOCK_K], A.element_size())
desc_b = create_tma_desc_gmem_ptr(B.data_ptr(), [K, N], [BLOCK_K, BLOCK_N], B.element_size())
desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size())
kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1,
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps=8,
num_stages=num_stages)
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma,
num_warps=8, num_stages=num_stages)
ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16)
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
if BLOCK_M >= 64 and BLOCK_N >= 64:
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]
if byval_tma:
assert ".param .align 64 .b8" in kernel.asm["ptx"]
7 changes: 6 additions & 1 deletion python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .. import language
from .._C.libtriton import ir
from ..language import constexpr, tensor, str_to_ty
from ..language.core import _unwrap_if_constexpr
from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type
from ..runtime.jit import _normalize_ty, get_jit_fn_file_line
# ideally we wouldn't need any runtime component
from ..runtime import JITFunction
Expand Down Expand Up @@ -409,6 +409,11 @@ def visit_FunctionDef(self, node):
if i in self.attributes:
for name, value in self.attributes[i]:
self.fn.set_arg_attr(idx, name, value)

# Mark this argument as a pass-by-value TMA descriptor (nvidia)
if isinstance(self.prototype.param_types[idx], nv_tma_desc_type):
self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1)

arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx]))
idx += 1

Expand Down
6 changes: 6 additions & 0 deletions python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
permute,
pi32_t,
pointer_type,
nv_tma_desc_type,
program_id,
range,
reduce,
Expand Down Expand Up @@ -207,6 +208,7 @@
"philox_impl",
"pi32_t",
"pointer_type",
"nv_tma_desc_type",
"program_id",
"rand",
"rand4x",
Expand Down Expand Up @@ -259,6 +261,10 @@ def str_to_ty(name):
const = True
ty = str_to_ty(name)
return pointer_type(element_ty=ty, const=const)

if name == "nvTmaDesc":
return nv_tma_desc_type()

tys = {
"fp8e4nv": float8e4nv,
"fp8e4b8": float8e4b8,
Expand Down
9 changes: 8 additions & 1 deletion python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = Fals
self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>'

def to_ir(self, builder: ir.builder) -> ir.pointer_type:
return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1)
return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space)

def __str__(self):
return self.name
Expand All @@ -595,6 +595,13 @@ def scalar(self):
return self


class nv_tma_desc_type(pointer_type):

def __init__(self):
super().__init__(uint8, const=True, address_space=0)
self.name = 'nv_tma_desc_type'


class block_type(dtype):

def __init__(self, element_ty: dtype, shape: List):
Expand Down
3 changes: 2 additions & 1 deletion python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH'))
include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so]
# for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is causing the extra warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GCC doesn't like the CUtensorMap struct. This is called out in the CUDA C++ Programming Guide as a false warning:

When passing the tensor map as a parameter, some versions of the GCC C++ compiler issue the warning “the ABI for passing parameters with 64-byte alignment has changed in GCC 4.6”. This warning can be ignored.

I don't think it can be suppressed inline via pragma, it has to be suppressed on the command line: https://godbolt.org/z/f5n5crhjG

cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
Expand Down
2 changes: 2 additions & 0 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ def mangle_type(arg, is_const=False):
return "i64"
elif isinstance(arg, float):
return "fp32"
elif hasattr(arg, "tma_desc_cpu_ptr"):
return "nvTmaDesc"
else:
# dtypes are hashable so we can memoize this mapping:
dsk = (arg.dtype, is_const)
Expand Down
42 changes: 23 additions & 19 deletions python/triton/tools/experimental_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,30 @@
import triton


# Constructs a 1D TMA descriptor in mutable GPU memory.
#
# Note: on the first use of a new descriptor, each SM must invalidate the descriptor's
# address in TMA cache via fence.proxy.tensormap::generic.acquire.gpu.
class TmaDescKernelParam:
TMA_DESC_SIZE = 128

def __init__(self, ptr, dims, block_dims, element_size):
self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.int8, device="cpu")
assert len(dims) == len(block_dims)
assert 1 <= len(dims) <= 2
assert self.desc.data_ptr() % 64 == 0

if len(dims) == 1:
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size,
self.desc.data_ptr())
else:
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0],
block_dims[1], element_size, self.desc.data_ptr())

# Return a CUtensorMap* pointer in host memory
def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()


def create_1d_tma_descriptor(ptr, dim, block_dim, element_size):
embg marked this conversation as resolved.
Show resolved Hide resolved
TMA_SIZE = 128
desc = torch.empty(TMA_SIZE, dtype=torch.int8)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dim, block_dim, element_size, desc.data_ptr())
gpu_desc = desc.cuda()
return gpu_desc
return TmaDescKernelParam(ptr, [dim], [block_dim], element_size)


# Constructs a 2D TMA descriptor in mutable GPU memory.
#
# Note: on the first use of a new descriptor, each SM must invalidate the descriptor's
# address in TMA cache via fence.proxy.tensormap::generic.acquire.gpu.
def create_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size):
TMA_SIZE = 128
desc = torch.empty(TMA_SIZE, dtype=torch.int8)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size,
desc.data_ptr())
gpu_desc = desc.cuda()
return gpu_desc
return TmaDescKernelParam(ptr, [dim1, dim0], [block_dim1, block_dim0], element_size)
8 changes: 0 additions & 8 deletions python/tutorials/09-persistent-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,6 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
GROUP_SIZE_M: tl.constexpr, #
FP8_OUTPUT: tl.constexpr, #
NUM_SMS: tl.constexpr): #
# TODO(embg) remove TMA fence after __grid_constant__ lands
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)

dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand Down
12 changes: 12 additions & 0 deletions test/Conversion/tritonnvidiagpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: byval_tma_desc
// CHECK: llvm.align = 64
// CHECK: llvm.byval = !llvm.array<128 x i8>
// CHECK: nvvm.grid_constant
tt.func @byval_tma_desc(%desc: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}) {
tt.return
}
}
Loading
Loading