-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nvidia] Support passing TMA descriptors by-value #4498
Conversation
@@ -42,7 +42,8 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): | |||
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] | |||
custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH')) | |||
include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] | |||
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] | |||
# for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 | |||
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is causing the extra warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GCC doesn't like the CUtensorMap
struct. This is called out in the CUDA C++ Programming Guide as a false warning:
When passing the tensor map as a parameter, some versions of the GCC C++ compiler issue the warning “the ABI for passing parameters with 64-byte alignment has changed in GCC 4.6”. This warning can be ignored.
I don't think it can be suppressed inline via pragma, it has to be suppressed on the command line: https://godbolt.org/z/f5n5crhjG
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the good work. LGTM in general. Left a couple minor feedbacks.
llvmFuncOp.setArgAttr(i, "nvvm.grid_constant", | ||
mlir::UnitAttr::get(llvmFuncOp.getContext())); | ||
llvmFuncOp.setArgAttr(i, "llvm.align", | ||
mlir::IntegerAttr::get(i32_type, 64)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is 64 a required alignment value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Here is the definition of CUtensorMap
in <cuda.h>
:
typedef struct CUtensorMap_st {
alignas(64)
unsigned long long opaque[16];
} CUtensorMap;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once the other comments are addressed
Summary: This PR follows [a recent PR in Triton](triton-lang/triton#4498) that supports passing TMA descriptors by-value using `__grid_constant__`. In this PR, we update the kernel `_attn_fwd_inner` to support the above new feature in Triton. To support auto-tune, we implement a helper class that wraps operations for TMA during auto-tune and computations in kernel respectively. In addition, the benchmark program now also checks whether the triton version supports this new feature. If it doesn't, the helper class applies the old way of handling TMA. The change has been tested on Triton from the standard installation of pytorch on conda, as well as the recent Triton including the above PR. Command for testing and experiment results: Before removing fences: P1541573348 After removing fences: P1541736645 1) CUDA_VISIBLE_DEVICES=5, old tma: 138.476 2) CUDA_VISIBLE_DEVICES=5, new tma, with fences: 152 - 164 3) CUDA_VISIBLE_DEVICES=5, new tma, after removing fences: 168.0 4) CUDA_VISIBLE_DEVICES=5, no tma: 187.881 The result is still behind no TMA and we can investigate further. Pull Request resolved: #2428 Reviewed By: embg Differential Revision: D61668142 Pulled By: sfzhu93 fbshipit-source-id: d08bab147c6b2197f73447ee8f30ede877e712ca
Motivation
Currently, Triton passes TMA descriptors by-ref through global memory. This has a number of problems:
There are two possible solutions:
Because of the tricky memory model for TMA descriptors on H100, creating a descriptor on-device requires moving data back and forth from L2 cache. This is relatively expensive (100s of cycles at least) and requires the user or compiler to correctly insert release/acquire fences.
In some cases, there is no way to avoid creating the descriptor on-device. But for many use-cases, it's perfectly fine to set up the descriptor on the host and pass by-value, avoiding both performance and correctness issues. This PR implements the by-value functionality.
User-level API
Whenever the user provides a kernel param which implements the method
tma_desc_cpu_ptr()
, Triton will lower that argument to a__grid_constant__
by-value param. The existing helper methodscreate_[1d/2d]_tma_descriptor
were modified to return such a type, so existing code does not need any changes to take advantage of the new feature.Implementation details
When a kernel param with
tma_desc_cpu_ptr()
is detected, we attach an attribute to that param at the TTIR level. The attribute is passed through to TTGIR. When lowering TTGIR to LLIR, we use code ported from Mosaic (jax-ml/jax#22175) to set up the correct LLVM attributes. The runtime is also modified to pass by-value TMA descriptors properly.Limitations
This feature is currently broken when compiling an
IRSource
directly (which is useful for editing IR and re-compiling). That would require updating some regexes which infer the function signature from the IR.IRSource
compilation still works fine for kernels which do not use the new feature.Once the approach I'm taking here is reviewed, I plan to fix that limitation, either in this PR or in a follow-up PR.