diff --git a/include/cuco/detail/probing_scheme_base.cuh b/include/cuco/detail/probing_scheme_base.cuh index a3d7c148a..9ce06da92 100644 --- a/include/cuco/detail/probing_scheme_base.cuh +++ b/include/cuco/detail/probing_scheme_base.cuh @@ -16,6 +16,8 @@ #pragma once +#include + #include namespace cuco { @@ -30,6 +32,34 @@ namespace detail { */ template class probing_scheme_base { + private: + template + __host__ __device__ constexpr SizeType sanitize_hash_positive(HashType hash) const noexcept + { + if constexpr (cuda::std::is_signed_v) { + return cuda::std::abs(static_cast(hash)); + } else { + return static_cast(hash); + } + } + + protected: + template + __host__ __device__ constexpr SizeType sanitize_hash(HashType hash) const noexcept + { + if constexpr (cuda::std::is_same_v>) { +#if !defined(CUCO_HAS_INT128) + static_assert(false, + "CUCO_HAS_INT128 undefined. Need unsigned __int128 type when sanitizing " + "cuda::std::array"); +#endif + unsigned __int128 ret{}; + memcpy(&ret, &hash, sizeof(unsigned __int128)); + return sanitize_hash_positive(static_cast(ret)); + } else + return sanitize_hash_positive(hash); + } + public: /** * @brief The size of the CUDA cooperative thread group. diff --git a/include/cuco/detail/probing_scheme_impl.inl b/include/cuco/detail/probing_scheme_impl.inl index 50d7c4dcc..33998168e 100644 --- a/include/cuco/detail/probing_scheme_impl.inl +++ b/include/cuco/detail/probing_scheme_impl.inl @@ -107,7 +107,7 @@ __host__ __device__ constexpr auto linear_probing::operator()( { using size_type = typename Extent::value_type; return detail::probing_iterator{ - cuco::detail::sanitize_hash(hash_(probe_key)) % upper_bound, + probing_scheme_base_type::template sanitize_hash(hash_(probe_key)) % upper_bound, 1, // step size is 1 upper_bound}; } @@ -121,7 +121,10 @@ __host__ __device__ constexpr auto linear_probing::operator()( { using size_type = typename Extent::value_type; return detail::probing_iterator{ - cuco::detail::sanitize_hash(hash_(probe_key) + g.thread_rank()) % upper_bound, + probing_scheme_base_type::template sanitize_hash( + probing_scheme_base_type::template sanitize_hash(hash_(probe_key)) + + g.thread_rank()) % + upper_bound, cg_size, upper_bound}; } @@ -148,9 +151,9 @@ __host__ __device__ constexpr auto double_hashing::operato { using size_type = typename Extent::value_type; return detail::probing_iterator{ - cuco::detail::sanitize_hash(hash1_(probe_key)) % upper_bound, + probing_scheme_base_type::template sanitize_hash(hash1_(probe_key)) % upper_bound, max(size_type{1}, - cuco::detail::sanitize_hash(hash2_(probe_key)) % + probing_scheme_base_type::template sanitize_hash(hash2_(probe_key)) % upper_bound), // step size in range [1, prime - 1] upper_bound}; } @@ -164,9 +167,13 @@ __host__ __device__ constexpr auto double_hashing::operato { using size_type = typename Extent::value_type; return detail::probing_iterator{ - cuco::detail::sanitize_hash(hash1_(probe_key) + g.thread_rank()) % upper_bound, + probing_scheme_base_type::template sanitize_hash( + probing_scheme_base_type::template sanitize_hash(hash1_(probe_key)) + + g.thread_rank()) % + upper_bound, static_cast( - (cuco::detail::sanitize_hash(hash2_(probe_key)) % (upper_bound / cg_size - 1) + + (probing_scheme_base_type::template sanitize_hash(hash2_(probe_key)) % + (upper_bound / cg_size - 1) + 1) * cg_size), upper_bound}; // TODO use fast_int operator diff --git a/include/cuco/detail/utils.cuh b/include/cuco/detail/utils.cuh index 1cbe8fd26..f2aecc0ef 100644 --- a/include/cuco/detail/utils.cuh +++ b/include/cuco/detail/utils.cuh @@ -17,6 +17,7 @@ #include +#include #include #include #include @@ -81,23 +82,5 @@ struct slot_is_filled { } }; -/** - * @brief Converts a given hash value into a valid (positive) size type. - * - * @tparam SizeType The target type - * @tparam HashType The input type - * - * @return Converted hash value - */ -template -__host__ __device__ constexpr SizeType sanitize_hash(HashType hash) noexcept -{ - if constexpr (cuda::std::is_signed_v) { - return cuda::std::abs(static_cast(hash)); - } else { - return static_cast(hash); - } -} - } // namespace detail } // namespace cuco diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a37f2d4e2..9d75d7a0e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -76,6 +76,7 @@ ConfigureTest(STATIC_MAP_TEST static_map/custom_type_test.cu static_map/duplicate_keys_test.cu static_map/erase_test.cu + static_map/hash_test.cu static_map/heterogeneous_lookup_test.cu static_map/insert_and_find_test.cu static_map/insert_or_assign_test.cu diff --git a/tests/static_map/hash_test.cu b/tests/static_map/hash_test.cu new file mode 100644 index 000000000..c22eae998 --- /dev/null +++ b/tests/static_map/hash_test.cu @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include +#include +#include + +#include + +using size_type = std::size_t; + +template +void test_hash_function() +{ + using Value = int64_t; + + constexpr size_type num_keys{400}; + + auto map = cuco::static_map, + cuda::thread_scope_device, + thrust::equal_to, + cuco::linear_probing<1, Hash>, + cuco::cuda_allocator, + cuco::storage<2>>{ + num_keys, cuco::empty_key{-1}, cuco::empty_value{-1}}; + + auto keys_begin = thrust::counting_iterator(1); + + auto pairs_begin = thrust::make_transform_iterator( + keys_begin, cuda::proclaim_return_type>([] __device__(auto i) { + return cuco::pair(i, i); + })); + + thrust::device_vector d_keys_exist(num_keys); + + map.insert(pairs_begin, pairs_begin + num_keys); + + REQUIRE(map.size() == num_keys); + + map.contains(keys_begin, keys_begin + num_keys, d_keys_exist.begin()); + + REQUIRE(cuco::test::all_of(d_keys_exist.begin(), d_keys_exist.end(), thrust::identity{})); +} + +TEMPLATE_TEST_CASE_SIG("static_map hash tests", "", ((typename Key)), (int32_t), (int64_t)) +{ + test_hash_function>(); + test_hash_function>(); + test_hash_function>(); + test_hash_function>(); +} \ No newline at end of file