Skip to content

Commit

Permalink
Remove HasPayload tparam from OA impl classes (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
sleeepyjack authored Oct 4, 2023
1 parent fd23a3d commit b4657fd
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class open_addressing_ref_impl {
static constexpr auto cg_size = probing_scheme_type::cg_size; ///< Cooperative group size
static constexpr auto window_size =
storage_ref_type::window_size; ///< Number of elements handled per window
static constexpr auto has_payload =
not std::is_same_v<key_type, value_type>; ///< Determines if the container is a key/value or
///< key-only store

/**
* @brief Constructs open_addressing_ref_impl.
Expand Down Expand Up @@ -154,7 +157,6 @@ class open_addressing_ref_impl {
/**
* @brief Inserts an element.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
Expand All @@ -163,13 +165,13 @@ class open_addressing_ref_impl {
*
* @return True if the given element is successfully inserted
*/
template <bool HasPayload, typename Value, typename Predicate>
template <typename Value, typename Predicate>
__device__ bool insert(Value const& value, Predicate const& predicate) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

auto const key = [&]() {
if constexpr (HasPayload) {
if constexpr (this->has_payload) {
return value.first;
} else {
return value;
Expand All @@ -187,7 +189,7 @@ class open_addressing_ref_impl {
if (eq_res == detail::equal_result::EQUAL) { return false; }
if (eq_res == detail::equal_result::EMPTY) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
switch (attempt_insert<HasPayload>(
switch (attempt_insert(
(storage_ref_.data() + *probing_iter)->data() + intra_window_index, value, predicate)) {
case insert_result::CONTINUE: continue;
case insert_result::SUCCESS: return true;
Expand All @@ -202,7 +204,6 @@ class open_addressing_ref_impl {
/**
* @brief Inserts an element.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
Expand All @@ -212,13 +213,13 @@ class open_addressing_ref_impl {
*
* @return True if the given element is successfully inserted
*/
template <bool HasPayload, typename Value, typename Predicate>
template <typename Value, typename Predicate>
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
Value const& value,
Predicate const& predicate) noexcept
{
auto const key = [&]() {
if constexpr (HasPayload) {
if constexpr (this->has_payload) {
return value.first;
} else {
return value;
Expand Down Expand Up @@ -252,10 +253,9 @@ class open_addressing_ref_impl {
auto const src_lane = __ffs(group_contains_empty) - 1;
auto const status =
(group.thread_rank() == src_lane)
? attempt_insert<HasPayload>(
(storage_ref_.data() + *probing_iter)->data() + intra_window_index,
value,
predicate)
? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
value,
predicate)
: insert_result::CONTINUE;

switch (group.shfl(status, src_lane)) {
Expand All @@ -276,7 +276,6 @@ class open_addressing_ref_impl {
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
Expand All @@ -286,14 +285,14 @@ class open_addressing_ref_impl {
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <bool HasPayload, typename Value, typename Predicate>
template <typename Value, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(Value const& value,
Predicate const& predicate) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

auto const key = [&]() {
if constexpr (HasPayload) {
if constexpr (this->has_payload) {
return value.first;
} else {
return value;
Expand All @@ -313,7 +312,7 @@ class open_addressing_ref_impl {
if (eq_res == detail::equal_result::EMPTY) {
switch ([&]() {
if constexpr (sizeof(value_type) <= 8) {
return packed_cas<HasPayload>(window_ptr + i, value, predicate);
return packed_cas(window_ptr + i, value, predicate);
} else {
return cas_dependent_write(window_ptr + i, value, predicate);
}
Expand All @@ -339,7 +338,6 @@ class open_addressing_ref_impl {
* element that prevented the insertion) and a `bool` denoting whether the insertion took place or
* not.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
Expand All @@ -350,14 +348,14 @@ class open_addressing_ref_impl {
* @return a pair consisting of an iterator to the element and a bool indicating whether the
* insertion is successful or not.
*/
template <bool HasPayload, typename Value, typename Predicate>
template <typename Value, typename Predicate>
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group,
Value const& value,
Predicate const& predicate) noexcept
{
auto const key = [&]() {
if constexpr (HasPayload) {
if constexpr (this->has_payload) {
return value.first;
} else {
return value;
Expand Down Expand Up @@ -399,7 +397,7 @@ class open_addressing_ref_impl {
auto const status = [&]() {
if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; }
if constexpr (sizeof(value_type) <= 8) {
return packed_cas<HasPayload>(slot_ptr, value, predicate);
return packed_cas(slot_ptr, value, predicate);
} else {
return cas_dependent_write(slot_ptr, value, predicate);
}
Expand Down Expand Up @@ -715,7 +713,6 @@ class open_addressing_ref_impl {
/**
* @brief Inserts the specified element with one single CAS operation.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
Expand All @@ -725,15 +722,15 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <bool HasPayload, typename Value, typename Predicate>
template <typename Value, typename Predicate>
[[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot,
Value const& value,
Predicate const& predicate) noexcept
{
auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast<value_type>(value));
auto* old_ptr = reinterpret_cast<value_type*>(&old);
auto const inserted = [&]() {
if constexpr (HasPayload) {
if constexpr (this->has_payload) {
// If it's a map implementation, compare keys only
return cuco::detail::bitwise_compare(old_ptr->first, this->empty_slot_sentinel_.first);
} else {
Expand All @@ -746,7 +743,7 @@ class open_addressing_ref_impl {
} else {
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
auto const res = [&]() {
if constexpr (HasPayload) {
if constexpr (this->has_payload) {
// If it's a map implementation, compare keys only
return predicate.equal_to(old_ptr->first, value.first);
} else {
Expand Down Expand Up @@ -852,7 +849,6 @@ class open_addressing_ref_impl {
* @note Dispatches the correct implementation depending on the container
* type and presence of other operator mixins.
*
* @tparam HasPayload Boolean indicating it's a set or map implementation
* @tparam Value Input type which is implicitly convertible to 'value_type'
* @tparam Predicate Predicate type
*
Expand All @@ -862,13 +858,13 @@ class open_addressing_ref_impl {
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <bool HasPayload, typename Value, typename Predicate>
template <typename Value, typename Predicate>
[[nodiscard]] __device__ insert_result attempt_insert(value_type* slot,
Value const& value,
Predicate const& predicate) noexcept
{
if constexpr (sizeof(value_type) <= 8) {
return packed_cas<HasPayload>(slot, value, predicate);
return packed_cas(slot, value, predicate);
} else {
#if (_CUDA_ARCH__ < 700)
return cas_dependent_write(slot, value, predicate);
Expand Down
20 changes: 8 additions & 12 deletions include/cuco/detail/static_map/static_map_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,8 @@ class operator_impl<
*/
__device__ bool insert(value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert<has_payload>(value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(value, ref_.predicate_);
}

/**
Expand All @@ -263,9 +262,8 @@ class operator_impl<
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
value_type const& value) noexcept
{
auto& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert<has_payload>(group, value, ref_.predicate_);
auto& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(group, value, ref_.predicate_);
}
};

Expand Down Expand Up @@ -492,9 +490,8 @@ class operator_impl<
*/
__device__ thrust::pair<iterator, bool> insert_and_find(value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert_and_find<has_payload>(value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(value, ref_.predicate_);
}

/**
Expand All @@ -513,9 +510,8 @@ class operator_impl<
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, value_type const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = true;
return ref_.impl_.insert_and_find<has_payload>(group, value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(group, value, ref_.predicate_);
}
};

Expand Down
20 changes: 8 additions & 12 deletions include/cuco/detail/static_set/static_set_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ class operator_impl<op::insert_tag,
template <typename Value>
__device__ bool insert(Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert<has_payload>(value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(value, ref_.predicate_);
}

/**
Expand All @@ -156,9 +155,8 @@ class operator_impl<op::insert_tag,
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
Value const& value) noexcept
{
auto& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert<has_payload>(group, value, ref_.predicate_);
auto& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert(group, value, ref_.predicate_);
}
};

Expand Down Expand Up @@ -224,9 +222,8 @@ class operator_impl<op::insert_and_find_tag,
template <typename Value>
__device__ thrust::pair<iterator, bool> insert_and_find(Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert_and_find<has_payload>(value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(value, ref_.predicate_);
}

/**
Expand All @@ -248,9 +245,8 @@ class operator_impl<op::insert_and_find_tag,
__device__ thrust::pair<iterator, bool> insert_and_find(
cooperative_groups::thread_block_tile<cg_size> const& group, Value const& value) noexcept
{
ref_type& ref_ = static_cast<ref_type&>(*this);
auto constexpr has_payload = false;
return ref_.impl_.insert_and_find<has_payload>(group, value, ref_.predicate_);
ref_type& ref_ = static_cast<ref_type&>(*this);
return ref_.impl_.insert_and_find(group, value, ref_.predicate_);
}
};

Expand Down
2 changes: 1 addition & 1 deletion include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <cuco/cuda_stream_ref.hpp>
#include <cuco/detail/__config>
#include <cuco/detail/open_addressing_impl.cuh>
#include <cuco/detail/open_addressing/open_addressing_impl.cuh>
#include <cuco/detail/static_map_kernels.cuh>
#include <cuco/hash_functions.cuh>
#include <cuco/pair.cuh>
Expand Down
2 changes: 1 addition & 1 deletion include/cuco/static_map_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#pragma once

#include <cuco/detail/open_addressing_ref_impl.cuh>
#include <cuco/detail/open_addressing/open_addressing_ref_impl.cuh>
#include <cuco/hash_functions.cuh>
#include <cuco/operator.hpp>
#include <cuco/probing_scheme.cuh>
Expand Down
2 changes: 1 addition & 1 deletion include/cuco/static_set.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#pragma once

#include <cuco/cuda_stream_ref.hpp>
#include <cuco/detail/open_addressing_impl.cuh>
#include <cuco/detail/open_addressing/open_addressing_impl.cuh>
#include <cuco/extent.cuh>
#include <cuco/hash_functions.cuh>
#include <cuco/probing_scheme.cuh>
Expand Down
2 changes: 1 addition & 1 deletion include/cuco/static_set_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#pragma once

#include <cuco/detail/equal_wrapper.cuh>
#include <cuco/detail/open_addressing_ref_impl.cuh>
#include <cuco/detail/open_addressing/open_addressing_ref_impl.cuh>
#include <cuco/hash_functions.cuh>
#include <cuco/operator.hpp>
#include <cuco/probing_scheme.cuh>
Expand Down

0 comments on commit b4657fd

Please sign in to comment.