Skip to content

Commit

Permalink
Update attempt_insert to incorporate with erase
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Oct 6, 2023
1 parent 9326a91 commit cb17023
Showing 1 changed file with 55 additions and 50 deletions.
105 changes: 55 additions & 50 deletions include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ class open_addressing_ref_impl {
if (eq_res == detail::equal_result::AVAILABLE) {
auto const intra_window_index = thrust::distance(window_slots.begin(), &slot_content);
switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
slot_content,
value)) {
case insert_result::CONTINUE: continue;
case insert_result::SUCCESS: return true;
Expand Down Expand Up @@ -341,6 +342,7 @@ class open_addressing_ref_impl {
auto const status =
(group.thread_rank() == src_lane)
? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
window_slots[src_lane],
value)
: insert_result::CONTINUE;

Expand Down Expand Up @@ -389,9 +391,9 @@ class open_addressing_ref_impl {
if (eq_res == detail::equal_result::AVAILABLE) {
switch ([&]() {
if constexpr (sizeof(value_type) <= 8) {
return packed_cas(window_ptr + i, value);
return packed_cas(window_ptr + i, window_slots[i], value);
} else {
return cas_dependent_write(window_ptr + i, value);
return cas_dependent_write(window_ptr + i, window_slots[i], value);
}
}()) {
case insert_result::SUCCESS: {
Expand Down Expand Up @@ -464,9 +466,9 @@ class open_addressing_ref_impl {
auto const status = [&]() {
if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; }
if constexpr (sizeof(value_type) <= 8) {
return packed_cas(slot_ptr, value);
return packed_cas(slot_ptr, window_slots[src_lane], value);
} else {
return cas_dependent_write(slot_ptr, value);
return cas_dependent_write(slot_ptr, window_slots[src_lane], value);
}
}();

Expand All @@ -485,25 +487,19 @@ class open_addressing_ref_impl {
}
}

template <typename Value, typename Predicate>
__device__ bool erase(Value const& value, Predicate const& predicate) noexcept
template <typename Value>
__device__ bool erase(Value const& value) noexcept
{
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");

auto const key = [&]() {
if constexpr (this->has_payload) {
return value.first;
} else {
return value;
}
}();
auto const key = this->extract_key(value);
auto probing_iter = probing_scheme_(key, storage_ref_.window_extent());

while (true) {
auto const window_slots = storage_ref_[*probing_iter];

for (auto& slot_content : window_slots) {
auto const eq_res = predicate(slot_content, key);
auto const eq_res = this->predicate_(this->extract_key(slot_content), key);

// Key doesn't exist, return false
if (eq_res == detail::equal_result::AVAILABLE) { return false; }
Expand All @@ -517,10 +513,9 @@ class open_addressing_ref_impl {
return this->erased_key_sentinel();
}
}();
switch (attempt_insert<this->has_payload>(
(storage_ref_.data() + *probing_iter)->data() + intra_window_index,
erased_slot,
predicate)) {
switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_window_index,
slot_content,
erased_slot)) {
case insert_result::CONTINUE: continue;
case insert_result::SUCCESS: return true;
}
Expand Down Expand Up @@ -821,16 +816,19 @@ class open_addressing_ref_impl {
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param slot Pointer to the slot in memory
* @param value Element to insert
* @param address Pointer to the slot in memory
* @param expected Element to compare against
* @param desired Element to insert
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <typename Value>
[[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* slot,
Value const& value) noexcept
[[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* address,
value_type const& expected,
Value const& desired) noexcept
{
auto old = compare_and_swap(slot, this->empty_slot_sentinel_, static_cast<value_type>(value));
auto old =
compare_and_swap(address, this->empty_slot_sentinel_, static_cast<value_type>(desired));
auto* old_ptr = reinterpret_cast<value_type*>(&old);
auto const inserted = [&]() {
if constexpr (this->has_payload) {
Expand All @@ -848,10 +846,10 @@ class open_addressing_ref_impl {
auto const res = [&]() {
if constexpr (this->has_payload) {
// If it's a map implementation, compare keys only
return this->predicate_.equal_to(old_ptr->first, value.first);
return this->predicate_.equal_to(old_ptr->first, desired.first);
} else {
// If it's a set implementation, compare the whole slot content
return this->predicate_.equal_to(*old_ptr, value);
return this->predicate_.equal_to(*old_ptr, desired);
}
}();
return res == detail::equal_result::EQUAL ? insert_result::DUPLICATE
Expand All @@ -864,41 +862,44 @@ class open_addressing_ref_impl {
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param slot Pointer to the slot in memory
* @param value Element to insert
* @param address Pointer to the slot in memory
* @param expected Element to compare against
* @param desired Element to insert
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <typename Value>
[[nodiscard]] __device__ constexpr insert_result back_to_back_cas(value_type* slot,
Value const& value) noexcept
[[nodiscard]] __device__ constexpr insert_result back_to_back_cas(value_type* address,
value_type const& expected,
Value const& desired) noexcept
{
using mapped_type = decltype(this->empty_slot_sentinel_.second);

auto const expected_key = this->empty_slot_sentinel_.first;
auto const expected_payload = this->empty_slot_sentinel_.second;

auto old_key = compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first));
auto old_payload =
compare_and_swap(&slot->second, expected_payload, static_cast<mapped_type>(value.second));
auto old_key =
compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));
auto old_payload = compare_and_swap(
&address->second, expected_payload, static_cast<mapped_type>(desired.second));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);
auto* old_payload_ptr = reinterpret_cast<mapped_type*>(&old_payload);

// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
while (not cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
old_payload =
compare_and_swap(&slot->second, expected_payload, static_cast<mapped_type>(value.second));
old_payload = compare_and_swap(
&address->second, expected_payload, static_cast<mapped_type>(desired.second));
}
return insert_result::SUCCESS;
} else if (cuco::detail::bitwise_compare(*old_payload_ptr, expected_payload)) {
atomic_store(&slot->second, expected_payload);
atomic_store(&address->second, expected_payload);
}

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr, value.first) == detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand All @@ -910,32 +911,34 @@ class open_addressing_ref_impl {
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param slot Pointer to the slot in memory
* @param value Element to insert
* @param address Pointer to the slot in memory
* @param expected Element to compare against
* @param desired Element to insert
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <typename Value>
[[nodiscard]] __device__ constexpr insert_result cas_dependent_write(value_type* slot,
Value const& value) noexcept
[[nodiscard]] __device__ constexpr insert_result cas_dependent_write(
value_type* address, value_type const& expected, Value const& desired) noexcept
{
using mapped_type = decltype(this->empty_slot_sentinel_.second);

auto const expected_key = this->empty_slot_sentinel_.first;

auto old_key = compare_and_swap(&slot->first, expected_key, static_cast<key_type>(value.first));
auto old_key =
compare_and_swap(&address->first, expected_key, static_cast<key_type>(desired.first));

auto* old_key_ptr = reinterpret_cast<key_type*>(&old_key);

// if key success
if (cuco::detail::bitwise_compare(*old_key_ptr, expected_key)) {
atomic_store(&slot->second, static_cast<mapped_type>(value.second));
atomic_store(&address->second, static_cast<mapped_type>(desired.second));
return insert_result::SUCCESS;
}

// Our key was already present in the slot, so our key is a duplicate
// Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare
if (this->predicate_.equal_to(*old_key_ptr, value.first) == detail::equal_result::EQUAL) {
if (this->predicate_.equal_to(*old_key_ptr, desired.first) == detail::equal_result::EQUAL) {
return insert_result::DUPLICATE;
}

Expand All @@ -950,22 +953,24 @@ class open_addressing_ref_impl {
*
* @tparam Value Input type which is implicitly convertible to 'value_type'
*
* @param slot Pointer to the slot in memory
* @param value Element to insert
* @param address Pointer to the slot in memory
* @param expected Element to compare against
* @param desired Element to insert
*
* @return Result of this operation, i.e., success/continue/duplicate
*/
template <typename Value>
[[nodiscard]] __device__ insert_result attempt_insert(value_type* slot,
Value const& value) noexcept
[[nodiscard]] __device__ insert_result attempt_insert(value_type* address,
value_type const& expected,
Value const& desired) noexcept
{
if constexpr (sizeof(value_type) <= 8) {
return packed_cas(slot, value);
return packed_cas(address, expected, desired);
} else {
#if (_CUDA_ARCH__ < 700)
return cas_dependent_write(slot, value);
return cas_dependent_write(address, expected, desired);
#else
return back_to_back_cas(slot, value);
return back_to_back_cas(address, expected, desired);
#endif
}
}
Expand Down

0 comments on commit cb17023

Please sign in to comment.