From 8bb3c1dc8a1bf5968ad76be8ac8b0044aab86305 Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Fri, 2 Aug 2024 13:54:08 -0700 Subject: [PATCH 01/14] byval tma desc working prototype --- .../TritonGPUToLLVM/FuncOpToLLVM.cpp | 15 +++++ python/triton/compiler/code_generator.py | 7 ++- python/triton/language/__init__.py | 8 +++ python/triton/language/core.py | 14 ++++- python/triton/runtime/build.py | 3 +- python/triton/runtime/jit.py | 3 +- test.py | 31 ++++++++++ third_party/nvidia/backend/driver.py | 58 ++++++++++++++++++- 8 files changed, 133 insertions(+), 6 deletions(-) create mode 100644 test.py diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp index 6172c614aac9..2114c2daf9b1 100644 --- a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -87,6 +87,21 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { auto ctx = funcOp->getContext(); if (LLVM::isKernel(funcOp)) { + for (unsigned i = 0; i < newFuncOp.getNumArguments(); ++i) { + const auto attrs = newFuncOp.getArgAttrDict(i); + for (const auto& attr : attrs) { + if (attr.getName() == "tt.nv_tma_desc") { + mlir::BlockArgument arg = newFuncOp.getArgument(i); + const auto byteType = mlir::IntegerType::get(newFuncOp.getContext(), 8); + const auto arrayType = mlir::LLVM::LLVMArrayType::get(newFuncOp.getContext(), byteType, 128); + newFuncOp.setArgAttr(i, "llvm.byval", mlir::TypeAttr::get(arrayType)); + newFuncOp.setArgAttr(i, "nvvm.grid_constant", mlir::UnitAttr::get(newFuncOp.getContext())); + newFuncOp.setArgAttr(i, "llvm.align", mlir::IntegerAttr::get( + mlir::IntegerType::get(newFuncOp.getContext(), 32), 64)); + } + } + } + // Set an attribute to indicate this function is a kernel entry. newFuncOp->setAttr("nvvm.kernel", rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 0ae66922f133..959035ccda80 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -9,7 +9,7 @@ from .. import language from .._C.libtriton import ir from ..language import constexpr, tensor, str_to_ty -from ..language.core import _unwrap_if_constexpr +from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type from ..runtime.jit import _normalize_ty, get_jit_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction @@ -409,6 +409,11 @@ def visit_FunctionDef(self, node): if i in self.attributes: for name, value in self.attributes[i]: self.fn.set_arg_attr(idx, name, value) + + # TMA + if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): + self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) idx += 1 diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 6e8803638e84..79e8d81029f1 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -84,6 +84,7 @@ permute, pi32_t, pointer_type, + nv_tma_desc_type, program_id, range, reduce, @@ -102,6 +103,7 @@ view, void, where, + NvTmaDesc, ) from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, ceil) @@ -207,6 +209,7 @@ "philox_impl", "pi32_t", "pointer_type", + "nv_tma_desc_type", "program_id", "rand", "rand4x", @@ -247,6 +250,7 @@ "xor_sum", "zeros", "zeros_like", + "NvTmaDesc", ] @@ -259,6 +263,10 @@ def str_to_ty(name): const = True ty = str_to_ty(name) return pointer_type(element_ty=ty, const=const) + + if name == "nvTmaDesc": + return nv_tma_desc_type() + tys = { "fp8e4nv": float8e4nv, "fp8e4b8": float8e4b8, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 971d8a9f3a37..8e3aee8a228f 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -568,7 +568,7 @@ def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = Fals self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>' def to_ir(self, builder: ir.builder) -> ir.pointer_type: - return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1) + return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space) def __str__(self): return self.name @@ -593,6 +593,10 @@ def __ne__(self, other: pointer_type) -> bool: @property def scalar(self): return self + +class nv_tma_desc_type(pointer_type): + def __init__(self): + super().__init__(uint8, const = True, address_space = 0) class block_type(dtype): @@ -2661,3 +2665,11 @@ def binary_op_type_legalization(lhs, rhs, builder): def extern(fn): """A decorator for external functions.""" return builtin(fn) + +class NvTmaDesc: + def __init__(self, data): + self.data = data + self.dtype = uint8 + + def tma_desc_ptr(self): + return self.data.data_ptr() diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 95d6d524fd69..20da2bc25790 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -42,7 +42,8 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH')) include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] - cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] + # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] cc_cmd += [f'-l{lib}' for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index afdb41d0b3f7..9a79b47a67d8 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -292,7 +292,6 @@ def compute_spec_key(v, align): def mangle_type(arg, is_const=False): - if arg is None: return "none" elif isinstance(arg, bool): @@ -306,6 +305,8 @@ def mangle_type(arg, is_const=False): return "i64" elif isinstance(arg, float): return "fp32" + elif "NvTmaDesc" in type(arg).__name__: + return "nvTmaDesc" else: # dtypes are hashable so we can memoize this mapping: dsk = (arg.dtype, is_const) diff --git a/test.py b/test.py new file mode 100644 index 000000000000..ff1b20d27726 --- /dev/null +++ b/test.py @@ -0,0 +1,31 @@ +import torch +import triton +from triton import language as tl +from triton import cdiv + +BLOCK_M : tl.constexpr = 128 +BLOCK_N : tl.constexpr = 128 + +@triton.jit +def test_kernel(desc): + off_n = tl.program_id(0) * BLOCK_N + off_m = tl.program_id(1) * BLOCK_M + tile = tl._experimental_descriptor_load(desc, [off_m, off_n], [BLOCK_M, BLOCK_N], tl.float32) + tile += 1.0 + tl._experimental_descriptor_store(desc, tile, [off_m, off_n]) + +M = 256 +N = 512 +tensor = torch.zeros((M, N), device='cuda', dtype=torch.float32) +cpu_desc = torch.empty(128, device="cpu") +triton.runtime.driver.active.utils.fill_2d_tma_descriptor( + tensor.data_ptr(), + M, N, BLOCK_M, BLOCK_N, + tensor.element_size(), cpu_desc.data_ptr() +) + +val = torch.clone(tensor) + 1.0 +test_kernel[(cdiv(N, BLOCK_N), cdiv(M, BLOCK_M))](tl.NvTmaDesc(cpu_desc), num_warps=1) +assert torch.allclose(val, tensor) + +print("byval tma desc passed!") diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 90f71138bcd9..a6020f4e993e 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -110,6 +110,7 @@ def ty_to_cpp(ty): "fp32": "float", "f32": "float", "fp64": "double", + "nvTmaDesc": "CUtensorMap", }[ty] @@ -121,6 +122,9 @@ def make_launcher(constants, signature, ids): def _extracted_type(ty): if ty[0] == '*': return "PyObject*" + if ty == "nvTmaDesc": + return "PyObject*" + return ty_to_cpp(ty) def format_of(ty): @@ -143,6 +147,16 @@ def format_of(ty): format = "iiiKKOOOO" + args_format args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + internal_args_list = [] + for i, ty in signature.items(): + if ty[0] == "*": + internal_args_list.append(f"ptr_info{i}.dev_ptr") + elif ty == "nvTmaDesc": + # Note: we have to dereference the pointer + internal_args_list.append(f"*tma_ptr{i}") + else: + internal_args_list.append(f"_arg{i}") + # generate glue code params = [i for i in signature.keys() if i not in constants] src = f""" @@ -271,6 +285,45 @@ def format_of(ty): return ptr_info; }} +static inline CUtensorMap* getTmaDesc(PyObject *obj) {{ + if (sizeof(CUtensorMap*) != 8) {{ + PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation"); + return NULL; + }} + + PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_ptr"); + if (!method_handle) {{ + PyErr_SetString(PyExc_TypeError, "tma_desc_ptr() method does not exist"); + return NULL; + }} + + PyObject *empty_tuple = PyTuple_New(0); + if (!empty_tuple) goto python_internal_error; + PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(method_handle); + if (!method_ret) goto python_internal_error; + + if (!PyLong_Check(method_ret)) {{ + PyErr_SetString(PyExc_TypeError, "tma_desc_ptr() must return 64-bit int"); + Py_DECREF(method_ret); + return NULL; + }} + + uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret); + if (ptr_as_uint % 64 != 0) {{ + PyErr_SetString(PyExc_ValueError, "tma_desc_ptr() must be 64-byte aligned"); + Py_DECREF(method_ret); + return NULL; + }} + + return (CUtensorMap*)(ptr_as_uint); + +python_internal_error: + PyErr_SetString(PyExc_SystemError, "Internal Python error!"); + return NULL; +}} + static PyObject* launch(PyObject* self, PyObject* args) {{ int gridX, gridY, gridZ; uint64_t _stream; @@ -302,9 +355,10 @@ def format_of(ty): }} // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])}; Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); Py_END_ALLOW_THREADS; if (PyErr_Occurred()) {{ return NULL; From f14b9cb8caab9d792f474fbbd2833ddd9bbd46ef Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Sat, 3 Aug 2024 15:11:40 -0700 Subject: [PATCH 02/14] nits for driver code --- third_party/nvidia/backend/driver.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index a6020f4e993e..b62673f0cacf 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -311,9 +311,13 @@ def format_of(ty): }} uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret); + Py_DECREF(method_ret); + if (!ptr_as_uint) {{ + PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_ptr()"); + return NULL; + }} if (ptr_as_uint % 64 != 0) {{ PyErr_SetString(PyExc_ValueError, "tma_desc_ptr() must be 64-byte aligned"); - Py_DECREF(method_ret); return NULL; }} From 1e010f77e7501b656696e0cfcc1284ff57523152 Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Mon, 5 Aug 2024 09:03:11 -0700 Subject: [PATCH 03/14] TmaDescKernelParam class --- python/setup.py | 2 + python/triton/language/__init__.py | 2 - python/triton/language/core.py | 8 ---- python/triton/runtime/jit.py | 3 +- .../triton/tools/experimental_descriptor.py | 39 ++++++++++--------- test.py | 9 ++--- 6 files changed, 28 insertions(+), 35 deletions(-) diff --git a/python/setup.py b/python/setup.py index 9abf85285cb8..5158840ac9d1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -95,8 +95,10 @@ def check_env_flag(name: str, default: str = "") -> bool: def get_build_type(): if check_env_flag("DEBUG"): + print("DEBUG BUILD!") return "Debug" elif check_env_flag("REL_WITH_DEB_INFO"): + print("REL_WITH_DEB_INFO!") return "RelWithDebInfo" elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"): return "TritonRelBuildWithAsserts" diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 79e8d81029f1..ffd2a0b8a150 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -103,7 +103,6 @@ view, void, where, - NvTmaDesc, ) from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, ceil) @@ -250,7 +249,6 @@ "xor_sum", "zeros", "zeros_like", - "NvTmaDesc", ] diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 8e3aee8a228f..b8811c9e26c2 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2665,11 +2665,3 @@ def binary_op_type_legalization(lhs, rhs, builder): def extern(fn): """A decorator for external functions.""" return builtin(fn) - -class NvTmaDesc: - def __init__(self, data): - self.data = data - self.dtype = uint8 - - def tma_desc_ptr(self): - return self.data.data_ptr() diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 9a79b47a67d8..e41191599dcd 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -292,6 +292,7 @@ def compute_spec_key(v, align): def mangle_type(arg, is_const=False): + if arg is None: return "none" elif isinstance(arg, bool): @@ -305,7 +306,7 @@ def mangle_type(arg, is_const=False): return "i64" elif isinstance(arg, float): return "fp32" - elif "NvTmaDesc" in type(arg).__name__: + elif hasattr(arg, "tma_desc_ptr"): return "nvTmaDesc" else: # dtypes are hashable so we can memoize this mapping: diff --git a/python/triton/tools/experimental_descriptor.py b/python/triton/tools/experimental_descriptor.py index c1265ba04bbd..dadbc06ab822 100644 --- a/python/triton/tools/experimental_descriptor.py +++ b/python/triton/tools/experimental_descriptor.py @@ -2,27 +2,28 @@ import triton +class TmaDescKernelParam: + TMA_DESC_SIZE = 128 + + def __init__(self, ptr, dims, block_dims, element_size): + self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.int8, device="cpu") + assert len(dims) == len(block_dims) + assert 1 <= len(dims) <= 2 + assert self.desc.data_ptr() % 64 == 0 + + if len(dims) == 1: + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, self.desc.data_ptr()) + else: + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], block_dims[1], element_size, + self.desc.data_ptr()) + + def tma_desc_ptr(self): + return self.desc.data_ptr() + -# Constructs a 1D TMA descriptor in mutable GPU memory. -# -# Note: on the first use of a new descriptor, each SM must invalidate the descriptor's -# address in TMA cache via fence.proxy.tensormap::generic.acquire.gpu. def create_1d_tma_descriptor(ptr, dim, block_dim, element_size): - TMA_SIZE = 128 - desc = torch.empty(TMA_SIZE, dtype=torch.int8) - triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dim, block_dim, element_size, desc.data_ptr()) - gpu_desc = desc.cuda() - return gpu_desc + return TmaDescKernelParam(ptr, [dim], [block_dim], element_size) -# Constructs a 2D TMA descriptor in mutable GPU memory. -# -# Note: on the first use of a new descriptor, each SM must invalidate the descriptor's -# address in TMA cache via fence.proxy.tensormap::generic.acquire.gpu. def create_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size): - TMA_SIZE = 128 - desc = torch.empty(TMA_SIZE, dtype=torch.int8) - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size, - desc.data_ptr()) - gpu_desc = desc.cuda() - return gpu_desc + return TmaDescKernelParam(ptr, [dim1, dim0], [block_dim1, block_dim0], element_size) diff --git a/test.py b/test.py index ff1b20d27726..b986b42f003e 100644 --- a/test.py +++ b/test.py @@ -1,6 +1,7 @@ import torch import triton from triton import language as tl +from triton.tools.experimental_descriptor import create_2d_tma_descriptor from triton import cdiv BLOCK_M : tl.constexpr = 128 @@ -17,15 +18,13 @@ def test_kernel(desc): M = 256 N = 512 tensor = torch.zeros((M, N), device='cuda', dtype=torch.float32) -cpu_desc = torch.empty(128, device="cpu") -triton.runtime.driver.active.utils.fill_2d_tma_descriptor( +tma_desc = create_2d_tma_descriptor( tensor.data_ptr(), M, N, BLOCK_M, BLOCK_N, - tensor.element_size(), cpu_desc.data_ptr() -) + tensor.element_size()) val = torch.clone(tensor) + 1.0 -test_kernel[(cdiv(N, BLOCK_N), cdiv(M, BLOCK_M))](tl.NvTmaDesc(cpu_desc), num_warps=1) +test_kernel[(cdiv(N, BLOCK_N), cdiv(M, BLOCK_M))](tma_desc, num_warps=1) assert torch.allclose(val, tensor) print("byval tma desc passed!") From 1ed44a3f808b9f46ee19745c031b78fe09b8428a Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Tue, 6 Aug 2024 14:13:08 -0700 Subject: [PATCH 04/14] refactor FuncOpConversion --- .../TritonGPUToLLVM/FuncOpToLLVM.cpp | 42 ++++++++++++------- python/setup.py | 2 - python/triton/language/core.py | 1 + 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp index 2114c2daf9b1..f5e8e6e57018 100644 --- a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -67,6 +67,29 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { return amendedFuncOp; } + // Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM attributes. + static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) { + const bool isKernel = LLVM::isKernel(llvmFuncOp); + for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) { + const auto attrs = llvmFuncOp.getArgAttrDict(i); + for (const auto& attr : attrs) { + if (attr.getName() == "tt.nv_tma_desc") { + const auto i32_type = mlir::IntegerType::get(llvmFuncOp.getContext(), 32); + assert(attr.getValue() == mlir::IntegerAttr::get(i32_type, 1)); + assert(isKernel && "tt.nv_tma_desc is not supported for device functions"); + + // See https://github.com/google/jax/blob/main/jaxlib/mosaic/gpu/passes.cc + mlir::BlockArgument arg = llvmFuncOp.getArgument(i); + const auto byteType = mlir::IntegerType::get(llvmFuncOp.getContext(), 8); + const auto arrayType = mlir::LLVM::LLVMArrayType::get(llvmFuncOp.getContext(), byteType, 128); + llvmFuncOp.setArgAttr(i, "llvm.byval", mlir::TypeAttr::get(arrayType)); + llvmFuncOp.setArgAttr(i, "nvvm.grid_constant", mlir::UnitAttr::get(llvmFuncOp.getContext())); + llvmFuncOp.setArgAttr(i, "llvm.align", mlir::IntegerAttr::get(i32_type, 64)); + } + } + } + } + LogicalResult matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -87,21 +110,6 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { auto ctx = funcOp->getContext(); if (LLVM::isKernel(funcOp)) { - for (unsigned i = 0; i < newFuncOp.getNumArguments(); ++i) { - const auto attrs = newFuncOp.getArgAttrDict(i); - for (const auto& attr : attrs) { - if (attr.getName() == "tt.nv_tma_desc") { - mlir::BlockArgument arg = newFuncOp.getArgument(i); - const auto byteType = mlir::IntegerType::get(newFuncOp.getContext(), 8); - const auto arrayType = mlir::LLVM::LLVMArrayType::get(newFuncOp.getContext(), byteType, 128); - newFuncOp.setArgAttr(i, "llvm.byval", mlir::TypeAttr::get(arrayType)); - newFuncOp.setArgAttr(i, "nvvm.grid_constant", mlir::UnitAttr::get(newFuncOp.getContext())); - newFuncOp.setArgAttr(i, "llvm.align", mlir::IntegerAttr::get( - mlir::IntegerType::get(newFuncOp.getContext(), 32), 64)); - } - } - } - // Set an attribute to indicate this function is a kernel entry. newFuncOp->setAttr("nvvm.kernel", rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); @@ -120,6 +128,10 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { newFuncOp->setAttr("nvvm.reqntid", rewriter.getDenseI32ArrayAttr(32 * numWarps)); rewriter.eraseOp(funcOp); + + // Add attributes for by-value TMA descriptor args (nvidia) + handleByvalTmaDescArgs(newFuncOp); + return success(); } diff --git a/python/setup.py b/python/setup.py index 5158840ac9d1..9abf85285cb8 100644 --- a/python/setup.py +++ b/python/setup.py @@ -95,10 +95,8 @@ def check_env_flag(name: str, default: str = "") -> bool: def get_build_type(): if check_env_flag("DEBUG"): - print("DEBUG BUILD!") return "Debug" elif check_env_flag("REL_WITH_DEB_INFO"): - print("REL_WITH_DEB_INFO!") return "RelWithDebInfo" elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"): return "TritonRelBuildWithAsserts" diff --git a/python/triton/language/core.py b/python/triton/language/core.py index b8811c9e26c2..9a7ff5b6522e 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -597,6 +597,7 @@ def scalar(self): class nv_tma_desc_type(pointer_type): def __init__(self): super().__init__(uint8, const = True, address_space = 0) + self.name = 'nv_tma_desc_type' class block_type(dtype): From 98d7191a765a31f048acf80bf30dd2b99d85289b Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Tue, 6 Aug 2024 16:21:10 -0700 Subject: [PATCH 05/14] bugfix for null argAttrDict --- lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp index f5e8e6e57018..09736426dc54 100644 --- a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -72,6 +72,8 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { const bool isKernel = LLVM::isKernel(llvmFuncOp); for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) { const auto attrs = llvmFuncOp.getArgAttrDict(i); + if (!attrs) { continue; } + for (const auto& attr : attrs) { if (attr.getName() == "tt.nv_tma_desc") { const auto i32_type = mlir::IntegerType::get(llvmFuncOp.getContext(), 32); From 4a8e66b7cf7ce3b77052822529dd9cb6bb8de460 Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Wed, 7 Aug 2024 16:11:58 -0700 Subject: [PATCH 06/14] add lit test --- python/triton/compiler/code_generator.py | 2 +- test/Conversion/tritonnvidiagpu_to_llvm.mlir | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 959035ccda80..3197bab3f4fd 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -410,7 +410,7 @@ def visit_FunctionDef(self, node): for name, value in self.attributes[i]: self.fn.set_arg_attr(idx, name, value) - # TMA + # Mark this argument as a pass-by-value TMA descriptor (nvidia) if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index c5dc86abd851..511f72bcd907 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -79,3 +79,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: byval_tma_desc + // CHECK: llvm.align = 64 + // CHECK: llvm.byval = !llvm.array<128 x i8> + // CHECK: nvvm.grid_constant + tt.func @byval_tma_desc(%desc: !tt.ptr {tt.nv_tma_desc = 1 : i32}) { + tt.return + } +} From e3d403216afd5a8705ae99fe725715e673d70b97 Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Thu, 8 Aug 2024 11:45:36 -0700 Subject: [PATCH 07/14] format --- .../TritonGPUToLLVM/FuncOpToLLVM.cpp | 33 ++++++++++++------- python/triton/compiler/code_generator.py | 2 +- python/triton/language/__init__.py | 2 +- python/triton/language/core.py | 6 ++-- .../triton/tools/experimental_descriptor.py | 8 +++-- test.py | 13 ++++---- third_party/nvidia/backend/driver.py | 18 +++++----- 7 files changed, 48 insertions(+), 34 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp index 09736426dc54..ae3aa63b851d 100644 --- a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -67,26 +67,37 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { return amendedFuncOp; } - // Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM attributes. + // Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM + // attributes. static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) { const bool isKernel = LLVM::isKernel(llvmFuncOp); for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) { const auto attrs = llvmFuncOp.getArgAttrDict(i); - if (!attrs) { continue; } + if (!attrs) { + continue; + } - for (const auto& attr : attrs) { + for (const auto &attr : attrs) { if (attr.getName() == "tt.nv_tma_desc") { - const auto i32_type = mlir::IntegerType::get(llvmFuncOp.getContext(), 32); + const auto i32_type = + mlir::IntegerType::get(llvmFuncOp.getContext(), 32); assert(attr.getValue() == mlir::IntegerAttr::get(i32_type, 1)); - assert(isKernel && "tt.nv_tma_desc is not supported for device functions"); + assert(isKernel && + "tt.nv_tma_desc is not supported for device functions"); - // See https://github.com/google/jax/blob/main/jaxlib/mosaic/gpu/passes.cc + // See + // https://github.com/google/jax/blob/main/jaxlib/mosaic/gpu/passes.cc mlir::BlockArgument arg = llvmFuncOp.getArgument(i); - const auto byteType = mlir::IntegerType::get(llvmFuncOp.getContext(), 8); - const auto arrayType = mlir::LLVM::LLVMArrayType::get(llvmFuncOp.getContext(), byteType, 128); - llvmFuncOp.setArgAttr(i, "llvm.byval", mlir::TypeAttr::get(arrayType)); - llvmFuncOp.setArgAttr(i, "nvvm.grid_constant", mlir::UnitAttr::get(llvmFuncOp.getContext())); - llvmFuncOp.setArgAttr(i, "llvm.align", mlir::IntegerAttr::get(i32_type, 64)); + const auto byteType = + mlir::IntegerType::get(llvmFuncOp.getContext(), 8); + const auto arrayType = mlir::LLVM::LLVMArrayType::get( + llvmFuncOp.getContext(), byteType, 128); + llvmFuncOp.setArgAttr(i, "llvm.byval", + mlir::TypeAttr::get(arrayType)); + llvmFuncOp.setArgAttr(i, "nvvm.grid_constant", + mlir::UnitAttr::get(llvmFuncOp.getContext())); + llvmFuncOp.setArgAttr(i, "llvm.align", + mlir::IntegerAttr::get(i32_type, 64)); } } } diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 3197bab3f4fd..96b7346ac554 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -409,7 +409,7 @@ def visit_FunctionDef(self, node): if i in self.attributes: for name, value in self.attributes[i]: self.fn.set_arg_attr(idx, name, value) - + # Mark this argument as a pass-by-value TMA descriptor (nvidia) if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index ffd2a0b8a150..0a84bd86a5a1 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -261,7 +261,7 @@ def str_to_ty(name): const = True ty = str_to_ty(name) return pointer_type(element_ty=ty, const=const) - + if name == "nvTmaDesc": return nv_tma_desc_type() diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 9a7ff5b6522e..cf86e9296a6a 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -593,10 +593,12 @@ def __ne__(self, other: pointer_type) -> bool: @property def scalar(self): return self - + + class nv_tma_desc_type(pointer_type): + def __init__(self): - super().__init__(uint8, const = True, address_space = 0) + super().__init__(uint8, const=True, address_space=0) self.name = 'nv_tma_desc_type' diff --git a/python/triton/tools/experimental_descriptor.py b/python/triton/tools/experimental_descriptor.py index dadbc06ab822..7742b30f72a0 100644 --- a/python/triton/tools/experimental_descriptor.py +++ b/python/triton/tools/experimental_descriptor.py @@ -2,6 +2,7 @@ import triton + class TmaDescKernelParam: TMA_DESC_SIZE = 128 @@ -12,10 +13,11 @@ def __init__(self, ptr, dims, block_dims, element_size): assert self.desc.data_ptr() % 64 == 0 if len(dims) == 1: - triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, self.desc.data_ptr()) + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, + self.desc.data_ptr()) else: - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], block_dims[1], element_size, - self.desc.data_ptr()) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], + block_dims[1], element_size, self.desc.data_ptr()) def tma_desc_ptr(self): return self.desc.data_ptr() diff --git a/test.py b/test.py index b986b42f003e..e8944efd6626 100644 --- a/test.py +++ b/test.py @@ -4,8 +4,9 @@ from triton.tools.experimental_descriptor import create_2d_tma_descriptor from triton import cdiv -BLOCK_M : tl.constexpr = 128 -BLOCK_N : tl.constexpr = 128 +BLOCK_M: tl.constexpr = 128 +BLOCK_N: tl.constexpr = 128 + @triton.jit def test_kernel(desc): @@ -15,13 +16,11 @@ def test_kernel(desc): tile += 1.0 tl._experimental_descriptor_store(desc, tile, [off_m, off_n]) + M = 256 N = 512 -tensor = torch.zeros((M, N), device='cuda', dtype=torch.float32) -tma_desc = create_2d_tma_descriptor( - tensor.data_ptr(), - M, N, BLOCK_M, BLOCK_N, - tensor.element_size()) +tensor = torch.zeros((M, N), device='cuda', dtype=torch.float32) +tma_desc = create_2d_tma_descriptor(tensor.data_ptr(), M, N, BLOCK_M, BLOCK_N, tensor.element_size()) val = torch.clone(tensor) + 1.0 test_kernel[(cdiv(N, BLOCK_N), cdiv(M, BLOCK_M))](tma_desc, num_warps=1) diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index b62673f0cacf..820bf68f2dcb 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -149,13 +149,13 @@ def format_of(ty): internal_args_list = [] for i, ty in signature.items(): - if ty[0] == "*": - internal_args_list.append(f"ptr_info{i}.dev_ptr") - elif ty == "nvTmaDesc": - # Note: we have to dereference the pointer - internal_args_list.append(f"*tma_ptr{i}") - else: - internal_args_list.append(f"_arg{i}") + if ty[0] == "*": + internal_args_list.append(f"ptr_info{i}.dev_ptr") + elif ty == "nvTmaDesc": + # Note: we have to dereference the pointer + internal_args_list.append(f"*tma_ptr{i}") + else: + internal_args_list.append(f"_arg{i}") # generate glue code params = [i for i in signature.keys() if i not in constants] @@ -290,7 +290,7 @@ def format_of(ty): PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation"); return NULL; }} - + PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_ptr"); if (!method_handle) {{ PyErr_SetString(PyExc_TypeError, "tma_desc_ptr() method does not exist"); @@ -322,7 +322,7 @@ def format_of(ty): }} return (CUtensorMap*)(ptr_as_uint); - + python_internal_error: PyErr_SetString(PyExc_SystemError, "Internal Python error!"); return NULL; From 20d63a8e026116d11928598901c33cdad540307e Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Fri, 9 Aug 2024 14:22:35 -0700 Subject: [PATCH 08/14] update unit tests for byval tma --- .../test/unit/hopper/test_experimental_tma.py | 78 ++++++------------- python/tutorials/09-persistent-matmul.py | 8 -- 2 files changed, 23 insertions(+), 63 deletions(-) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index c0228fb54a69..df3ba76eee7f 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -7,50 +7,17 @@ from triton.tools.experimental_descriptor import create_1d_tma_descriptor, create_2d_tma_descriptor -def test_descriptor_load_ttgir(): - if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: - pytest.skip("Test requires Hopper target.") - return - device = "cuda" - SIZE = 128 - - x = torch.randn(SIZE, dtype=torch.float32, device=device) - desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size()) - size_in_bytes = SIZE * x.element_size() - - ir = f""" - #blocked = #triton_gpu.blocked<{{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}}> - #shared = #triton_gpu.shared<{{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}}> - module attributes {{"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ - tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ - %c0_i32 = arith.constant 0 : i32 - %0 = tt.make_range {{end = {SIZE} : i32, start = 0 : i32}} : tensor<{SIZE}xi32, #blocked> - %1 = triton_gpu.local_alloc : () -> !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> - %2 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %2, 1 : <1xi64, #shared, #triton_gpu.shared_memory, mutable> - %true = arith.constant 1 : i1 - triton_nvidia_gpu.barrier_expect %2, {size_in_bytes}, %true : <1xi64, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0_i32] %1, %2, %true : , <1xi64, #shared, #triton_gpu.shared_memory, mutable> -> <{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %2, %c0_i32 : <1xi64, #shared, #triton_gpu.shared_memory, mutable> - %3 = triton_gpu.local_load %1 : !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<{SIZE}xf32, #blocked> - %4 = tt.splat %arg0 : !tt.ptr -> tensor<{SIZE}x!tt.ptr, #blocked> - %5 = tt.addptr %4, %0 : tensor<{SIZE}x!tt.ptr, #blocked>, tensor<{SIZE}xi32, #blocked> - tt.store %5, %3 : tensor<{SIZE}x!tt.ptr, #blocked> - tt.return - }} - }} - """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) - - z_tri = torch.empty_like(x) - kernel[(1, 1, 1)](z_tri, desc) - assert torch.equal(x, z_tri) +def create_tma_desc_gmem_ptr(ptr, dims, block_dims, element_size): + cpu_desc = torch.empty(128, device="cpu") + if len(dims) == 1: + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, cpu_desc.data_ptr()) + else: + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], block_dims[1], element_size, cpu_desc.data_ptr()) + return cpu_desc.cuda() -def test_experimetal_descriptor_load(): +@pytest.mark.parametrize("byval_tma", [True, False]) +def test_experimetal_descriptor_load(byval_tma): if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: pytest.skip("Test requires Hopper target.") return @@ -65,7 +32,10 @@ def kernel(Z, desc, SIZE: tl.constexpr): tl.store(Z + off, x) x = torch.randn(SIZE, dtype=torch.float32, device=device) - desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size()) + if byval_tma: + desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size()) + else: + desc = create_tma_desc_gmem_ptr(x.data_ptr(), [SIZE], [SIZE], x.element_size()) z_tri = torch.empty_like(x) kernel[(1, )](z_tri, desc, SIZE=SIZE, num_warps=4) assert torch.equal(x, z_tri) @@ -74,14 +44,6 @@ def kernel(Z, desc, SIZE: tl.constexpr): @triton.jit def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): - # TODO(embg) remove TMA fence after __grid_constant__ lands - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) - pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) pid_m = pid % num_pid_m @@ -101,7 +63,8 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # @pytest.mark.parametrize("num_stages", [1, 4]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) -def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K): +@pytest.mark.parametrize("byval_tma", [True, False]) +def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma): if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: pytest.skip("Test requires Hopper target.") return @@ -111,9 +74,14 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K): A = torch.randn((M, K), dtype=torch.float16, device=device) B = torch.randn((K, N), dtype=torch.float16, device=device) C = torch.empty((M, N), dtype=torch.float16, device=device) - desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size()) - desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size()) - desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size()) + if byval_tma: + desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size()) + desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size()) + desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size()) + else: + desc_a = create_tma_desc_gmem_ptr(A.data_ptr(), [M, K], [BLOCK_M, BLOCK_K], A.element_size()) + desc_b = create_tma_desc_gmem_ptr(B.data_ptr(), [K, N], [BLOCK_K, BLOCK_N], B.element_size()) + desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size()) kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps=8, num_stages=num_stages) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 460c374d7fb2..fdbdbfecfb86 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -259,14 +259,6 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # GROUP_SIZE_M: tl.constexpr, # FP8_OUTPUT: tl.constexpr, # NUM_SMS: tl.constexpr): # - # TODO(embg) remove TMA fence after __grid_constant__ lands - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) - dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) From 8a3b707e43fae424ab76678a89538c387f267e6e Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Fri, 9 Aug 2024 14:23:27 -0700 Subject: [PATCH 09/14] format --- python/test/unit/hopper/test_experimental_tma.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index df3ba76eee7f..f6885bd64ad5 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -1,6 +1,5 @@ import pytest import torch -import tempfile import triton import triton.language as tl @@ -10,9 +9,11 @@ def create_tma_desc_gmem_ptr(ptr, dims, block_dims, element_size): cpu_desc = torch.empty(128, device="cpu") if len(dims) == 1: - triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, cpu_desc.data_ptr()) + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, + cpu_desc.data_ptr()) else: - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], block_dims[1], element_size, cpu_desc.data_ptr()) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], block_dims[1], + element_size, cpu_desc.data_ptr()) return cpu_desc.cuda() From b5b420f71611c5e5057c5aa54cdc6deda915e93c Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Fri, 9 Aug 2024 14:44:22 -0700 Subject: [PATCH 10/14] check PTX in unit tests --- python/test/unit/hopper/test_experimental_tma.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index f6885bd64ad5..4140a36a22f2 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -38,8 +38,10 @@ def kernel(Z, desc, SIZE: tl.constexpr): else: desc = create_tma_desc_gmem_ptr(x.data_ptr(), [SIZE], [SIZE], x.element_size()) z_tri = torch.empty_like(x) - kernel[(1, )](z_tri, desc, SIZE=SIZE, num_warps=4) + compiled_kernel = kernel[(1, )](z_tri, desc, SIZE=SIZE, num_warps=4) assert torch.equal(x, z_tri) + if byval_tma: + assert ".param .align 64 .b8" in compiled_kernel.asm["ptx"] @triton.jit @@ -90,3 +92,5 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tm torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) if BLOCK_M >= 64 and BLOCK_N >= 64: assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] + if byval_tma: + assert ".param .align 64 .b8" in kernel.asm["ptx"] From cd03b7a7cfc843bd1b49e5427aae44a0116da596 Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Fri, 9 Aug 2024 14:54:11 -0700 Subject: [PATCH 11/14] remove local test script --- test.py | 29 ----------------------------- 1 file changed, 29 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index e8944efd6626..000000000000 --- a/test.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -import triton -from triton import language as tl -from triton.tools.experimental_descriptor import create_2d_tma_descriptor -from triton import cdiv - -BLOCK_M: tl.constexpr = 128 -BLOCK_N: tl.constexpr = 128 - - -@triton.jit -def test_kernel(desc): - off_n = tl.program_id(0) * BLOCK_N - off_m = tl.program_id(1) * BLOCK_M - tile = tl._experimental_descriptor_load(desc, [off_m, off_n], [BLOCK_M, BLOCK_N], tl.float32) - tile += 1.0 - tl._experimental_descriptor_store(desc, tile, [off_m, off_n]) - - -M = 256 -N = 512 -tensor = torch.zeros((M, N), device='cuda', dtype=torch.float32) -tma_desc = create_2d_tma_descriptor(tensor.data_ptr(), M, N, BLOCK_M, BLOCK_N, tensor.element_size()) - -val = torch.clone(tensor) + 1.0 -test_kernel[(cdiv(N, BLOCK_N), cdiv(M, BLOCK_M))](tma_desc, num_warps=1) -assert torch.allclose(val, tensor) - -print("byval tma desc passed!") From 35dde67da1b4564385866d278369d69012156911 Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Fri, 9 Aug 2024 16:17:07 -0700 Subject: [PATCH 12/14] small bugfix --- third_party/nvidia/backend/driver.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 820bf68f2dcb..9dd9e919639e 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -298,7 +298,10 @@ def format_of(ty): }} PyObject *empty_tuple = PyTuple_New(0); - if (!empty_tuple) goto python_internal_error; + if (!empty_tuple) {{ + Py_DECREF(method_handle); + goto python_internal_error; + }} PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL); Py_DECREF(empty_tuple); Py_DECREF(method_handle); From 84d0173836df387fbd8bc59fe99ef2900f455929 Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Tue, 13 Aug 2024 09:49:03 -0700 Subject: [PATCH 13/14] fence when byval_tma is false --- .../test/unit/hopper/test_experimental_tma.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 4140a36a22f2..6c687d0891d0 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -17,6 +17,9 @@ def create_tma_desc_gmem_ptr(ptr, dims, block_dims, element_size): return cpu_desc.cuda() +TMA_FENCE_ASM: tl.constexpr = "fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg" + + @pytest.mark.parametrize("byval_tma", [True, False]) def test_experimetal_descriptor_load(byval_tma): if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: @@ -26,7 +29,9 @@ def test_experimetal_descriptor_load(byval_tma): SIZE = 128 @triton.jit - def kernel(Z, desc, SIZE: tl.constexpr): + def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr): + if not BYVAL_TMA: + tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [desc], dtype=tl.int32, is_pure=False, pack=1) off_desc = 0 off = tl.arange(0, SIZE) x = tl._experimental_descriptor_load(desc, [off_desc], [SIZE], Z.dtype.element_ty) @@ -38,7 +43,7 @@ def kernel(Z, desc, SIZE: tl.constexpr): else: desc = create_tma_desc_gmem_ptr(x.data_ptr(), [SIZE], [SIZE], x.element_size()) z_tri = torch.empty_like(x) - compiled_kernel = kernel[(1, )](z_tri, desc, SIZE=SIZE, num_warps=4) + compiled_kernel = kernel[(1, )](z_tri, desc, SIZE=SIZE, BYVAL_TMA=byval_tma, num_warps=4) assert torch.equal(x, z_tri) if byval_tma: assert ".param .align 64 .b8" in compiled_kernel.asm["ptx"] @@ -46,7 +51,13 @@ def kernel(Z, desc, SIZE: tl.constexpr): @triton.jit def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # - M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BYVAL_TMA: tl.constexpr): + if not BYVAL_TMA: + tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) pid_m = pid % num_pid_m @@ -86,8 +97,8 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tm desc_b = create_tma_desc_gmem_ptr(B.data_ptr(), [K, N], [BLOCK_K, BLOCK_N], B.element_size()) desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size()) kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, - 1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps=8, - num_stages=num_stages) + 1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma, + num_warps=8, num_stages=num_stages) ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16) torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) if BLOCK_M >= 64 and BLOCK_N >= 64: From 24079aec3a2466787d3422781960eadc49895ea9 Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Thu, 15 Aug 2024 12:19:30 -0700 Subject: [PATCH 14/14] nits --- python/triton/runtime/jit.py | 2 +- .../triton/tools/experimental_descriptor.py | 3 ++- third_party/nvidia/backend/driver.py | 22 +++++++++---------- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index e41191599dcd..821abd150138 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -306,7 +306,7 @@ def mangle_type(arg, is_const=False): return "i64" elif isinstance(arg, float): return "fp32" - elif hasattr(arg, "tma_desc_ptr"): + elif hasattr(arg, "tma_desc_cpu_ptr"): return "nvTmaDesc" else: # dtypes are hashable so we can memoize this mapping: diff --git a/python/triton/tools/experimental_descriptor.py b/python/triton/tools/experimental_descriptor.py index 7742b30f72a0..fba3366c0ca6 100644 --- a/python/triton/tools/experimental_descriptor.py +++ b/python/triton/tools/experimental_descriptor.py @@ -19,7 +19,8 @@ def __init__(self, ptr, dims, block_dims, element_size): triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], block_dims[1], element_size, self.desc.data_ptr()) - def tma_desc_ptr(self): + # Return a CUtensorMap* pointer in host memory + def tma_desc_cpu_ptr(self): return self.desc.data_ptr() diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 9dd9e919639e..bf1f066d5537 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -291,24 +291,28 @@ def format_of(ty): return NULL; }} - PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_ptr"); + PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr"); if (!method_handle) {{ - PyErr_SetString(PyExc_TypeError, "tma_desc_ptr() method does not exist"); + PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist"); return NULL; }} PyObject *empty_tuple = PyTuple_New(0); if (!empty_tuple) {{ Py_DECREF(method_handle); - goto python_internal_error; + PyErr_SetString(PyExc_SystemError, "Internal Python error!"); + return NULL; }} PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL); Py_DECREF(empty_tuple); Py_DECREF(method_handle); - if (!method_ret) goto python_internal_error; + if (!method_ret) {{ + PyErr_SetString(PyExc_SystemError, "Internal Python error!"); + return NULL; + }} if (!PyLong_Check(method_ret)) {{ - PyErr_SetString(PyExc_TypeError, "tma_desc_ptr() must return 64-bit int"); + PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int"); Py_DECREF(method_ret); return NULL; }} @@ -316,19 +320,15 @@ def format_of(ty): uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret); Py_DECREF(method_ret); if (!ptr_as_uint) {{ - PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_ptr()"); + PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()"); return NULL; }} if (ptr_as_uint % 64 != 0) {{ - PyErr_SetString(PyExc_ValueError, "tma_desc_ptr() must be 64-byte aligned"); + PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned"); return NULL; }} return (CUtensorMap*)(ptr_as_uint); - -python_internal_error: - PyErr_SetString(PyExc_SystemError, "Internal Python error!"); - return NULL; }} static PyObject* launch(PyObject* self, PyObject* args) {{