diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index f4169d696..7dccfb353 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -1233,20 +1233,24 @@ class open_addressing_ref_impl { */ template [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* address, - value_type const& expected, - Value const& desired) noexcept + value_type expected, + Value desired) noexcept { - cuda::atomic_ref slot_ref(*address); - auto expected_slot = expected; + using packed_type = cuda::std::conditional_t; - auto const success = slot_ref.compare_exchange_strong( - expected_slot, this->native_value(desired), cuda::std::memory_order_relaxed); + auto* slot_ptr = reinterpret_cast(address); + auto* expected_ptr = reinterpret_cast(&expected); + auto* desired_ptr = reinterpret_cast(&desired); + + auto slot_ref = cuda::atomic_ref{*slot_ptr}; + + auto const success = + slot_ref.compare_exchange_strong(*expected_ptr, *desired_ptr, cuda::memory_order_relaxed); if (success) { return insert_result::SUCCESS; } else { - return this->predicate_.equal_to(this->extract_key(desired), - this->extract_key(expected_slot)) == + return this->predicate_.equal_to(this->extract_key(desired), this->extract_key(expected)) == detail::equal_result::EQUAL ? insert_result::DUPLICATE : insert_result::CONTINUE;