Skip to content

Commit 3ae1afe

Browse files
fix sanitize_hash edge case
1 parent 56efe03 commit 3ae1afe

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

include/cuco/detail/probing_scheme_impl.inl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ __host__ __device__ constexpr auto linear_probing<CGSize, Hash>::operator()(
121121
{
122122
using size_type = typename Extent::value_type;
123123
return detail::probing_iterator<Extent>{
124-
cuco::detail::sanitize_hash<size_type>(hash_(probe_key), g.thread_rank()) % upper_bound,
124+
cuco::detail::sanitize_hash<size_type>(hash_(probe_key), g) % upper_bound,
125125
cg_size,
126126
upper_bound};
127127
}
@@ -164,7 +164,7 @@ __host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::operato
164164
{
165165
using size_type = typename Extent::value_type;
166166
return detail::probing_iterator<Extent>{
167-
cuco::detail::sanitize_hash<size_type>(hash1_(probe_key), g.thread_rank()) % upper_bound,
167+
cuco::detail::sanitize_hash<size_type>(hash1_(probe_key), g) % upper_bound,
168168
static_cast<size_type>(
169169
(cuco::detail::sanitize_hash<size_type>(hash2_(probe_key)) % (upper_bound / cg_size - 1) +
170170
1) *

include/cuco/detail/utils.cuh

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
#include <cuda/std/array>
2121
#include <cuda/std/bit>
2222
#include <cuda/std/cmath>
23+
#include <cuda/std/limits>
2324
#include <cuda/std/type_traits>
2425
#include <thrust/tuple.h>
2526

27+
#include <cstddef>
28+
2629
namespace cuco {
2730
namespace detail {
2831

@@ -122,13 +125,19 @@ __host__ __device__ constexpr SizeType sanitize_hash(HashType hash) noexcept
122125
*
123126
* @tparam SizeType The target type
124127
* @tparam HashType The input type
128+
* @tparam CG Cooperative group type
125129
*
126130
* @return Converted hash value
127131
*/
128-
template <typename SizeType, typename HashType>
129-
__host__ __device__ constexpr SizeType sanitize_hash(HashType hash, std::uint32_t cg_rank) noexcept
132+
template <typename SizeType, typename HashType, typename CG>
133+
__host__ __device__ constexpr SizeType sanitize_hash(HashType hash, CG group) noexcept
130134
{
131-
return sanitize_hash<SizeType>(sanitize_hash<SizeType>(hash) + cg_rank);
135+
auto const base_hash = sanitize_hash<SizeType>(hash);
136+
auto const max_size = cuda::std::numeric_limits<SizeType>::max();
137+
auto const cg_rank = static_cast<SizeType>(group.thread_rank());
138+
139+
if (base_hash > (max_size - cg_rank)) return cg_rank - (max_size - base_hash);
140+
return base_hash + cg_rank;
132141
}
133142

134143
} // namespace detail

0 commit comments

Comments
 (0)