diff --git a/src/cuda/api/context.hpp b/src/cuda/api/context.hpp index f1381f16..8bf91e77 100644 --- a/src/cuda/api/context.hpp +++ b/src/cuda/api/context.hpp @@ -759,14 +759,16 @@ class context_t { /// @note: The comparison ignores whether or not the wrapper is owning ///@{ -inline bool operator==(const context_t& lhs, const context_t& rhs) +inline bool operator==(const context_t& lhs, const context_t& rhs) noexcept { - return lhs.handle() == rhs.handle(); + // Note: Contexts on different devices cannot have the same context handle, + // so this is redundant, but let's be extra safe: + return lhs.device_id() == rhs.device_id() and lhs.handle() == rhs.handle(); } -inline bool operator!=(const context_t& lhs, const context_t& rhs) +inline bool operator!=(const context_t& lhs, const context_t& rhs) noexcept { - return lhs.handle() != rhs.handle(); + return not (lhs == rhs); } ///@} diff --git a/src/cuda/api/kernel.hpp b/src/cuda/api/kernel.hpp index d974a6be..5575cce5 100644 --- a/src/cuda/api/kernel.hpp +++ b/src/cuda/api/kernel.hpp @@ -587,6 +587,19 @@ inline grid::dimension_t kernel_t::max_active_blocks_per_multiprocessor( dynamic_shared_memory_per_block, disable_caching_override); } +inline bool operator==(const kernel_t& lhs, const kernel_t& rhs) noexcept +{ + return + lhs.device_id() == rhs.device_id() + and lhs.context_handle() == rhs.context_handle() + and lhs.handle() == rhs.handle(); +} + +inline bool operator!=(const kernel_t& lhs, const kernel_t& rhs) noexcept +{ + return not (lhs == rhs); +} + } // namespace cuda #endif // CUDA_API_WRAPPERS_KERNEL_HPP_