Skip to content

Commit 5602381

Browse files
Improve rebind utilities for cuco hash tables (#598)
This PR renames all `with_*` member functions to `rebind_*` for improved clarity. The legacy `with_operators` will be removed once libcudf is migrated to use the new `rebind_operators`. --------- Co-authored-by: Daniel Jünger <[email protected]>
1 parent 9ef3535 commit 5602381

File tree

13 files changed

+280
-157
lines changed

13 files changed

+280
-157
lines changed

include/cuco/detail/probing_scheme/probing_scheme_impl.inl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ __host__ __device__ constexpr linear_probing<CGSize, Hash>::linear_probing(Hash
9595

9696
template <int32_t CGSize, typename Hash>
9797
template <typename NewHash>
98-
__host__ __device__ constexpr auto linear_probing<CGSize, Hash>::with_hash_function(
98+
__host__ __device__ constexpr auto linear_probing<CGSize, Hash>::rebind_hash_function(
9999
NewHash const& hash) const noexcept
100100
{
101101
return linear_probing<cg_size, NewHash>{hash};
@@ -143,28 +143,20 @@ __host__ __device__ constexpr double_hashing<CGSize, Hash1, Hash2>::double_hashi
143143

144144
template <int32_t CGSize, typename Hash1, typename Hash2>
145145
__host__ __device__ constexpr double_hashing<CGSize, Hash1, Hash2>::double_hashing(
146-
cuco::pair<Hash1, Hash2> const& hash)
146+
cuda::std::tuple<Hash1, Hash2> const& hash)
147147
: hash1_{hash.first}, hash2_{hash.second}
148148
{
149149
}
150150

151-
template <int32_t CGSize, typename Hash1, typename Hash2>
152-
template <typename NewHash1, typename NewHash2>
153-
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::with_hash_function(
154-
NewHash1 const& hash1, NewHash2 const& hash2) const noexcept
155-
{
156-
return double_hashing<cg_size, NewHash1, NewHash2>{hash1, hash2};
157-
}
158-
159151
template <int32_t CGSize, typename Hash1, typename Hash2>
160152
template <typename NewHash, typename Enable>
161-
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::with_hash_function(
153+
__host__ __device__ constexpr auto double_hashing<CGSize, Hash1, Hash2>::rebind_hash_function(
162154
NewHash const& hash) const
163155
{
164156
static_assert(cuco::is_tuple_like<NewHash>::value,
165157
"The given hasher must be a tuple-like object");
166158

167-
auto const [hash1, hash2] = cuco::pair{hash};
159+
auto const [hash1, hash2] = cuda::std::tuple{hash};
168160
using hash1_type = cuda::std::decay_t<decltype(hash1)>;
169161
using hash2_type = cuda::std::decay_t<decltype(hash2)>;
170162
return double_hashing<cg_size, hash1_type, hash2_type>{hash1, hash2};

include/cuco/detail/static_map/kernels.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem(
206206
ref.probing_scheme(),
207207
{},
208208
storage};
209-
auto shared_map_ref = std::move(shared_map).with(cuco::op::insert_or_apply);
209+
auto shared_map_ref = shared_map.rebind_operators(cuco::op::insert_or_apply);
210210
shared_map_ref.initialize(block);
211211
block.sync();
212212

@@ -262,4 +262,4 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem(
262262
}
263263
}
264264
}
265-
} // namespace cuco::static_map_ns::detail
265+
} // namespace cuco::static_map_ns::detail

include/cuco/detail/static_map/static_map_ref.inl

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,17 @@ template <typename Key,
296296
typename StorageRef,
297297
typename... Operators>
298298
template <typename... NewOperators>
299-
auto static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with(
300-
NewOperators...) && noexcept
299+
__host__ __device__ constexpr auto
300+
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with_operators(
301+
NewOperators...) const noexcept
301302
{
302303
return static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
303-
std::move(*this)};
304+
cuco::empty_key<Key>{this->empty_key_sentinel()},
305+
cuco::empty_value<T>{this->empty_value_sentinel()},
306+
this->key_eq(),
307+
this->probing_scheme(),
308+
{},
309+
this->storage_ref()};
304310
}
305311

306312
template <typename Key,
@@ -311,22 +317,65 @@ template <typename Key,
311317
typename StorageRef,
312318
typename... Operators>
313319
template <typename... NewOperators>
314-
__host__ __device__ auto constexpr static_map_ref<Key,
315-
T,
316-
Scope,
317-
KeyEqual,
318-
ProbingScheme,
319-
StorageRef,
320-
Operators...>::with_operators(NewOperators...)
321-
const noexcept
320+
__host__ __device__ constexpr auto
321+
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::rebind_operators(
322+
NewOperators...) const noexcept
322323
{
323324
return static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
324325
cuco::empty_key<Key>{this->empty_key_sentinel()},
325326
cuco::empty_value<T>{this->empty_value_sentinel()},
326327
this->key_eq(),
327-
this->impl_.probing_scheme(),
328+
this->probing_scheme(),
328329
{},
329-
this->impl_.storage_ref()};
330+
this->storage_ref()};
331+
}
332+
333+
template <typename Key,
334+
typename T,
335+
cuda::thread_scope Scope,
336+
typename KeyEqual,
337+
typename ProbingScheme,
338+
typename StorageRef,
339+
typename... Operators>
340+
template <typename NewKeyEqual>
341+
__host__ __device__ constexpr auto
342+
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::rebind_key_eq(
343+
NewKeyEqual const& key_equal) const noexcept
344+
{
345+
return static_map_ref<Key, T, Scope, NewKeyEqual, ProbingScheme, StorageRef, Operators...>{
346+
cuco::empty_key<Key>{this->empty_key_sentinel()},
347+
cuco::empty_value<T>{this->empty_value_sentinel()},
348+
key_equal,
349+
this->probing_scheme(),
350+
{},
351+
this->storage_ref()};
352+
}
353+
354+
template <typename Key,
355+
typename T,
356+
cuda::thread_scope Scope,
357+
typename KeyEqual,
358+
typename ProbingScheme,
359+
typename StorageRef,
360+
typename... Operators>
361+
template <typename NewHash>
362+
__host__ __device__ constexpr auto
363+
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
364+
rebind_hash_function(NewHash const& hash) const
365+
{
366+
auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash);
367+
return static_map_ref<Key,
368+
T,
369+
Scope,
370+
KeyEqual,
371+
cuda::std::decay_t<decltype(probing_scheme)>,
372+
StorageRef,
373+
Operators...>{cuco::empty_key<Key>{this->empty_key_sentinel()},
374+
cuco::empty_value<T>{this->empty_value_sentinel()},
375+
this->key_eq(),
376+
probing_scheme,
377+
{},
378+
this->storage_ref()};
330379
}
331380

332381
template <typename Key,
@@ -349,7 +398,7 @@ static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>
349398
cuco::empty_value<T>{this->empty_value_sentinel()},
350399
cuco::erased_key<Key>{this->erased_key_sentinel()},
351400
this->key_eq(),
352-
this->impl_.probing_scheme(),
401+
this->probing_scheme(),
353402
scope,
354403
storage_ref_type{this->window_extent(), memory_to_use}};
355404
}

include/cuco/detail/static_multimap/static_multimap_ref.inl

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,11 +295,22 @@ template <typename Key,
295295
typename StorageRef,
296296
typename... Operators>
297297
template <typename... NewOperators>
298-
auto static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with(
299-
NewOperators...) && noexcept
298+
__host__ __device__ auto constexpr static_multimap_ref<
299+
Key,
300+
T,
301+
Scope,
302+
KeyEqual,
303+
ProbingScheme,
304+
StorageRef,
305+
Operators...>::with_operators(NewOperators...) const noexcept
300306
{
301307
return static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
302-
std::move(*this)};
308+
cuco::empty_key<Key>{this->empty_key_sentinel()},
309+
cuco::empty_value<T>{this->empty_value_sentinel()},
310+
this->key_eq(),
311+
this->probing_scheme(),
312+
{},
313+
impl_.storage_ref()};
303314
}
304315

305316
template <typename Key,
@@ -317,15 +328,63 @@ __host__ __device__ auto constexpr static_multimap_ref<
317328
KeyEqual,
318329
ProbingScheme,
319330
StorageRef,
320-
Operators...>::with_operators(NewOperators...) const noexcept
331+
Operators...>::rebind_operators(NewOperators...) const noexcept
321332
{
322333
return static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
323334
cuco::empty_key<Key>{this->empty_key_sentinel()},
324335
cuco::empty_value<T>{this->empty_value_sentinel()},
325336
this->key_eq(),
326337
impl_.probing_scheme(),
327338
{},
328-
impl_.storage_ref()};
339+
this->storage_ref()};
340+
}
341+
342+
template <typename Key,
343+
typename T,
344+
cuda::thread_scope Scope,
345+
typename KeyEqual,
346+
typename ProbingScheme,
347+
typename StorageRef,
348+
typename... Operators>
349+
template <typename NewKeyEqual>
350+
__host__ __device__ constexpr auto
351+
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
352+
rebind_key_eq(NewKeyEqual const& key_equal) const noexcept
353+
{
354+
return static_multimap_ref<Key, T, Scope, NewKeyEqual, ProbingScheme, StorageRef, Operators...>{
355+
cuco::empty_key<Key>{this->empty_key_sentinel()},
356+
cuco::empty_value<T>{this->empty_value_sentinel()},
357+
key_equal,
358+
this->probing_scheme(),
359+
{},
360+
this->storage_ref()};
361+
}
362+
363+
template <typename Key,
364+
typename T,
365+
cuda::thread_scope Scope,
366+
typename KeyEqual,
367+
typename ProbingScheme,
368+
typename StorageRef,
369+
typename... Operators>
370+
template <typename NewHash>
371+
__host__ __device__ constexpr auto
372+
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
373+
rebind_hash_function(NewHash const& hash) const
374+
{
375+
auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash);
376+
return static_multimap_ref<Key,
377+
T,
378+
Scope,
379+
KeyEqual,
380+
cuda::std::decay_t<decltype(probing_scheme)>,
381+
StorageRef,
382+
Operators...>{cuco::empty_key<Key>{this->empty_key_sentinel()},
383+
cuco::empty_value<T>{this->empty_value_sentinel()},
384+
this->key_eq(),
385+
probing_scheme,
386+
{},
387+
this->storage_ref()};
329388
}
330389

331390
template <typename Key,

include/cuco/detail/static_multiset/static_multiset.inl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,11 @@ static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
308308
ProbeHash const& probe_hash,
309309
cuda::stream_ref stream) const
310310
{
311-
return impl_->count(first,
312-
last,
313-
ref(op::count).with_key_eq(probe_key_equal).with_hash_function(probe_hash),
314-
stream);
311+
return impl_->count(
312+
first,
313+
last,
314+
ref(op::count).rebind_key_eq(probe_key_equal).rebind_hash_function(probe_hash),
315+
stream);
315316
}
316317

317318
template <class Key,
@@ -333,7 +334,7 @@ static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
333334
return impl_->count_outer(
334335
first,
335336
last,
336-
ref(op::count).with_key_eq(probe_key_equal).with_hash_function(probe_hash),
337+
ref(op::count).rebind_key_eq(probe_key_equal).rebind_hash_function(probe_hash),
337338
stream);
338339
}
339340

include/cuco/detail/static_multiset/static_multiset_ref.inl

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,16 @@ template <typename Key,
251251
typename StorageRef,
252252
typename... Operators>
253253
template <typename... NewOperators>
254-
auto static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with(
255-
NewOperators...) && noexcept
254+
__host__ __device__ constexpr auto
255+
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with_operators(
256+
NewOperators...) const noexcept
256257
{
257258
return static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
258-
std::move(*this)};
259+
cuco::empty_key<Key>{this->empty_key_sentinel()},
260+
this->key_eq(),
261+
this->probing_scheme(),
262+
{},
263+
this->storage_ref()};
259264
}
260265

261266
template <typename Key,
@@ -266,15 +271,15 @@ template <typename Key,
266271
typename... Operators>
267272
template <typename... NewOperators>
268273
__host__ __device__ constexpr auto
269-
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with_operators(
270-
NewOperators...) const noexcept
274+
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
275+
rebind_operators(NewOperators...) const noexcept
271276
{
272277
return static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, NewOperators...>{
273278
cuco::empty_key<Key>{this->empty_key_sentinel()},
274279
this->key_eq(),
275-
this->impl_.probing_scheme(),
280+
this->probing_scheme(),
276281
{},
277-
this->impl_.storage_ref()};
282+
this->storage_ref()};
278283
}
279284

280285
template <typename Key,
@@ -285,15 +290,15 @@ template <typename Key,
285290
typename... Operators>
286291
template <typename NewKeyEqual>
287292
__host__ __device__ constexpr auto
288-
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::with_key_eq(
293+
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::rebind_key_eq(
289294
NewKeyEqual const& key_equal) const noexcept
290295
{
291296
return static_multiset_ref<Key, Scope, NewKeyEqual, ProbingScheme, StorageRef, Operators...>{
292297
cuco::empty_key<Key>{this->empty_key_sentinel()},
293298
key_equal,
294-
this->impl_.probing_scheme(),
299+
this->probing_scheme(),
295300
{},
296-
this->impl_.storage_ref()};
301+
this->storage_ref()};
297302
}
298303

299304
template <typename Key,
@@ -305,19 +310,19 @@ template <typename Key,
305310
template <typename NewHash>
306311
__host__ __device__ constexpr auto
307312
static_multiset_ref<Key, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>::
308-
with_hash_function(NewHash const& hash) const
313+
rebind_hash_function(NewHash const& hash) const
309314
{
310-
auto const probing_scheme = this->impl_.probing_scheme().with_hash_function(hash);
315+
auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash);
311316
return static_multiset_ref<Key,
312317
Scope,
313318
KeyEqual,
314319
cuda::std::decay_t<decltype(probing_scheme)>,
315320
StorageRef,
316321
Operators...>{cuco::empty_key<Key>{this->empty_key_sentinel()},
317-
this->impl_.key_eq(),
322+
this->key_eq(),
318323
probing_scheme,
319324
{},
320-
this->impl_.storage_ref()};
325+
this->storage_ref()};
321326
}
322327

323328
namespace detail {

0 commit comments

Comments
 (0)