From 6d7755bf481ebcc72cf5f3c36738ea9346aee3c7 Mon Sep 17 00:00:00 2001 From: Eyal Rozenberg Date: Sat, 19 Oct 2024 20:56:44 +0300 Subject: [PATCH] Fixes #686: Can now specify which context to load a library kernel into - both with the standalone `library::get_kernel()` function and with the `library_t::get_kernel()` method. --- src/cuda/api/kernels/in_library.hpp | 13 ++++++++++++- src/cuda/api/library.hpp | 15 ++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/cuda/api/kernels/in_library.hpp b/src/cuda/api/kernels/in_library.hpp index b44a6d4e..a66c1ec0 100644 --- a/src/cuda/api/kernels/in_library.hpp +++ b/src/cuda/api/kernels/in_library.hpp @@ -191,7 +191,7 @@ inline ::std::string identify(const kernel_t& library_kernel) inline kernel_t get(const library_t& library, const char* name) { - auto kernel_handle = library::detail_::get_kernel(library.handle(), name); + auto kernel_handle = cuda::library::detail_::get_kernel_in_current_context(library.handle(), name); return kernel::detail_::wrap(library.handle(), kernel_handle); } @@ -209,6 +209,17 @@ inline library::kernel_t library_t::get_kernel(const ::std::string& name) const return get_kernel(name.c_str()); } +inline library::kernel_t library_t::get_kernel(const context_t& context, const char* name) const +{ + CUDA_CONTEXT_FOR_THIS_SCOPE(context); + return library::kernel::get(*this, name); +} + +inline library::kernel_t library_t::get_kernel(const context_t& context, const ::std::string& name) const +{ + return get_kernel(context, name.c_str()); +} + } // namespace cuda #endif // CUDA_VERSION >= 12000 diff --git a/src/cuda/api/library.hpp b/src/cuda/api/library.hpp index 0b8d93e6..17f10c89 100644 --- a/src/cuda/api/library.hpp +++ b/src/cuda/api/library.hpp @@ -77,18 +77,25 @@ library_t create( namespace detail_ { -inline kernel::handle_t get_kernel(handle_t library_handle, const char* name) +inline kernel::handle_t get_kernel_in_current_context(handle_t library_handle, const char* name) { library::kernel::handle_t kernel_handle; auto status = cuLibraryGetKernel(&kernel_handle, library_handle, name); - throw_if_error_lazy(status, ::std::string{"Failed obtaining kernel "} + name - + "' from " + library::detail_::identify(library_handle)); + throw_if_error_lazy(status, ::std::string{"Failed obtaining kernel "} + + name + "' from " + library::detail_::identify(library_handle)); return kernel_handle; } +inline kernel::handle_t get_kernel(context::handle_t context_handle, handle_t library_handle, const char* name) +{ + CAW_SET_SCOPE_CONTEXT(context_handle); + return get_kernel_in_current_context(library_handle, name); +} + } // namespace detail_ inline kernel_t get_kernel(const library_t& library, const char* name); +inline kernel_t get_kernel(context_t& context, const library_t& library, const char* name); } // namespace library @@ -125,6 +132,8 @@ class library_t { * @return An enqueable kernel proxy object for the requested kernel, * in the current context. */ + library::kernel_t get_kernel(const context_t& context, const char* name) const; + library::kernel_t get_kernel(const context_t& context, const ::std::string& name) const; library::kernel_t get_kernel(const char* name) const; library::kernel_t get_kernel(const ::std::string& name) const;