Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic GPU] Pass in TMA descriptors through kernel parameters #22175

Merged
merged 1 commit into from
Jul 2, 2024

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Jun 28, 2024

[Mosaic GPU] Pass in TMA descriptors through kernel parameters

As we've established (sigh) we can't pass in TMA descriptors through global memory.
The current workaround was to use constant memory instead, but this raises a number of
potential concurrency issues. So, instead, we use the freshly added support for grid_constant
parameters in upstream LLVM to pass the descriptors as kernel arguments. This seems to work
fine and should in fact have lower overheads than both previous methods.

@copybara-service copybara-service bot force-pushed the test_647649107 branch 3 times, most recently from 45015cd to 2eb5a48 Compare July 2, 2024 16:06
As we've established (sigh) we can't pass in TMA descriptors through global memory.
The current workaround was to use constant memory instead, but this raises a number of
potential concurrency issues. So, instead, we use the freshly added support for grid_constant
parameters in upstream LLVM to pass the descriptors as kernel arguments. This seems to work
fine and should in fact have lower overheads than both previous methods.

PiperOrigin-RevId: 648744363
@copybara-service copybara-service bot merged commit 265a54d into main Jul 2, 2024
@copybara-service copybara-service bot deleted the test_647649107 branch July 2, 2024 16:30
embg added a commit to triton-lang/triton that referenced this pull request Aug 19, 2024
## Motivation
Currently, Triton passes TMA descriptors by-ref through global memory.
This has a number of problems:
* Significant launch overhead (5-10us) for the host-to-device memcpy
* Users must insert fences for TMA descriptor cache flush (see
#4342). When users don't
insert these fences correctly, they run into very strange bugs:
#4332
* The memcpy makes it nearly impossible to use cudagraphs

There are two possible solutions:
* [Pass the descriptor
by-value](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#using-tma-to-transfer-multi-dimensional-arrays)
* [Create the descriptor
on-device](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device)

Because of the tricky memory model for TMA descriptors on H100, creating
a descriptor on-device requires moving data back and forth from L2
cache. This is relatively expensive (100s of cycles at least) and
requires the user or compiler to correctly insert release/acquire
fences.

In some cases, there is no way to avoid creating the descriptor
on-device. But for many use-cases, it's perfectly fine to set up the
descriptor on the host and pass by-value, avoiding both performance and
correctness issues. This PR implements the by-value functionality.

## User-level API
Whenever the user provides a kernel param which implements the method
`tma_desc_cpu_ptr()`, Triton will lower that argument to a
`__grid_constant__` by-value param. The existing helper methods
`create_[1d/2d]_tma_descriptor` were modified to return such a type, so
existing code does not need any changes to take advantage of the new
feature.

## Implementation details
When a kernel param with `tma_desc_cpu_ptr()` is detected, we attach an
attribute to that param at the TTIR level. The attribute is passed
through to TTGIR. When lowering TTGIR to LLIR, we use code ported from
Mosaic (jax-ml/jax#22175) to set up the correct
LLVM attributes. The runtime is also modified to pass by-value TMA
descriptors properly.

## Limitations
This feature is currently broken when compiling an `IRSource` directly
(which is useful for editing IR and re-compiling). That would require
updating some
[regexes](https://github.com/triton-lang/triton/blob/edcc2bcb8dd2e9224c94b689df9cbb7d2986ebea/python/triton/compiler/compiler.py#L52-L53)
which infer the function signature from the IR. `IRSource` compilation
still works fine for kernels which do not use the new feature.

Once the approach I'm taking here is reviewed, I plan to fix that
limitation, either in this PR or in a follow-up PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant