diff --git a/include/cuco/detail/open_addressing_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_impl.cuh similarity index 100% rename from include/cuco/detail/open_addressing_impl.cuh rename to include/cuco/detail/open_addressing/open_addressing_impl.cuh diff --git a/include/cuco/detail/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh similarity index 95% rename from include/cuco/detail/open_addressing_ref_impl.cuh rename to include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 3967cffa3..2432a81b0 100644 --- a/include/cuco/detail/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -106,6 +106,9 @@ class open_addressing_ref_impl { static constexpr auto cg_size = probing_scheme_type::cg_size; ///< Cooperative group size static constexpr auto window_size = storage_ref_type::window_size; ///< Number of elements handled per window + static constexpr auto has_payload = + not std::is_same_v; ///< Determines if the container is a key/value or + ///< key-only store /** * @brief Constructs open_addressing_ref_impl. @@ -154,7 +157,6 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -163,13 +165,13 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(Value const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); auto const key = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { return value.first; } else { return value; @@ -187,7 +189,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EQUAL) { return false; } if (eq_res == detail::equal_result::EMPTY) { auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content); - switch (attempt_insert( + switch (attempt_insert( (storage_ref_.data() + *probing_iter)->data() + intra_window_index, value, predicate)) { case insert_result::CONTINUE: continue; case insert_result::SUCCESS: return true; @@ -202,7 +204,6 @@ class open_addressing_ref_impl { /** * @brief Inserts an element. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -212,13 +213,13 @@ class open_addressing_ref_impl { * * @return True if the given element is successfully inserted */ - template + template __device__ bool insert(cooperative_groups::thread_block_tile const& group, Value const& value, Predicate const& predicate) noexcept { auto const key = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { return value.first; } else { return value; @@ -252,10 +253,9 @@ class open_addressing_ref_impl { auto const src_lane = __ffs(group_contains_empty) - 1; auto const status = (group.thread_rank() == src_lane) - ? attempt_insert( - (storage_ref_.data() + *probing_iter)->data() + intra_window_index, - value, - predicate) + ? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index, + value, + predicate) : insert_result::CONTINUE; switch (group.shfl(status, src_lane)) { @@ -276,7 +276,6 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -286,14 +285,14 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find(Value const& value, Predicate const& predicate) noexcept { static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); auto const key = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { return value.first; } else { return value; @@ -313,7 +312,7 @@ class open_addressing_ref_impl { if (eq_res == detail::equal_result::EMPTY) { switch ([&]() { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(window_ptr + i, value, predicate); + return packed_cas(window_ptr + i, value, predicate); } else { return cas_dependent_write(window_ptr + i, value, predicate); } @@ -339,7 +338,6 @@ class open_addressing_ref_impl { * element that prevented the insertion) and a `bool` denoting whether the insertion took place or * not. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -350,14 +348,14 @@ class open_addressing_ref_impl { * @return a pair consisting of an iterator to the element and a bool indicating whether the * insertion is successful or not. */ - template + template __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, Value const& value, Predicate const& predicate) noexcept { auto const key = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { return value.first; } else { return value; @@ -399,7 +397,7 @@ class open_addressing_ref_impl { auto const status = [&]() { if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot_ptr, value, predicate); + return packed_cas(slot_ptr, value, predicate); } else { return cas_dependent_write(slot_ptr, value, predicate); } @@ -715,7 +713,6 @@ class open_addressing_ref_impl { /** * @brief Inserts the specified element with one single CAS operation. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -725,7 +722,7 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot, Value const& value, Predicate const& predicate) noexcept @@ -733,7 +730,7 @@ class open_addressing_ref_impl { auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast(value)); auto* old_ptr = reinterpret_cast(&old); auto const inserted = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { // If it's a map implementation, compare keys only return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first); } else { @@ -746,7 +743,7 @@ class open_addressing_ref_impl { } else { // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare auto const res = [&]() { - if constexpr (HasPayload) { + if constexpr (this->has_payload) { // If it's a map implementation, compare keys only return predicate.equal_to(old_ptr->first, value.first); } else { @@ -852,7 +849,6 @@ class open_addressing_ref_impl { * @note Dispatches the correct implementation depending on the container * type and presence of other operator mixins. * - * @tparam HasPayload Boolean indicating it's a set or map implementation * @tparam Value Input type which is implicitly convertible to 'value_type' * @tparam Predicate Predicate type * @@ -862,13 +858,13 @@ class open_addressing_ref_impl { * * @return Result of this operation, i.e., success/continue/duplicate */ - template + template [[nodiscard]] __device__ insert_result attempt_insert(value_type* slot, Value const& value, Predicate const& predicate) noexcept { if constexpr (sizeof(value_type) <= 8) { - return packed_cas(slot, value, predicate); + return packed_cas(slot, value, predicate); } else { #if (_CUDA_ARCH__ < 700) return cas_dependent_write(slot, value, predicate); diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 28b3ffaf2..e85b77509 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -248,9 +248,8 @@ class operator_impl< */ __device__ bool insert(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert(value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert(value, ref_.predicate_); } /** @@ -263,9 +262,8 @@ class operator_impl< __device__ bool insert(cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - auto& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert(group, value, ref_.predicate_); + auto& ref_ = static_cast(*this); + return ref_.impl_.insert(group, value, ref_.predicate_); } }; @@ -492,9 +490,8 @@ class operator_impl< */ __device__ thrust::pair insert_and_find(value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert_and_find(value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert_and_find(value, ref_.predicate_); } /** @@ -513,9 +510,8 @@ class operator_impl< __device__ thrust::pair insert_and_find( cooperative_groups::thread_block_tile const& group, value_type const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = true; - return ref_.impl_.insert_and_find(group, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert_and_find(group, value, ref_.predicate_); } }; diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 3dbda9bbf..3b754d972 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -137,9 +137,8 @@ class operator_impl __device__ bool insert(Value const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert(value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert(value, ref_.predicate_); } /** @@ -156,9 +155,8 @@ class operator_impl const& group, Value const& value) noexcept { - auto& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert(group, value, ref_.predicate_); + auto& ref_ = static_cast(*this); + return ref_.impl_.insert(group, value, ref_.predicate_); } }; @@ -224,9 +222,8 @@ class operator_impl __device__ thrust::pair insert_and_find(Value const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert_and_find(value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert_and_find(value, ref_.predicate_); } /** @@ -248,9 +245,8 @@ class operator_impl insert_and_find( cooperative_groups::thread_block_tile const& group, Value const& value) noexcept { - ref_type& ref_ = static_cast(*this); - auto constexpr has_payload = false; - return ref_.impl_.insert_and_find(group, value, ref_.predicate_); + ref_type& ref_ = static_cast(*this); + return ref_.impl_.insert_and_find(group, value, ref_.predicate_); } }; diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh index 4db0d43e7..34fcfc805 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -18,7 +18,7 @@ #include #include -#include +#include #include #include #include diff --git a/include/cuco/static_map_ref.cuh b/include/cuco/static_map_ref.cuh index c41ed88f3..f65b4566b 100644 --- a/include/cuco/static_map_ref.cuh +++ b/include/cuco/static_map_ref.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include #include diff --git a/include/cuco/static_set.cuh b/include/cuco/static_set.cuh index 6d48d5dc8..979bdfead 100644 --- a/include/cuco/static_set.cuh +++ b/include/cuco/static_set.cuh @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include #include diff --git a/include/cuco/static_set_ref.cuh b/include/cuco/static_set_ref.cuh index b2c8158e7..af34b134e 100644 --- a/include/cuco/static_set_ref.cuh +++ b/include/cuco/static_set_ref.cuh @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include #include