Skip to content

Commit

Permalink
Fixes #686: Can now specify which context to load a library kernel in…
Browse files Browse the repository at this point in the history
…to - both with the standalone `library::get_kernel()` function and with the `library_t::get_kernel()` method.
  • Loading branch information
eyalroz committed Nov 17, 2024
1 parent 1f10302 commit b59269c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
13 changes: 12 additions & 1 deletion src/cuda/api/kernels/in_library.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ inline ::std::string identify(const kernel_t& library_kernel)

inline kernel_t get(const library_t& library, const char* name)
{
auto kernel_handle = library::detail_::get_kernel(library.handle(), name);
auto kernel_handle = cuda::library::detail_::get_kernel_in_current_context(library.handle(), name);
return kernel::detail_::wrap(library.handle(), kernel_handle);
}

Expand All @@ -209,6 +209,17 @@ inline library::kernel_t library_t::get_kernel(const ::std::string& name) const
return get_kernel(name.c_str());
}

inline library::kernel_t library_t::get_kernel(const context_t& context, const char* name) const
{
CUDA_CONTEXT_FOR_THIS_SCOPE(context);
return library::kernel::get(*this, name);
}

inline library::kernel_t library_t::get_kernel(const context_t& context, const ::std::string& name) const
{
return get_kernel(context, name.c_str());
}

} // namespace cuda

#endif // CUDA_VERSION >= 12000
Expand Down
15 changes: 12 additions & 3 deletions src/cuda/api/library.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,25 @@ library_t create(

namespace detail_ {

inline kernel::handle_t get_kernel(handle_t library_handle, const char* name)
inline kernel::handle_t get_kernel_in_current_context(handle_t library_handle, const char* name)
{
library::kernel::handle_t kernel_handle;
auto status = cuLibraryGetKernel(&kernel_handle, library_handle, name);
throw_if_error_lazy(status, ::std::string{"Failed obtaining kernel "} + name
+ "' from " + library::detail_::identify(library_handle));
throw_if_error_lazy(status, ::std::string{"Failed obtaining kernel "}
+ name + "' from " + library::detail_::identify(library_handle));
return kernel_handle;
}

inline kernel::handle_t get_kernel(context::handle_t context_handle, handle_t library_handle, const char* name)
{
CAW_SET_SCOPE_CONTEXT(context_handle);
return get_kernel_in_current_context(library_handle, name);
}

} // namespace detail_

inline kernel_t get_kernel(const library_t& library, const char* name);
inline kernel_t get_kernel(context_t& context, const library_t& library, const char* name);

} // namespace library

Expand Down Expand Up @@ -125,6 +132,8 @@ class library_t {
* @return An enqueable kernel proxy object for the requested kernel,
* in the current context.
*/
library::kernel_t get_kernel(const context_t& context, const char* name) const;
library::kernel_t get_kernel(const context_t& context, const ::std::string& name) const;
library::kernel_t get_kernel(const char* name) const;
library::kernel_t get_kernel(const ::std::string& name) const;

Expand Down

0 comments on commit b59269c

Please sign in to comment.