Skip to content

[Offload] Allow CUDA Kernels to use arbitrarily large shared memory #145963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

gvalson
Copy link

@gvalson gvalson commented Jun 26, 2025

Previously, the user was not able to use more than 48 KB of shared memory on NVIDIA GPUs. In order to do so, setting the function attribute CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK is required, which was not present in the code base. With this commit, we add the ability toset this attribute, allowing the user to utilize the full power of their GPU.

In order to not have to reset the function attribute for each launch of the same kernel, we keep track of the maximum memory limit (as the variable MaxDynCGroupMemLimit) and only set the attribute if our desired amount exceeds the limit. By default, this limit is set to 48 KB.

Feedback is greatly appreciated, especially around setting the new variable as mutable. I did this becuase the launchImpl method is const and I am not able to modify my variable otherwise.

Previously, the user was not able to use more than 48 KB of shared
memory on NVIDIA GPUs. In order to do so, setting the function
attribute `CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK` is required, which
was not present in the code base. With this commit, we add the ability
toset this attribute, allowing the user to utilize the full power of
their GPU.

In order to not have to reset the function attribute for each launch
of the same kernel, we keep track of the maximum memory limit (as the
variable `MaxDynCGroupMemLimit`) and only set the attribute if our
desired amount exceeds the limit. By default, this limit is set to 48
KB.

Feedback is greatly appreciated, especially around setting the new
variable as mutable. I did this becuase the `launchImpl` method is
const and I am not able to modify my variable otherwise.
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-offload

Author: Giorgi Gvalia (gvalson)

Changes

Previously, the user was not able to use more than 48 KB of shared memory on NVIDIA GPUs. In order to do so, setting the function attribute CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK is required, which was not present in the code base. With this commit, we add the ability toset this attribute, allowing the user to utilize the full power of their GPU.

In order to not have to reset the function attribute for each launch of the same kernel, we keep track of the maximum memory limit (as the variable MaxDynCGroupMemLimit) and only set the attribute if our desired amount exceeds the limit. By default, this limit is set to 48 KB.

Feedback is greatly appreciated, especially around setting the new variable as mutable. I did this becuase the launchImpl method is const and I am not able to modify my variable otherwise.


Full diff: https://github.com/llvm/llvm-project/pull/145963.diff

3 Files Affected:

  • (modified) offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp (+1)
  • (modified) offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h (+2)
  • (modified) offload/plugins-nextgen/cuda/src/rtl.cpp (+14)
diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
index e5332686fcffb..361a781e8f9b6 100644
--- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
+++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
@@ -31,6 +31,7 @@ DLWRAP(cuDeviceGet, 2)
 DLWRAP(cuDeviceGetAttribute, 3)
 DLWRAP(cuDeviceGetCount, 1)
 DLWRAP(cuFuncGetAttribute, 3)
+DLWRAP(cuFuncSetAttribute, 3)
 
 // Device info
 DLWRAP(cuDeviceGetName, 3)
diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
index 1c5b421768894..b6c022c8e7e8b 100644
--- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
+++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
@@ -258,6 +258,7 @@ typedef enum CUdevice_attribute_enum {
 
 typedef enum CUfunction_attribute_enum {
   CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 0,
+  CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES = 8,
 } CUfunction_attribute;
 
 typedef enum CUctx_flags_enum {
@@ -295,6 +296,7 @@ CUresult cuDeviceGet(CUdevice *, int);
 CUresult cuDeviceGetAttribute(int *, CUdevice_attribute, CUdevice);
 CUresult cuDeviceGetCount(int *);
 CUresult cuFuncGetAttribute(int *, CUfunction_attribute, CUfunction);
+CUresult cuFuncSetAttribute(CUfunction, CUfunction_attribute, int);
 
 // Device info
 CUresult cuDeviceGetName(char *, int, CUdevice);
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 0e662b038c363..fd9528061b55e 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -160,6 +160,9 @@ struct CUDAKernelTy : public GenericKernelTy {
 private:
   /// The CUDA kernel function to execute.
   CUfunction Func;
+  /// The maximum amount of dynamic shared memory per thread group. By default,
+  /// this is set to 48 KB.
+  mutable uint32_t MaxDynCGroupMemLimit = 49152;
 };
 
 /// Class wrapping a CUDA stream reference. These are the objects handled by the
@@ -1300,6 +1303,17 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
   if (GenericDevice.getRPCServer())
     GenericDevice.Plugin.getRPCServer().Thread->notify();
 
+  // In case we require more memory than the current limit.
+  if (MaxDynCGroupMem >= MaxDynCGroupMemLimit) {
+    CUresult AttrResult = cuFuncSetAttribute(
+        Func,
+        CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
+        MaxDynCGroupMem);
+    Plugin::check(AttrResult,
+        "Error in cuLaunchKernel while setting the memory limits: %s");
+    MaxDynCGroupMemLimit = MaxDynCGroupMem;
+  }
+
   CUresult Res = cuLaunchKernel(Func, NumBlocks[0], NumBlocks[1], NumBlocks[2],
                                 NumThreads[0], NumThreads[1], NumThreads[2],
                                 MaxDynCGroupMem, Stream, nullptr, Config);

@@ -160,6 +160,9 @@ struct CUDAKernelTy : public GenericKernelTy {
private:
/// The CUDA kernel function to execute.
CUfunction Func;
/// The maximum amount of dynamic shared memory per thread group. By default,
/// this is set to 48 KB.
mutable uint32_t MaxDynCGroupMemLimit = 49152;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be mutable, we probably want to initialize this at kernel creation using the correct value from the cuFuncGetAttributes function. Alternatively we could just check that value every time we launch a kernel, though I don't know how much overhead that would add.

Making it mutable keeps it up-to-date I suppose, so it would avoid redundant work if we call the function multiple times with a different opt-in. However I'd say that's required for correctness because theoretically a user could use an API to modify it manually so it's probably best to just play it safe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to check every time, and we can't fight the user. They could also change the context or other stuff, that's not our concern. The only real question is if we want to go back to a lower value or keep the high water mark. Going back might benefit performance for those launches, but I'd stick with the high water mark for now. Wrt. Mutable, we need some solution, open to suggestions, mutable isn't the worst here honestly.

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h offload/plugins-nextgen/cuda/src/rtl.cpp
View the diff from clang-format here.
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index fd9528061..b899497bd 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -1306,10 +1306,9 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
   // In case we require more memory than the current limit.
   if (MaxDynCGroupMem >= MaxDynCGroupMemLimit) {
     CUresult AttrResult = cuFuncSetAttribute(
-        Func,
-        CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
-        MaxDynCGroupMem);
-    Plugin::check(AttrResult,
+        Func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, MaxDynCGroupMem);
+    Plugin::check(
+        AttrResult,
         "Error in cuLaunchKernel while setting the memory limits: %s");
     MaxDynCGroupMemLimit = MaxDynCGroupMem;
   }

@shiltian
Copy link
Contributor

I wonder what the attribute this corresponds? Is itCU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES or CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN?

@jhuber6
Copy link
Contributor

jhuber6 commented Jun 30, 2025

I wonder what the attribute this corresponds? Is itCU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES or CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN?

I believe the former is the standard maximum and the latter is the maximum to can opt-in to. I think it's 48 KiB and 64 KiB respectively right now.

@shiltian
Copy link
Contributor

Based on CUDA document, the former can be set by cuFuncSetAttribute. If MaxDynCGroupMemLimit corresponds to the former, it needs to be mutable; otherwise it should be set to the latter value (and then immutable).

@gvalson
Copy link
Author

gvalson commented Jun 30, 2025

Based on CUDA document, the former can be set by cuFuncSetAttribute. If MaxDynCGroupMemLimit corresponds to the former, it needs to be mutable; otherwise it should be set to the latter value (and then immutable).

It does correspond to CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants