diff --git a/c/CMakeLists.txt b/c/CMakeLists.txt index be4bc7a051..77f43c8210 100644 --- a/c/CMakeLists.txt +++ b/c/CMakeLists.txt @@ -105,6 +105,7 @@ add_library( src/preprocessing/quantize/pq.cpp src/preprocessing/quantize/scalar.cpp src/distance/pairwise_distance.cpp + src/selection/select_k.cpp ) add_library(cuvs::c_api ALIAS cuvs_c) set_target_properties( diff --git a/c/include/cuvs/core/c_api.h b/c/include/cuvs/core/c_api.h index 00d4729481..a0da143c67 100644 --- a/c/include/cuvs/core/c_api.h +++ b/c/include/cuvs/core/c_api.h @@ -131,6 +131,23 @@ CUVS_EXPORT cuvsError_t cuvsStreamSync(cuvsResources_t res); */ CUVS_EXPORT cuvsError_t cuvsDeviceIdGet(cuvsResources_t res, int* device_id); +/** + * @brief Configure the temporary workspace on this resources object as an uncapped pool, backed + * by the current device memory resource. After the initial reservation is allocated on + * first use, subsequent calls to cuvsRMMAlloc / cuvsRMMFree on the same resources handle + * hit the pool cache rather than calling cudaMallocAsync / cudaFreeAsync, reducing CUDA + * context lock contention under concurrent query threads. The pool grows without shrinking: + * freed allocations are returned to the pool rather than to the device, so the pool's + * high-water mark only increases until the resources object is destroyed. + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] initial_size_bytes initial pool reservation in bytes; size to cover the + * steady-state working set to avoid growth after warmup + * @return cuvsError_t + */ +CUVS_EXPORT cuvsError_t cuvsResourcesSetWorkspacePool(cuvsResources_t res, + size_t initial_size_bytes); + /** * @brief Create an Initialized opaque C handle for C++ type `raft::device_resources_snmg` * for multi-GPU operations @@ -212,6 +229,19 @@ CUVS_EXPORT cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes CUVS_EXPORT cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_percent, int max_pool_size_percent, bool managed); +/** + * @brief Switches the working memory resource to use stream-ordered asynchronous allocation + * (cudaMallocAsync / cudaFreeAsync). Unlike the pool resource, this resource returns memory to + * the stream immediately without blocking the CPU, eliminating device-wide synchronization on + * deallocation. This is especially beneficial when multiple CAGRA searches run concurrently on + * separate CUDA streams, because the internal workspace allocations no longer serialize kernel + * launches. Be aware that this function will change the memory resource for the whole process + * and the new memory resource will be used until explicitly changed. + * + * @return cuvsError_t + */ +CUVS_EXPORT cuvsError_t cuvsRMMAsyncMemoryResourceEnable(); + /** * @brief Resets the memory resource to use the default memory resource (cuda_memory_resource) * @return cuvsError_t diff --git a/c/include/cuvs/neighbors/cagra.h b/c/include/cuvs/neighbors/cagra.h index 22809da37e..bcab94419c 100644 --- a/c/include/cuvs/neighbors/cagra.h +++ b/c/include/cuvs/neighbors/cagra.h @@ -714,6 +714,44 @@ CUVS_EXPORT cuvsError_t cuvsCagraSearch(cuvsResources_t res, DLManagedTensor* distances, cuvsFilter filter); +/** + * @brief Search multiple CAGRA index partitions concurrently and return the global top-k per + * query. + * + * For each query row, the function searches all partitions in parallel into an internal + * intermediate buffer, applies per-partition distance post-processing, runs a batched top-k + * merge across partitions, and writes the final outputs to the caller-supplied device tensors. + * All work is submitted to the CUDA stream associated with @p res; use @c cuvsStreamSync to + * wait for completion. + * + * Only float32 datasets are currently supported. + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] params search parameters (shared across partitions) + * @param[in] num_partitions number of index partitions + * @param[in] indices array of num_partitions cuvsCagraIndex_t pointers + * @param[in] queries DLManagedTensor* (device, float32, [n_queries, dim]); the queries + * matrix is searched against every partition + * @param[out] partition_ids DLManagedTensor* (device, uint32, [n_queries, k]); which partition + * each returned neighbor came from + * @param[out] neighbors DLManagedTensor* (device, uint32 or int64, [n_queries, k]); ordinal + * in the corresponding partition's dataset + * @param[out] distances DLManagedTensor* (device, float32, [n_queries, k]); post-processed + * distance for each (query, neighbor) + * @param[in] filter filter to apply during search; use {.type=NO_FILTER, .addr=0} for + * unfiltered search, or {.type=MULTI_PARTITION_BITSET, .addr=ptr} where + * ptr is a uintptr_t-cast cuvsMultiPartitionBitsetFilter* + */ +CUVS_EXPORT cuvsError_t cuvsCagraSearchMultiPartition(cuvsResources_t res, + cuvsCagraSearchParams_t params, + uint32_t num_partitions, + cuvsCagraIndex_t* indices, + DLManagedTensor* queries, + DLManagedTensor* partition_ids, + DLManagedTensor* neighbors, + DLManagedTensor* distances, + cuvsFilter filter); + /** * @} */ diff --git a/c/include/cuvs/neighbors/common.h b/c/include/cuvs/neighbors/common.h index d4a124b45c..d30218c01a 100644 --- a/c/include/cuvs/neighbors/common.h +++ b/c/include/cuvs/neighbors/common.h @@ -5,6 +5,7 @@ #pragma once +#include #include #include @@ -28,9 +29,28 @@ enum cuvsFilterType { /* Filter an index with a bitset */ BITSET = 1, /* Filter an index with a bitmap */ - BITMAP = 2 + BITMAP = 2, + /* Filter multiple index partitions with a single concatenated bitset plus per-partition offsets */ + MULTI_PARTITION_BITSET = 3 }; +/** + * @brief Filter parameters for multi-partition search. + * + * Holds a single device bitset that is the concatenation of per-partition bitsets, + * together with a device array of per-partition bit offsets. Pass a pointer to + * this struct (cast to uintptr_t) in cuvsFilter::addr with + * cuvsFilter::type == MULTI_PARTITION_BITSET. + */ +typedef struct { + /** Device tensor (uint32, flat) of packed bitset words for all partitions concatenated. */ + DLManagedTensor* combined_bitset; + /** Total number of logical bits in combined_bitset. */ + int64_t total_bitset_bits; + /** Device tensor (int64, [num_partitions]) of per-partition bit offsets into combined_bitset. */ + DLManagedTensor* partition_offsets; +} cuvsMultiPartitionBitsetFilter; + /** * @brief Struct to hold address of cuvs::neighbors::prefilter and its type * diff --git a/c/include/cuvs/selection/select_k.h b/c/include/cuvs/selection/select_k.h new file mode 100644 index 0000000000..6f133d1627 --- /dev/null +++ b/c/include/cuvs/selection/select_k.h @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Select the k smallest values from a flat device array of n candidates. + * + * Treats `in_val` as a matrix of shape [1, n] and selects the `k` smallest + * float values. `out_idx` receives the int64 column positions of the selected + * values in [0, n), so the caller can recover per-segment identity as: + * + * segment_index = out_idx[j] / segment_k + * position_in_segment = out_idx[j] % segment_k + * + * @param[in] res cuvsResources_t handle + * @param[in] in_val DLManagedTensor* shape [1, n], float32, device memory + * @param[out] out_val DLManagedTensor* shape [1, k], float32, device memory + * @param[out] out_idx DLManagedTensor* shape [1, k], int64, device memory + * @return cuvsError_t + */ +CUVS_EXPORT cuvsError_t cuvsSelectK(cuvsResources_t res, + DLManagedTensor* in_val, + DLManagedTensor* out_val, + DLManagedTensor* out_idx); + +#ifdef __cplusplus +} +#endif diff --git a/c/src/core/c_api.cpp b/c/src/core/c_api.cpp index f4e3664482..05e3856da1 100644 --- a/c/src/core/c_api.cpp +++ b/c/src/core/c_api.cpp @@ -9,11 +9,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -35,6 +37,19 @@ extern "C" cuvsError_t cuvsResourcesCreate(cuvsResources_t* res) }); } +extern "C" cuvsError_t cuvsResourcesSetWorkspacePool(cuvsResources_t res, size_t initial_size_bytes) +{ + return cuvs::core::translate_exceptions([=] { + auto res_ptr = reinterpret_cast(res); + // Create an uncapped pool: pre-warms with initial_size_bytes to avoid cudaMalloc on every + // query, but can grow beyond that if an allocation exceeds the initial reservation. + raft::resource::set_workspace_resource( + *res_ptr, + rmm::mr::pool_memory_resource{rmm::mr::get_current_device_resource_ref(), + initial_size_bytes}); + }); +} + extern "C" cuvsError_t cuvsResourcesDestroy(cuvsResources_t res) { return cuvs::core::translate_exceptions([=] { @@ -132,8 +147,8 @@ extern "C" cuvsError_t cuvsRMMAlloc(cuvsResources_t res, void** ptr, size_t byte { return cuvs::core::translate_exceptions([=] { auto res_ptr = reinterpret_cast(res); - auto mr = rmm::mr::get_current_device_resource_ref(); - *ptr = mr.allocate(raft::resource::get_cuda_stream(*res_ptr), bytes); + auto stream = raft::resource::get_cuda_stream(*res_ptr); + *ptr = raft::resource::get_workspace_resource_ref(*res_ptr).allocate(stream, bytes); }); } @@ -141,11 +156,13 @@ extern "C" cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes) { return cuvs::core::translate_exceptions([=] { auto res_ptr = reinterpret_cast(res); - auto mr = rmm::mr::get_current_device_resource_ref(); - mr.deallocate(raft::resource::get_cuda_stream(*res_ptr), ptr, bytes); + auto stream = raft::resource::get_cuda_stream(*res_ptr); + raft::resource::get_workspace_resource_ref(*res_ptr).deallocate(stream, ptr, bytes); }); } +thread_local std::shared_ptr async_mr; + extern "C" cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_percent, int max_pool_size_percent, bool managed) @@ -164,9 +181,20 @@ extern "C" cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_per }); } +extern "C" cuvsError_t cuvsRMMAsyncMemoryResourceEnable() +{ + return cuvs::core::translate_exceptions([=] { + async_mr = std::make_shared(); + rmm::mr::set_current_device_resource(*async_mr); + }); +} + extern "C" cuvsError_t cuvsRMMMemoryResourceReset() { - return cuvs::core::translate_exceptions([=] { rmm::mr::reset_current_device_resource(); }); + return cuvs::core::translate_exceptions([=] { + rmm::mr::reset_current_device_resource(); + async_mr.reset(); + }); } thread_local std::unique_ptr pinned_mr; diff --git a/c/src/neighbors/cagra.cpp b/c/src/neighbors/cagra.cpp index 081179ca46..57a7723556 100644 --- a/c/src/neighbors/cagra.cpp +++ b/c/src/neighbors/cagra.cpp @@ -689,6 +689,90 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, }); } +extern "C" cuvsError_t cuvsCagraSearchMultiPartition(cuvsResources_t res, + cuvsCagraSearchParams_t params, + uint32_t num_partitions, + cuvsCagraIndex_t* indices, + DLManagedTensor* queries, + DLManagedTensor* partition_ids, + DLManagedTensor* neighbors, + DLManagedTensor* distances, + cuvsFilter filter) +{ + return cuvs::core::translate_exceptions([=] { + RAFT_EXPECTS(num_partitions > 0, "num_partitions must be > 0"); + RAFT_EXPECTS(indices != nullptr && queries != nullptr && partition_ids != nullptr && + neighbors != nullptr && distances != nullptr, + "All pointer arguments must be non-null"); + + auto res_ptr = reinterpret_cast(res); + auto search_params = cuvs::neighbors::cagra::search_params(); + convert_c_search_params(*params, &search_params); + + // Only float32 is supported for multi-partition search. + RAFT_EXPECTS( + indices[0]->dtype.code == kDLFloat && indices[0]->dtype.bits == 32, + "Multi-partition search only supports float32 indices"); + + using T = float; + using IdxT = uint32_t; + using OutIdxT = uint32_t; + using DistanceT = float; + using IndexT = cuvs::neighbors::cagra::index; + + std::vector idx_vec(num_partitions); + for (uint32_t i = 0; i < num_partitions; i++) { + RAFT_EXPECTS(indices[i] != nullptr && indices[i]->addr != 0, + "Index at position %u is null or not built", i); + idx_vec[i] = reinterpret_cast(indices[i]->addr); + } + + using queries_view_t = raft::device_matrix_view; + using pid_view_t = raft::device_matrix_view; + using nbrs_view_t = raft::device_matrix_view; + using dist_view_t = raft::device_matrix_view; + + auto queries_view = cuvs::core::from_dlpack(queries); + auto partition_ids_view = cuvs::core::from_dlpack(partition_ids); + auto neighbors_view = cuvs::core::from_dlpack(neighbors); + auto distances_view = cuvs::core::from_dlpack(distances); + + if (filter.type == NO_FILTER) { + cuvs::neighbors::cagra::search_multi_partition(*res_ptr, + search_params, + idx_vec, + queries_view, + partition_ids_view, + neighbors_view, + distances_view); + } else if (filter.type == MULTI_PARTITION_BITSET) { + auto* f = reinterpret_cast(filter.addr); + RAFT_EXPECTS(f != nullptr, "MULTI_PARTITION_BITSET filter addr must be non-null"); + + using bitset_mdspan_t = raft::device_vector_view; + using offsets_mdspan_t = raft::device_vector_view; + auto bitset_mds = cuvs::core::from_dlpack(f->combined_bitset); + auto offsets_mds = cuvs::core::from_dlpack(f->partition_offsets); + + cuvs::core::bitset_view combined_bitset_view( + bitset_mds, f->total_bitset_bits); + cuvs::neighbors::filtering::multi_partition_bitset_filter mp_filter( + combined_bitset_view, offsets_mds.data_handle()); + + cuvs::neighbors::cagra::search_multi_partition(*res_ptr, + search_params, + idx_vec, + queries_view, + partition_ids_view, + neighbors_view, + distances_view, + mp_filter); + } else { + RAFT_FAIL("Unsupported filter type for multi-partition search: %d", (int)filter.type); + } + }); +} + extern "C" cuvsError_t cuvsCagraMerge(cuvsResources_t res, cuvsCagraIndexParams_t params, cuvsCagraIndex_t* indices, diff --git a/c/src/selection/select_k.cpp b/c/src/selection/select_k.cpp new file mode 100644 index 0000000000..f68416454a --- /dev/null +++ b/c/src/selection/select_k.cpp @@ -0,0 +1,42 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include "../core/exceptions.hpp" +#include +#include + +#include +#include + +extern "C" cuvsError_t cuvsSelectK(cuvsResources_t res, + DLManagedTensor* in_val, + DLManagedTensor* out_val, + DLManagedTensor* out_idx) +{ + return cuvs::core::translate_exceptions([=] { + auto* res_ptr = reinterpret_cast(res); + + int64_t n = in_val->dl_tensor.shape[1]; + int64_t k = out_val->dl_tensor.shape[1]; + + auto in_view = raft::make_device_matrix_view( + static_cast(in_val->dl_tensor.data), 1, n); + + auto out_val_view = raft::make_device_matrix_view( + static_cast(out_val->dl_tensor.data), 1, k); + + auto out_idx_view = raft::make_device_matrix_view( + static_cast(out_idx->dl_tensor.data), 1, k); + + cuvs::selection::select_k( + *res_ptr, + in_view, + std::nullopt, // implicit positions [0, n) as in_idx + out_val_view, + out_idx_view, + true); // select_min = true (smallest distance = nearest neighbor) + }); +} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 227c2906cc..a182f0ff0f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -791,6 +791,35 @@ if(NOT BUILD_CPU_ONLY) OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_multi_cta" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT + "cagra_search_single_cta_mp_@topk_by_bitonic_sort_str@_@bitonic_sort_and_merge_multi_warps_str@_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_mp_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_mp_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_search_single_cta_mp<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@source_index_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, @topk_by_bitonic_sort@, @bitonic_sort_and_merge_multi_warps@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_single_cta_mp" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_files + NAME_FORMAT "cagra_search_multi_cta_mp_data_@data_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_mp_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_mp_kernel.cu.in" + FRAGMENT_TAG_FORMAT + "${cagra_ns}::fragment_tag_search_multi_cta_mp<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@source_index_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@>" + FRAGMENT_TAG_HEADER_FILES "" + "" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_multi_cta_mp" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) generate_jit_lto_kernels( jit_lto_files NAME_FORMAT "cagra_random_pickup_data_@data_abbrev@" diff --git a/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp index 0b42d79379..326b30eade 100644 --- a/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp @@ -75,9 +75,20 @@ template struct fragment_tag_search_single_cta_p {}; +template +struct fragment_tag_search_single_cta_mp {}; + template struct fragment_tag_search_multi_cta {}; +template +struct fragment_tag_search_multi_cta_mp {}; + template struct fragment_tag_random_pickup {}; diff --git a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp index cbd3f72730..15dd2bdf20 100644 --- a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp @@ -14,6 +14,7 @@ struct tag_i8 {}; struct tag_u8 {}; struct tag_filter_none {}; struct tag_filter_bitset {}; +struct tag_filter_mp_bitset {}; struct tag_bitset_u32 {}; diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 8edbcab8fa..5d95f795e5 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -30,6 +30,7 @@ #include #include #include +#include namespace CUVS_EXPORT cuvs { namespace neighbors { @@ -1730,6 +1731,115 @@ void search(raft::resources const& res, const cuvs::neighbors::filtering::base_filter& sample_filter = cuvs::neighbors::filtering::none_sample_filter{}); +/** + * @brief Search multiple CAGRA index partitions concurrently and return the global top-k per + * query. + * + * For each query row in @p queries, the kernel searches all partitions in parallel into an + * internal intermediate buffer, applies per-partition distance post-processing, runs a batched + * top-k merge across partitions, and writes the final outputs. The call returns when all work + * has been submitted to the stream associated with @p res (not necessarily completed); call + * @c raft::resource::sync_stream on @p res to wait for completion. + * + * @param[in] res raft resources + * @param[in] params search parameters (shared across partitions) + * @param[in] indices CAGRA index objects, one per partition + * @param[in] queries queries matrix, shape [n_queries, dim]; searched against every + * partition + * @param[out] partition_ids which partition each neighbor came from, shape [n_queries, k] + * @param[out] neighbors ordinal in the corresponding partition's dataset, shape + * [n_queries, k] + * @param[out] distances post-processed distance for each (query, neighbor), shape + * [n_queries, k] + */ +void search_multi_partition( + raft::resources const& res, + cuvs::neighbors::cagra::search_params const& params, + const std::vector*>& indices, + raft::device_matrix_view queries, + raft::device_matrix_view partition_ids, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); + +void search_multi_partition( + raft::resources const& res, + cuvs::neighbors::cagra::search_params const& params, + const std::vector*>& indices, + raft::device_matrix_view queries, + raft::device_matrix_view partition_ids, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); + +void search_multi_partition( + raft::resources const& res, + cuvs::neighbors::cagra::search_params const& params, + const std::vector*>& indices, + raft::device_matrix_view queries, + raft::device_matrix_view partition_ids, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); + +void search_multi_partition( + raft::resources const& res, + cuvs::neighbors::cagra::search_params const& params, + const std::vector*>& indices, + raft::device_matrix_view queries, + raft::device_matrix_view partition_ids, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); + +void search_multi_partition( + raft::resources const& res, + cuvs::neighbors::cagra::search_params const& params, + const std::vector*>& indices, + raft::device_matrix_view queries, + raft::device_matrix_view partition_ids, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); + +void search_multi_partition( + raft::resources const& res, + cuvs::neighbors::cagra::search_params const& params, + const std::vector*>& indices, + raft::device_matrix_view queries, + raft::device_matrix_view partition_ids, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); + +void search_multi_partition( + raft::resources const& res, + cuvs::neighbors::cagra::search_params const& params, + const std::vector*>& indices, + raft::device_matrix_view queries, + raft::device_matrix_view partition_ids, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); + +void search_multi_partition( + raft::resources const& res, + cuvs::neighbors::cagra::search_params const& params, + const std::vector*>& indices, + raft::device_matrix_view queries, + raft::device_matrix_view partition_ids, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter = + cuvs::neighbors::filtering::none_sample_filter{}); + /** * @} */ diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 887593c23b..4240505bc5 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -495,7 +495,7 @@ namespace filtering { * @{ */ -enum class FilterType { None, Bitmap, Bitset }; +enum class FilterType { None, Bitmap, Bitset, MultiPartitionBitset }; struct base_filter { ~base_filter() = default; @@ -615,6 +615,45 @@ struct bitset_filter : public base_filter { void to_csr(raft::resources const& handle, csr_matrix_t& csr); }; +/** + * @brief Filter for multi-partition CAGRA search backed by a single concatenated bitset. + * + * All per-partition bitsets are packed into one contiguous device buffer. The offset + * for partition i (in bits) is partition_offsets_[i]. Inside a multi-partition kernel, + * blockIdx.z carries the partition index, so each thread block automatically selects + * the correct portion of the bitset. + * + * @tparam bitset_t Word type of the bitset (e.g. uint32_t) + * @tparam index_t Index type (e.g. int64_t) + */ +template +struct multi_partition_bitset_filter : public base_filter { + using view_t = cuvs::core::bitset_view; + + const view_t combined_bitset_; + const index_t* partition_offsets_; // device pointer to [num_partitions] bit offsets + + _RAFT_HOST_DEVICE multi_partition_bitset_filter(const view_t combined_bitset, + const index_t* partition_offsets) + : combined_bitset_(combined_bitset), partition_offsets_(partition_offsets) + { + } + + /** \cond */ + constexpr __forceinline__ _RAFT_HOST_DEVICE bool operator()(const uint32_t query_ix, + const uint32_t sample_ix) const + { +#ifdef __CUDA_ARCH__ + return combined_bitset_.test(partition_offsets_[blockIdx.z] + static_cast(sample_ix)); +#else + return true; // unreachable on host; blockIdx.z is device-only +#endif + } + /** \endcond */ + + FilterType get_filter_type() const override { return FilterType::MultiPartitionBitset; } +}; + /** @} */ // end group neighbors_filtering /** diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index 73c3794d39..ac77b84e28 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -406,6 +406,54 @@ index merge(raft::resources const& handle, return cagra::detail::merge(handle, params, indices, row_filter); } +template +void search_multi_partition( + raft::resources const& res, + search_params const& params, + const std::vector*>& indices, + raft::device_matrix_view queries, + raft::device_matrix_view partition_ids, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + CagraSampleFilterT sample_filter = CagraSampleFilterT{}) +{ + cagra::detail::search_multi_partition( + res, params, indices, queries, partition_ids, neighbors, distances, sample_filter); +} + +template +void search_multi_partition( + raft::resources const& res, + search_params const& params, + const std::vector*>& indices, + raft::device_matrix_view queries, + raft::device_matrix_view partition_ids, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const cuvs::neighbors::filtering::base_filter& sample_filter_ref) +{ + try { + using none_filter_t = cuvs::neighbors::filtering::none_sample_filter; + auto& f = dynamic_cast(sample_filter_ref); + return search_multi_partition( + res, params, indices, queries, partition_ids, neighbors, distances, f); + } catch (const std::bad_cast&) { + } + + try { + using mp_filter_t = + cuvs::neighbors::filtering::multi_partition_bitset_filter; + auto& f = dynamic_cast(sample_filter_ref); + return search_multi_partition( + res, params, indices, queries, partition_ids, neighbors, distances, f); + } catch (const std::bad_cast&) { + RAFT_FAIL("Unsupported sample filter type for multi-partition search"); + } +} + /** @} */ // end group cagra } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/cagra_search_inst.cu.in b/cpp/src/neighbors/cagra_search_inst.cu.in index dfef630798..96fd1ccc77 100644 --- a/cpp/src/neighbors/cagra_search_inst.cu.in +++ b/cpp/src/neighbors/cagra_search_inst.cu.in @@ -32,4 +32,24 @@ CUVS_INST_CAGRA_SEARCH(data_t, uint32_t, int64_t); #undef CUVS_INST_CAGRA_SEARCH +#define CUVS_INST_CAGRA_SEARCH_MULTI_PARTITION(T, IdxT, OutputIdxT) \ + void search_multi_partition( \ + raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const std::vector*>& indices, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view partition_ids, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + const cuvs::neighbors::filtering::base_filter& sample_filter) \ + { \ + cuvs::neighbors::cagra::search_multi_partition( \ + handle, params, indices, queries, partition_ids, neighbors, distances, sample_filter); \ + } + +CUVS_INST_CAGRA_SEARCH_MULTI_PARTITION(data_t, uint32_t, uint32_t); +CUVS_INST_CAGRA_SEARCH_MULTI_PARTITION(data_t, uint32_t, int64_t); + +#undef CUVS_INST_CAGRA_SEARCH_MULTI_PARTITION + } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index bca8d3314d..f2b6b30ad5 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -8,7 +8,9 @@ #include "../../../core/nvtx.hpp" #include "factory.cuh" #include "sample_filter_utils.cuh" +#include "search_multi_cta.cuh" #include "search_plan.cuh" +#include "search_single_cta.cuh" #include #include @@ -28,9 +30,11 @@ // TODO: This shouldn't be calling spatial/knn apis #include "../ann_utils.cuh" +#include #include #include #include +#include // All includes are done before opening namespace to avoid nested namespace issues namespace cuvs::neighbors::cagra::detail { @@ -252,4 +256,381 @@ void search_main(raft::resources const& res, } /** @} */ // end group cagra +/** + * @brief Search all partitions concurrently and return the global top-k per query. + * + * For each query row in @p queries, the kernel searches all partitions in parallel + * (blockIdx.z = partition_id, blockIdx.y = query_id) into an internal intermediate buffer. + * Per-partition distance post-processing is applied, then a batched select_k merges across + * partitions and a small decode pass writes the final outputs. + * + * @param indices CAGRA index objects, one per partition (strided datasets only) + * @param queries queries matrix [n_queries, dim]; searched against every partition + * @param partition_ids output: which partition each neighbor came from, shape [n_queries, k] + * @param neighbors output: ordinal in partition[i]'s dataset, shape [n_queries, k] + * @param distances output: post-processed distance, shape [n_queries, k] + */ +template +void search_multi_partition( + raft::resources const& res, + search_params params, + const std::vector*>& indices, + raft::device_matrix_view queries, + raft::device_matrix_view partition_ids, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + CagraSampleFilterT sample_filter = CagraSampleFilterT{}) +{ + static_assert(std::is_same_v, "Only uint32_t graph index type is supported"); + static_assert(std::is_same_v, "Only float distances are supported"); + + const uint32_t num_partitions = static_cast(indices.size()); + RAFT_EXPECTS(num_partitions > 0, "At least one partition is required"); + + const uint32_t n_queries = static_cast(queries.extent(0)); + const int64_t dim = queries.extent(1); + const uint32_t topk = static_cast(neighbors.extent(1)); + + RAFT_EXPECTS(partition_ids.extent(0) == static_cast(n_queries) && + partition_ids.extent(1) == static_cast(topk), + "partition_ids shape must be [n_queries, k]"); + RAFT_EXPECTS(neighbors.extent(0) == static_cast(n_queries), + "neighbors and queries must have the same number of rows"); + RAFT_EXPECTS(distances.extent(0) == static_cast(n_queries) && + distances.extent(1) == static_cast(topk), + "distances shape must be [n_queries, k]"); + + // Find the max graph_degree across all partitions (needed for the shared kernel plan). + int64_t max_graph_degree = 0; + int64_t max_dataset_size = 0; + for (uint32_t i = 0; i < num_partitions; i++) { + RAFT_EXPECTS(!indices[i]->dataset_fd().has_value(), + "Disk-based datasets are not supported for multi-partition search"); + max_graph_degree = std::max(max_graph_degree, indices[i]->graph().extent(1)); + max_dataset_size = std::max(max_dataset_size, indices[i]->data().n_rows()); + } + + if (params.max_queries == 0) { + cudaDeviceProp deviceProp = raft::resource::get_device_properties(res); + params.max_queries = + std::min(static_cast(n_queries), deviceProp.maxGridSize[1]); + } + + // Persistent kernels are not used in multi-partition search regardless of which algo runs. + params.persistent = false; + + // MULTI_KERNEL is a reference implementation and is substantially slower than SINGLE_CTA / + // MULTI_CTA in practice; multi-partition deliberately does not route to it. + if (params.algo == search_algo::MULTI_KERNEL) { + RAFT_FAIL("MULTI_KERNEL is not supported for multi-partition search"); + } + + // AUTO resolution. Mirrors single-partition's heuristic in search_plan_impl_base, with the + // occupancy gate scaled by num_partitions (multi-partition grids already have a partition + // axis, so each query produces num_partitions CTAs on SINGLE_CTA). SINGLE_CTA's + // itopk_size <= 512 hard cap is enforced in its plan constructor (search_single_cta.cuh); + // above that, AUTO must route to MULTI_CTA. Below the cap, SINGLE_CTA wins only if there + // are enough (query, partition) CTAs to fill the GPU; otherwise MULTI_CTA's + // ceildiv(itopk_size, 32) CTAs per query recover occupancy. + if (params.algo == search_algo::AUTO) { + const size_t num_sm = raft::getMultiProcessorCount(); + if (params.itopk_size <= 512 && + static_cast(params.max_queries) * num_partitions >= num_sm * 2lu) { + params.algo = search_algo::SINGLE_CTA; + } else { + params.algo = search_algo::MULTI_CTA; + } + } + + // Build a single plan_desc sized for the maximum graph_degree across all partitions. The + // smem layout in the descriptor is type-dependent only, so any partition's descriptor (we + // pick indices[0]) is representative for the plan's smem/sizing calculations. + using graph_idx_type = uint32_t; + auto* strided_dset0 = dynamic_cast*>(&indices[0]->data()); + RAFT_EXPECTS(strided_dset0 != nullptr, + "Multi-partition search only supports strided (non-compressed) datasets"); + + RAFT_EXPECTS(indices[0]->metric() != cuvs::distance::DistanceType::CosineExpanded || + indices[0]->dataset_norms().has_value(), + "Dataset norms must be provided for CosineExpanded metric"); + const float* dataset_norms_ptr0 = nullptr; + if (indices[0]->metric() == cuvs::distance::DistanceType::CosineExpanded) { + dataset_norms_ptr0 = indices[0]->dataset_norms().value().data_handle(); + } + auto plan_desc = dataset_descriptor_init_with_cache( + res, params, *strided_dset0, indices[0]->metric(), dataset_norms_ptr0); + + cudaStream_t stream = raft::resource::get_cuda_stream(res); + + // Number of candidates each partition contributes to the cross-partition merge below. + // SINGLE_CTA's kernel produces exactly `topk` per partition; MULTI_CTA's kernel emits + // `num_cta_per_query * itopk_size` per partition (no per-partition merge — rely on the + // cross-partition select_k below to pick the final global top-k). + uint32_t per_partition_topk = 0; + + // Intermediate buffers shared between algos and post-processing; sized below per-algo. + size_t partition_stride = 0; + size_t intermediate_size = 0; + lightweight_uvector intermediate_neighbors(res); + lightweight_uvector intermediate_distances(res); + + if (params.algo == search_algo::SINGLE_CTA) { + single_cta_search:: + search + plan(res, params, plan_desc, dim, max_dataset_size, max_graph_degree, topk); + + RAFT_EXPECTS(topk <= plan.itopk_size, + "topk = %u must be smaller than itopk_size = %lu", + topk, + plan.itopk_size); + + per_partition_topk = topk; + partition_stride = static_cast(n_queries) * per_partition_topk; + intermediate_size = static_cast(num_partitions) * partition_stride; + intermediate_neighbors.resize(intermediate_size, stream); + intermediate_distances.resize(intermediate_size, stream); + + // Build per-partition descriptors on the host. Queries and result buffers are shared + // across partitions and are passed to the kernel as separate parameters. + using part_desc_t = single_cta_search::multi_partition_desc_t; + std::vector host_part_descs(num_partitions); + + // Collect per-partition dataset descriptors (may trigger lazy device init on `stream`). + std::vector> part_dataset_descs; + part_dataset_descs.reserve(num_partitions); + + for (uint32_t i = 0; i < num_partitions; i++) { + auto* strided_dset = dynamic_cast*>(&indices[i]->data()); + RAFT_EXPECTS(strided_dset != nullptr, + "All partitions must have strided (non-compressed) datasets"); + const float* norms_ptr = nullptr; + if (indices[i]->metric() == cuvs::distance::DistanceType::CosineExpanded) { + RAFT_EXPECTS(indices[i]->dataset_norms().has_value(), + "Dataset norms required for CosineExpanded metric (partition %u)", + i); + norms_ptr = indices[i]->dataset_norms().value().data_handle(); + } + part_dataset_descs.push_back(dataset_descriptor_init_with_cache( + res, params, *strided_dset, indices[i]->metric(), norms_ptr)); + + host_part_descs[i].dataset_desc = part_dataset_descs.back().dev_ptr(stream); + host_part_descs[i].graph = indices[i]->graph().data_handle(); + host_part_descs[i].graph_degree = static_cast(indices[i]->graph().extent(1)); + } + + lightweight_uvector dev_part_descs_buf(res); + dev_part_descs_buf.resize(num_partitions, stream); + RAFT_CUDA_TRY(cudaMemcpyAsync(dev_part_descs_buf.data(), + host_part_descs.data(), + num_partitions * sizeof(part_desc_t), + cudaMemcpyHostToDevice, + stream)); + + plan.run_multi_partition(res, + dev_part_descs_buf.data(), + num_partitions, + queries.data_handle(), + n_queries, + intermediate_neighbors.data(), + intermediate_distances.data(), + per_partition_topk, + sample_filter); + } else /* MULTI_CTA */ { + multi_cta_search:: + search + plan(res, params, plan_desc, dim, max_dataset_size, max_graph_degree, topk); + + // MULTI_CTA splits the global itopk pool across num_cta_per_query CTAs of 32 candidates + // each. The kernel emits all num_cta_per_query * itopk_size candidates per (query, + // partition) and lets the cross-partition select_k below pick the final global top-k. + per_partition_topk = + static_cast(plan.num_cta_per_query) * static_cast(plan.itopk_size); + partition_stride = static_cast(n_queries) * per_partition_topk; + intermediate_size = static_cast(num_partitions) * partition_stride; + intermediate_neighbors.resize(intermediate_size, stream); + intermediate_distances.resize(intermediate_size, stream); + + using part_desc_t = multi_cta_search::multi_partition_desc_t; + std::vector host_part_descs(num_partitions); + + std::vector> part_dataset_descs; + part_dataset_descs.reserve(num_partitions); + + for (uint32_t i = 0; i < num_partitions; i++) { + auto* strided_dset = dynamic_cast*>(&indices[i]->data()); + RAFT_EXPECTS(strided_dset != nullptr, + "All partitions must have strided (non-compressed) datasets"); + const float* norms_ptr = nullptr; + if (indices[i]->metric() == cuvs::distance::DistanceType::CosineExpanded) { + RAFT_EXPECTS(indices[i]->dataset_norms().has_value(), + "Dataset norms required for CosineExpanded metric (partition %u)", + i); + norms_ptr = indices[i]->dataset_norms().value().data_handle(); + } + part_dataset_descs.push_back(dataset_descriptor_init_with_cache( + res, params, *strided_dset, indices[i]->metric(), norms_ptr)); + + host_part_descs[i].dataset_desc = part_dataset_descs.back().dev_ptr(stream); + host_part_descs[i].graph = indices[i]->graph().data_handle(); + host_part_descs[i].graph_degree = static_cast(indices[i]->graph().extent(1)); + } + + lightweight_uvector dev_part_descs_buf(res); + dev_part_descs_buf.resize(num_partitions, stream); + RAFT_CUDA_TRY(cudaMemcpyAsync(dev_part_descs_buf.data(), + host_part_descs.data(), + num_partitions * sizeof(part_desc_t), + cudaMemcpyHostToDevice, + stream)); + + plan.run_multi_partition(res, + dev_part_descs_buf.data(), + num_partitions, + static_cast(max_graph_degree), + queries.data_handle(), + n_queries, + intermediate_neighbors.data(), + intermediate_distances.data(), + sample_filter); + } + + // Per-partition distance post-processing (scale + metric transform). Each partition's slice in + // intermediate_distances has shape [n_queries, per_partition_topk] and is contiguous row-major. + constexpr float kScale = cuvs::spatial::knn::detail::utils::config::kDivisor / + cuvs::spatial::knn::detail::utils::config::kDivisor; + + // Query norms (used only by CosineExpanded). Queries are shared across partitions, so compute + // once. The unconditional allocation is small (n_queries floats) relative to the search. + auto query_norms = raft::make_device_vector(res, n_queries); + { + auto scaled_sq_op = raft::compose_op( + raft::sq_op{}, raft::div_const_op{DistanceT(kScale)}, raft::cast_op()); + raft::linalg::reduce( + res, + raft::make_device_matrix_view( + queries.data_handle(), n_queries, dim), + query_norms.view(), + (DistanceT)0, + false, + scaled_sq_op, + raft::add_op(), + raft::sqrt_op{}); + } + + for (uint32_t i = 0; i < num_partitions; i++) { + DistanceT* slice_ptr = + intermediate_distances.data() + static_cast(i) * partition_stride; + if (indices[i]->metric() == cuvs::distance::DistanceType::CosineExpanded) { + auto slice_view = raft::make_device_matrix_view( + slice_ptr, n_queries, per_partition_topk); + raft::linalg::matrix_vector_op( + res, + raft::make_const_mdspan(slice_view), + raft::make_const_mdspan(query_norms.view()), + slice_view, + raft::compose_op(raft::add_const_op{DistanceT(1)}, raft::div_checkzero_op{})); + } else { + cuvs::neighbors::ivf::detail::postprocess_distances(res, + slice_ptr, + slice_ptr, + indices[i]->metric(), + n_queries, + per_partition_topk, + kScale, + true); + } + } + + // Transpose intermediate_distances from [num_partitions, n_queries, per_partition_topk] to + // [n_queries, num_partitions * per_partition_topk] so batched select_k can pick global top-k + // per query. (raft::matrix::select_k requires row-major contiguous input; a strided view + // won't suffice.) + lightweight_uvector transposed_distances(res); + transposed_distances.resize(intermediate_size, stream); + { + const DistanceT* src = intermediate_distances.data(); + const int64_t row_stride = static_cast(num_partitions) * per_partition_topk; + const int64_t partition_stride_i64 = static_cast(partition_stride); + const int64_t per_partition_topk_i64 = per_partition_topk; + auto transposed_view = raft::make_device_matrix_view( + transposed_distances.data(), static_cast(n_queries), row_stride); + raft::linalg::map_offset( + res, + transposed_view, + [src, row_stride, partition_stride_i64, per_partition_topk_i64] __device__(int64_t idx) { + const int64_t q = idx / row_stride; + const int64_t rem = idx % row_stride; + const int64_t p = rem / per_partition_topk_i64; + const int64_t j = rem % per_partition_topk_i64; + return src[p * partition_stride_i64 + q * per_partition_topk_i64 + j]; + }); + } + + // Batched select_k: for each query row, find the global top-k across all partition slots. + // Writes the final `distances` directly; writes positions in + // [0, num_partitions * per_partition_topk) into `positions_buf` for decoding into + // partition_ids and neighbors below. + lightweight_uvector positions_buf(res); + positions_buf.resize(static_cast(n_queries) * topk, stream); + auto positions_view = raft::make_device_matrix_view( + positions_buf.data(), n_queries, topk); + + raft::matrix::select_k( + res, + raft::make_device_matrix_view( + transposed_distances.data(), + static_cast(n_queries), + static_cast(num_partitions) * per_partition_topk), + std::nullopt, + distances, + positions_view, + /*select_min=*/true); + + // Decode positions into partition_ids and neighbors. + // positions[q, j_out] ∈ [0, num_partitions * per_partition_topk) encodes + // (partition, slot_in_partition): + // partition_ids[q, j_out] = pos / per_partition_topk + // neighbors[q, j_out] = intermediate_neighbors[ + // (pos / per_partition_topk) * partition_stride + // + q * per_partition_topk + (pos % per_partition_topk)] + // The output buffers (partition_ids, neighbors) have stride `topk` (caller-owned shape); + // the intermediate buffer has per-partition stride `per_partition_topk`. The two strides + // differ when the kernel emits more than `topk` candidates per partition (e.g. MULTI_CTA mp). + { + const uint32_t per_partition_topk_u32 = per_partition_topk; + raft::linalg::map( + res, + partition_ids, + [per_partition_topk_u32] __device__(uint32_t pos) { return pos / per_partition_topk_u32; }, + raft::make_const_mdspan(positions_view)); + } + { + const graph_idx_type* intermediate_neighbors_ptr = intermediate_neighbors.data(); + const uint32_t* positions_ptr = positions_buf.data(); + const int64_t partition_stride_i64 = static_cast(partition_stride); + const int64_t per_partition_topk_i64 = per_partition_topk; + const int64_t topk_i64 = topk; + raft::linalg::map_offset( + res, + neighbors, + [intermediate_neighbors_ptr, + positions_ptr, + partition_stride_i64, + per_partition_topk_i64, + topk_i64] __device__(int64_t idx) { + const int64_t q = idx / topk_i64; + const int64_t j_out = idx % topk_i64; + const uint32_t pos = positions_ptr[q * topk_i64 + j_out]; + const int64_t p = pos / static_cast(per_partition_topk_i64); + const int64_t j_in = pos % static_cast(per_partition_topk_i64); + return static_cast( + intermediate_neighbors_ptr[p * partition_stride_i64 + q * per_partition_topk_i64 + j_in]); + }); + } +} + } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh index ea01c2ce78..5573ec8447 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh @@ -15,6 +15,9 @@ namespace cuvs::neighbors::cagra::detail { template using cagra_bitset = cuvs::neighbors::detail::bitset_filter_data_t; +template +using mp_cagra_bitset = cuvs::neighbors::detail::mp_bitset_filter_data_t; + /// Host: bitset payload for kernels plus query offset for wrapped CAGRA filters. template struct cagra_sample_filter { @@ -22,6 +25,13 @@ struct cagra_sample_filter { std::uint32_t query_id_offset{0}; }; +/// Multi-partition variant. +template +struct mp_cagra_sample_filter { + mp_cagra_bitset bitset{}; + std::uint32_t query_id_offset{0}; +}; + template struct is_bitset_filter : std::false_type {}; @@ -29,6 +39,13 @@ template struct is_bitset_filter> : std::true_type {}; +template +struct is_mp_bitset_filter : std::false_type {}; + +template +struct is_mp_bitset_filter< + cuvs::neighbors::filtering::multi_partition_bitset_filter> : std::true_type {}; + /// Host: fill @ref cagra_sample_filter from a CAGRA filter object (used by JIT LTO launchers). template cagra_sample_filter extract_cagra_sample_filter(const SampleFilterT& sample_filter) @@ -50,4 +67,20 @@ cagra_sample_filter extract_cagra_sample_filter(const SampleFilter return out; } +/// Host: fill @ref mp_cagra_sample_filter from a multi-partition CAGRA filter (mp JIT launchers). +template +mp_cagra_sample_filter extract_cagra_mp_sample_filter( + const SampleFilterT& sample_filter) +{ + mp_cagra_sample_filter out; + if constexpr (is_mp_bitset_filter>::value) { + const auto& combined = sample_filter.combined_bitset_; + out.bitset.bitset_ptr = const_cast(combined.data()); + out.bitset.bitset_len = static_cast(combined.size()); + out.bitset.original_nbits = static_cast(combined.get_original_nbits()); + out.bitset.partition_offsets = sample_filter.partition_offsets_; + } + return out; +} + } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp index 60d17c5128..a91df1da28 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp @@ -117,6 +117,98 @@ std::shared_ptr build_multi_cta_launcher( return planner.get_launcher(); } +template +std::shared_ptr build_single_cta_mp_launcher( + const dataset_descriptor_host& dataset_desc, + bool topk_by_bitonic_sort, + bool bitonic_sort_and_merge_multi_warps) +{ + single_cta_search::CagraSingleCtaMpSearchPlanner + planner(dataset_desc.metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + if constexpr (std::is_same_v) { + planner.add_setup_workspace_device_function( + dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + planner.add_compute_distance_device_function( + dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + } else { + planner.add_setup_workspace_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim); + planner.add_compute_distance_device_function( + dataset_desc.metric, dataset_desc.team_size, dataset_desc.dataset_block_dim); + } + planner.add_search_kernel_fragment(topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps); + planner.add_sample_filter_device_function(); + return planner.get_launcher(); +} + +template +std::shared_ptr build_multi_cta_mp_launcher( + const dataset_descriptor_host& dataset_desc) +{ + multi_cta_search::CagraMultiCtaMpSearchPlanner + planner(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + if constexpr (std::is_same_v) { + planner.add_setup_workspace_device_function( + dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + planner.add_compute_distance_device_function( + dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + } else { + planner.add_setup_workspace_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim); + planner.add_compute_distance_device_function( + dataset_desc.metric, dataset_desc.team_size, dataset_desc.dataset_block_dim); + } + planner.add_search_multi_cta_kernel_fragment(); + planner.add_sample_filter_device_function(); + return planner.get_launcher(); +} + template make_cagra_multi_cta_jit_launcher( SourceIndexT>(dataset_desc); } +/// Build a JIT AlgorithmLauncher for the multi-partition single-CTA CAGRA search. +template +std::shared_ptr make_cagra_single_cta_mp_jit_launcher( + const dataset_descriptor_host& dataset_desc, + bool topk_by_bitonic_sort, + bool bitonic_sort_and_merge_multi_warps) +{ + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + return cagra_jit_launcher_factory_detail::build_single_cta_mp_launcher( + dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps); + } + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + return cagra_jit_launcher_factory_detail::build_single_cta_mp_launcher( + dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps); + } + using QueryTag = query_type_tag_standard_t; + return cagra_jit_launcher_factory_detail::build_single_cta_mp_launcher( + dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps); +} + +/// Build a JIT AlgorithmLauncher for the multi-partition multi-CTA CAGRA search. +template +std::shared_ptr make_cagra_multi_cta_mp_jit_launcher( + const dataset_descriptor_host& dataset_desc) +{ + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + return cagra_jit_launcher_factory_detail::build_multi_cta_mp_launcher( + dataset_desc); + } + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + return cagra_jit_launcher_factory_detail::build_multi_cta_mp_launcher( + dataset_desc); + } + using QueryTag = query_type_tag_standard_t; + return cagra_jit_launcher_factory_detail::build_multi_cta_mp_launcher(dataset_desc); +} + /// Build a JIT AlgorithmLauncher for multi-kernel CAGRA helpers that need `setup_workspace` and /// `compute_distance` linked (e.g. `random_pickup`, `compute_distance_to_child_nodes`). For /// `apply_filter_kernel` only, use `make_cagra_apply_filter_jit_launcher` instead. Use diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp index 370dbd33d8..6dc2c2d313 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp @@ -12,6 +12,7 @@ #include #include "../compute_distance.hpp" // dataset_descriptor_base_t +#include "../multi_partition_desc.hpp" #include "cagra_bitset.cuh" #include "search_single_cta_device_helpers.cuh" @@ -51,6 +52,30 @@ using search_single_cta_kernel_func_t = namespace single_cta_search { +template +using search_single_cta_mp_kernel_func_t = + void(const multi_partition_desc_t*, + const DataT* const, + IndexT* const, + DistanceT* const, + const std::uint32_t, + const unsigned, + const uint64_t, + const std::uint32_t, + IndexT* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + std::uint32_t* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + mp_cagra_bitset); + template using search_single_cta_p_kernel_func_t = void(worker_handle_t*, @@ -82,6 +107,25 @@ using search_single_cta_p_kernel_func_t = namespace multi_cta_search { +template +using search_multi_cta_mp_kernel_func_t = + void(const multi_partition_desc_t*, + IndexT* const, + DistanceT* const, + const DataT* const, + const std::uint32_t, + const std::uint32_t, + const unsigned, + const uint64_t, + const std::uint32_t, + IndexT* const, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + const std::uint32_t, + mp_cagra_bitset); + template using search_multi_cta_kernel_func_t = void(IndexT* const, diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh index d01f58166d..f9c29bde96 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_impl.cuh @@ -38,4 +38,21 @@ __device__ bool sample_filter_bitset_impl(uint32_t /*query_id*/, return view.test(node_id); } +template +__device__ bool sample_filter_mp_bitset_impl(uint32_t /*query_id*/, + SourceIndexT node_id, + void* filter_data) +{ + if (filter_data == nullptr) { return true; } + + auto* data = static_cast*>(filter_data); + if (data->bitset_ptr == nullptr) { return true; } + if (data->partition_offsets == nullptr) { return true; } + + raft::core::bitset_view const view{ + data->bitset_ptr, data->bitset_len, data->original_nbits}; + const auto offset = static_cast(data->partition_offsets[blockIdx.z]); + return view.test(offset + node_id); +} + } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json index 0136587b48..95554cc0ee 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_matrix.json @@ -1,5 +1,5 @@ { - "filter_name": ["none", "bitset"], + "filter_name": ["none", "bitset", "mp_bitset"], "_bitset": [ { "bitset_type": "uint32_t", diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh index 492195f359..9d086b1bb5 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh @@ -8,6 +8,7 @@ #include "../hashmap.hpp" #include "../utils.hpp" +#include #include #include @@ -362,4 +363,254 @@ __device__ void search_kernel_jit( #endif } +// Multi-partition variant of search_kernel_jit. Grid is (num_cta_per_query, num_queries, +// num_partitions); per-partition data (dataset_desc, graph, graph_degree) is read from +// partition_descs[blockIdx.z]. Cross-CTA traversed_hashmap is per-(query, partition), indexed +// by row = partition_id * num_queries + query_id. Outputs land in +// [num_partitions, num_queries, num_cta_per_query, itopk_size] partition-major. +template +__device__ void search_multi_cta_mp_jit( + const multi_partition_desc_t* partition_descs, + IndexT* const result_indices_ptr, + DistanceT* const result_distances_ptr, + const DataT* const queries_ptr, + const uint32_t max_elements, + const uint32_t max_graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const uint32_t visited_hash_bitlen, + IndexT* const traversed_hashmap_ptr, + const uint32_t traversed_hash_bitlen, + const uint32_t itopk_size, + const uint32_t min_iteration, + const uint32_t max_iteration, + const uint32_t query_id_offset, + mp_cagra_bitset bitset) +{ + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + + const auto num_queries = gridDim.y; + const auto query_id = blockIdx.y; + const auto num_cta_per_query = gridDim.x; + const auto cta_id = blockIdx.x; + const auto partition_id = blockIdx.z; + const auto row = partition_id * num_queries + query_id; + + const auto& part = partition_descs[partition_id]; + const auto* part_desc = part.dataset_desc; + const auto* knn_graph = part.graph; + const auto graph_degree = part.graph_degree; + + extern __shared__ uint8_t smem[]; + + const auto result_buffer_size = itopk_size + max_graph_degree; + const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); + assert(result_buffer_size_32 <= max_elements); + + uint32_t smem_ws_size_in_bytes = part_desc->smem_ws_size_in_bytes(); + + auto smem_desc = + setup_workspace(part_desc, smem, queries_ptr, query_id); + + auto* __restrict__ result_indices_buffer = + reinterpret_cast(smem + smem_ws_size_in_bytes); + auto* __restrict__ result_distances_buffer = + reinterpret_cast(result_indices_buffer + result_buffer_size_32); + auto* __restrict__ local_visited_hashmap_ptr = + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto* __restrict__ parent_indices_buffer = + reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); + auto* __restrict__ result_position = reinterpret_cast(parent_indices_buffer + 1); + + INDEX_T* const local_traversed_hashmap_ptr = + traversed_hashmap_ptr + (hashmap::get_size(traversed_hash_bitlen) * row); + + constexpr INDEX_T invalid_index = ~static_cast(0); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen); + __syncthreads(); + + uint32_t block_id = cta_id + (num_cta_per_query * row); + uint32_t num_blocks = num_cta_per_query * num_queries * gridDim.z; + + compute_distance_to_random_nodes_jit( + result_indices_buffer, + result_distances_buffer, + smem_desc, + graph_degree, + num_distilation, + rand_xor_mask, + static_cast(nullptr), + 0u, + local_visited_hashmap_ptr, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, + block_id, + num_blocks, + static_cast(0)); + __syncthreads(); + + uint32_t iter = 0; + while (1) { + if (threadIdx.x < 32) { + if constexpr (std::is_same_v) { + if (max_elements <= 64) { + topk_by_bitonic_sort_wrapper_64( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else if (max_elements <= 128) { + topk_by_bitonic_sort_wrapper_128( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else { + assert(max_elements <= 256); + topk_by_bitonic_sort_wrapper_256( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } + } else { + if (max_elements <= 64) { + topk_by_bitonic_sort<64, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else if (max_elements <= 128) { + topk_by_bitonic_sort<128, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else { + assert(max_elements <= 256); + topk_by_bitonic_sort<256, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } + } + } + __syncthreads(); + + if (iter + 1 >= max_iteration) { break; } + + if (threadIdx.x < 32) { + pickup_next_parent(parent_indices_buffer, + result_indices_buffer, + result_distances_buffer, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); + } else { + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen, 32); + } + __syncthreads(); + + if ((parent_indices_buffer[0] == invalid_index) && (iter >= min_iteration)) { break; } + + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + if ((i >= itopk_size) && (index & index_msb_1_mask)) { + hashmap::remove( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } else { + index &= ~index_msb_1_mask; + hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); + } + } + if (threadIdx.x == blockDim.x - 1) { result_position[0] = result_buffer_size_32; } + __syncthreads(); + + compute_distance_to_child_nodes_jit(result_indices_buffer, + result_distances_buffer, + smem_desc, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, + parent_indices_buffer, + result_indices_buffer, + 1, + result_position, + result_buffer_size_32); + + for (uint32_t i = threadIdx.x; i < result_position[0]; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index || index & index_msb_1_mask) { continue; } + if (hashmap::search(local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + __syncthreads(); + + for (unsigned p = threadIdx.x; p < 1; p += blockDim.x) { + if (parent_indices_buffer[p] != invalid_index) { + const auto parent_id = result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; + if (!sample_filter(query_id + query_id_offset, + static_cast(parent_id), + bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); + result_indices_buffer[parent_indices_buffer[p]] = invalid_index; + } + } + } + __syncthreads(); + + iter++; + } + + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + index &= ~index_msb_1_mask; + if (!sample_filter(query_id + query_id_offset, + static_cast(index), + bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + __syncthreads(); + + if (threadIdx.x < 32) { + uint32_t offset = 0; + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += 32) { + INDEX_T index = result_indices_buffer[i]; + bool is_valid = false; + if (index != invalid_index) { + if (index & index_msb_1_mask) { + is_valid = true; + index &= ~index_msb_1_mask; + } else if ((offset < itopk_size) && + hashmap::insert( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + is_valid = true; + } + } + const auto mask = __ballot_sync(0xffffffff, is_valid); + if (is_valid) { + const auto j = offset + __popc(mask & ((1 << threadIdx.x) - 1)); + if (j < itopk_size) { + uint32_t k = j + (itopk_size * (cta_id + (num_cta_per_query * row))); + result_indices_ptr[k] = index & ~index_msb_1_mask; + if (result_distances_ptr != nullptr) { + result_distances_ptr[k] = result_distances_buffer[i]; + } + } else { + hashmap::remove(local_traversed_hashmap_ptr, traversed_hash_bitlen, index); + } + } + offset += __popc(mask); + } + for (uint32_t i = offset + threadIdx.x; i < itopk_size; i += 32) { + uint32_t k = i + (itopk_size * (cta_id + (num_cta_per_query * row))); + result_indices_ptr[k] = invalid_index; + if (result_distances_ptr != nullptr) { + result_distances_ptr[k] = utils::get_max_value(); + } + } + } +} + } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_mp_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_mp_kernel.cu.in new file mode 100644 index 0000000000..657232934e --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_mp_kernel.cu.in @@ -0,0 +1,63 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using mp_desc_t = cuvs::neighbors::cagra::detail::multi_cta_search:: + multi_partition_desc_t; +using mp_cagra_bitset_t = cuvs::neighbors::cagra::detail::mp_cagra_bitset; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +extern "C" __global__ __launch_bounds__(1024, 1) void search_multi_cta_mp( + const mp_desc_t* partition_descs, + index_t* const result_indices_ptr, + distance_t* const result_distances_ptr, + const data_t* const queries_ptr, + const std::uint32_t max_elements, + const std::uint32_t max_graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const std::uint32_t visited_hash_bitlen, + index_t* const traversed_hashmap_ptr, + const std::uint32_t traversed_hash_bitlen, + const std::uint32_t itopk_size, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + const std::uint32_t query_id_offset, + mp_cagra_bitset_t bitset) +{ + search_multi_cta_mp_jit(partition_descs, + result_indices_ptr, + result_distances_ptr, + queries_ptr, + max_elements, + max_graph_degree, + num_distilation, + rand_xor_mask, + visited_hash_bitlen, + traversed_hashmap_ptr, + traversed_hash_bitlen, + itopk_size, + min_iteration, + max_iteration, + query_id_offset, + bitset); +} + +static_assert( + std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_mp_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_mp_matrix.json new file mode 100644 index 0000000000..adfdf1e78b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_mp_matrix.json @@ -0,0 +1,13 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "u8"}, + {"data_type": "int8_t", "data_abbrev": "i8"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "u32"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "u32"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp index 5e6ea43130..1cda3799dc 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp @@ -40,4 +40,33 @@ struct CagraMultiCtaSearchPlanner } }; +template +struct CagraMultiCtaMpSearchPlanner + : CagraPlannerBase { + static inline LauncherJitCache launcher_jit_cache{}; + + CagraMultiCtaMpSearchPlanner(cuvs::distance::DistanceType /*metric*/, + uint32_t /*team_size*/, + uint32_t /*dataset_block_dim*/, + bool /*is_vpq*/, + uint32_t /*pq_bits*/, + uint32_t /*pq_len*/) + : CagraPlannerBase( + "search_multi_cta_mp", launcher_jit_cache) + { + } + + void add_search_multi_cta_kernel_fragment() + { + this->template add_static_fragment< + fragment_tag_search_multi_cta_mp>(); + } +}; + } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh index 282490559a..b4fedbb549 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh @@ -10,6 +10,7 @@ // neighbors_device_intrinsics / memory_ops come via search_single_cta_device_helpers.cuh #include "../hashmap.hpp" +#include "../multi_partition_desc.hpp" #include "../topk_by_radix.cuh" #include "../utils.hpp" @@ -51,7 +52,8 @@ template + typename SourceIndexT, + typename BitsetT = cagra_bitset> RAFT_DEVICE_INLINE_FUNCTION void search_core( uintptr_t result_indices_ptr, DistanceT* const result_distances_ptr, @@ -78,7 +80,7 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core( const std::uint32_t query_id, const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter const dataset_descriptor_base_t* dataset_desc, - cagra_bitset bitset, + BitsetT bitset, const IndexT graph_size = 0) // Original number of bits { using LOAD_T = device::LOAD_128BIT_T; @@ -642,4 +644,86 @@ __device__ void search_single_cta_p_impl( } } +// Multi-partition variant of search_kernel_jit. Grid is (1, num_queries, num_partitions); each +// CTA handles one (query, partition) pair via search_core with that partition's dataset_desc, +// graph, and graph_degree. Hashmap is partition-major [num_partitions, num_queries, hash_size]; +// result buffers are partition-major [num_partitions, num_queries, top_k]. +template +__device__ void search_single_cta_mp_impl( + const multi_partition_desc_t* partitions, + const DataT* queries_ptr, + IndexT* intermediate_neighbors_ptr, + DistanceT* intermediate_distances_ptr, + const std::uint32_t top_k, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const std::uint32_t num_seeds, + IndexT* visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, + mp_cagra_bitset bitset) +{ + const uint32_t query_id = blockIdx.y; + const uint32_t part_id = blockIdx.z; + const auto& part = partitions[part_id]; + + IndexT* part_hashmap_ptr = (visited_hashmap_ptr != nullptr) + ? visited_hashmap_ptr + part_id * static_cast(gridDim.y) * + hashmap::get_size(hash_bitlen) + : nullptr; + + const size_t partition_offset = static_cast(part_id) * gridDim.y * top_k; + IndexT* part_result_indices = intermediate_neighbors_ptr + partition_offset; + DistanceT* part_result_distances = intermediate_distances_ptr + partition_offset; + + constexpr uintptr_t kTag = raft::Pow2::Log2; + const uintptr_t tagged_indices_ptr = reinterpret_cast(part_result_indices) | kTag; + + search_core(tagged_indices_ptr, + part_result_distances, + top_k, + queries_ptr, + part.graph, + part.graph_degree, + static_cast(nullptr), + num_distilation, + rand_xor_mask, + static_cast(nullptr), + num_seeds, + part_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id, + query_id_offset, + part.dataset_desc, + bitset); +} + } // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_mp_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_mp_kernel.cu.in new file mode 100644 index 0000000000..f40ba3f3d4 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_mp_kernel.cu.in @@ -0,0 +1,81 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace { + +constexpr bool k_topk_by_bitonic_sort = @topk_by_bitonic_sort@; +constexpr bool k_bitonic_sort_and_merge_multi_warps = @bitonic_sort_and_merge_multi_warps@; + +using data_t = @data_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using mp_desc_t = cuvs::neighbors::cagra::detail::single_cta_search:: + multi_partition_desc_t; +using mp_cagra_bitset_t = cuvs::neighbors::cagra::detail::mp_cagra_bitset; + +} // namespace + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +extern "C" __global__ __launch_bounds__(1024, 1) void search_single_cta_mp( + const mp_desc_t* partitions, + const data_t* const queries_ptr, + index_t* const intermediate_neighbors_ptr, + distance_t* const intermediate_distances_ptr, + const std::uint32_t top_k, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const std::uint32_t num_seeds, + index_t* const visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, + mp_cagra_bitset_t bitset) +{ + search_single_cta_mp_impl(partitions, + queries_ptr, + intermediate_neighbors_ptr, + intermediate_distances_ptr, + top_k, + num_distilation, + rand_xor_mask, + num_seeds, + visited_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id_offset, + bitset); +} + +static_assert( + std::is_same_v>); + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_mp_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_mp_matrix.json new file mode 100644 index 0000000000..ba1334f11d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_mp_matrix.json @@ -0,0 +1,21 @@ +[ + { + "_data": [ + {"data_type": "float", "data_abbrev": "f"}, + {"data_type": "__half", "data_abbrev": "h"}, + {"data_type": "uint8_t", "data_abbrev": "u8"}, + {"data_type": "int8_t", "data_abbrev": "i8"} + ], + "_source_index": [{"source_index_type": "uint32_t", "source_index_abbrev": "u32"}], + "_index": [{"index_type": "uint32_t", "index_abbrev": "u32"}], + "_distance": [{"distance_type": "float", "distance_abbrev": "f"}], + "_topk_by_bitonic": [ + {"topk_by_bitonic_sort": "true", "topk_by_bitonic_sort_str": "topk_by_bitonic_sort"}, + {"topk_by_bitonic_sort": "false", "topk_by_bitonic_sort_str": "no_topk_by_bitonic_sort"} + ], + "_bitonic_sort_and_merge_multi_warps": [ + {"bitonic_sort_and_merge_multi_warps": "true", "bitonic_sort_and_merge_multi_warps_str": "bitonic_sort_and_merge_multi_warps"}, + {"bitonic_sort_and_merge_multi_warps": "false", "bitonic_sort_and_merge_multi_warps_str": "no_bitonic_sort_and_merge_multi_warps"} + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp index a1cbe39dfe..26491c4795 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp @@ -105,4 +105,63 @@ struct CagraSingleCtaSearchPlanner } }; +template +struct CagraSingleCtaMpSearchPlanner + : CagraPlannerBase { + static inline LauncherJitCache launcher_jit_cache{}; + + CagraSingleCtaMpSearchPlanner(cuvs::distance::DistanceType /*metric*/, + bool /*topk_by_bitonic_sort*/, + bool /*bitonic_sort_and_merge_multi_warps*/, + uint32_t /*team_size*/, + uint32_t /*dataset_block_dim*/, + bool /*is_vpq*/, + uint32_t /*pq_bits*/, + uint32_t /*pq_len*/) + : CagraPlannerBase( + "search_single_cta_mp", launcher_jit_cache) + { + } + + void add_search_kernel_fragment(bool topk_by_bitonic_sort, + bool bitonic_sort_and_merge_multi_warps) + { + if (topk_by_bitonic_sort && bitonic_sort_and_merge_multi_warps) { + this->template add_static_fragment>(); + } else if (topk_by_bitonic_sort && !bitonic_sort_and_merge_multi_warps) { + this->template add_static_fragment>(); + } else if (!topk_by_bitonic_sort && bitonic_sort_and_merge_multi_warps) { + this->template add_static_fragment>(); + } else { + this->template add_static_fragment>(); + } + } +}; + } // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/multi_partition_desc.hpp b/cpp/src/neighbors/detail/cagra/multi_partition_desc.hpp new file mode 100644 index 0000000000..7221424878 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/multi_partition_desc.hpp @@ -0,0 +1,34 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +template +struct alignas(16) multi_partition_desc_t { + const dataset_descriptor_base_t* dataset_desc; + const IndexT* graph; + uint32_t graph_degree; + uint32_t _pad; +}; + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +template +struct alignas(16) multi_partition_desc_t { + const dataset_descriptor_base_t* dataset_desc; + const IndexT* graph; + uint32_t graph_degree; + uint32_t _pad; +}; + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index baf9336e6d..5917d64234 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -281,6 +281,58 @@ struct search output_indices_ptr, num_queries, topk)); } } + + /** + * Multi-partition search. Drives `search_kernel_mp` across all partitions in one fused + * launch. Each partition's data (dataset_desc, graph, graph_degree) is read by the kernel + * from partition_descs[blockIdx.z]; smem and the result buffer are sized for the max + * graph_degree across partitions. + * + * No per-partition top-k merge is performed here — the kernel emits + * num_cta_per_query * itopk_size candidates per (query, partition) into the caller's + * intermediate buffer (laid out [num_partitions, num_queries, num_cta_per_query * itopk_size] + * partition-major). The cross-partition select_k in cagra::detail::search_multi_partition + * consolidates everything into the final global top-k in one shot. + */ + template + void run_multi_partition( + raft::resources const& res, + const multi_partition_desc_t* partition_descs, + uint32_t num_partitions, + uint32_t max_graph_degree, + const DATA_T* queries_ptr, + uint32_t num_queries, + INDEX_T* + intermediate_indices_ptr, // [num_partitions, num_queries, num_cta_per_query * itopk_size] + DISTANCE_T* intermediate_distances_ptr, + SampleFilterT_ sample_filter) + { + cudaStream_t stream = raft::resource::get_cuda_stream(res); + + // Scale the cross-CTA traversed hashmap to (num_queries * num_partitions) rows. + const size_t traversed_hash_size = hashmap::get_size(hash_bitlen); + hashmap.resize(traversed_hash_size * static_cast(num_queries) * num_partitions, stream); + + select_and_run_mp( + dataset_desc, + partition_descs, + num_partitions, + max_graph_degree, + intermediate_indices_ptr, + intermediate_distances_ptr, + queries_ptr, + num_queries, + /* search_params */ *this, + thread_block_size, + result_buffer_size, + smem_size, + small_hash_bitlen, + hash_bitlen, + hashmap.data(), + num_cta_per_query, + sample_filter, + stream); + } }; } // namespace multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in index f350a2c9ef..6d38b7948f 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in @@ -11,6 +11,8 @@ namespace { using data_t = @data_type@; using bitset_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< cuvs::neighbors::filtering::bitset_filter>; +using mp_bitset_filter_t = + cuvs::neighbors::filtering::multi_partition_bitset_filter; } // namespace @@ -20,5 +22,11 @@ instantiate_kernel_selection(data_t, float, cuvs::neighbors::filtering::none_sample_filter); instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t); +instantiate_kernel_selection(data_t, uint32_t, float, mp_bitset_filter_t); +instantiate_kernel_selection_mp(data_t, + uint32_t, + float, + cuvs::neighbors::filtering::none_sample_filter); +instantiate_kernel_selection_mp(data_t, uint32_t, float, mp_bitset_filter_t); } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh index 72c40c6973..f374053615 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh @@ -37,4 +37,25 @@ namespace cuvs::neighbors::cagra::detail::multi_cta_search { SampleFilterT sample_filter, \ cudaStream_t stream); +#define instantiate_kernel_selection_mp(DataT, IndexT, DistanceT, SampleFilterT) \ + template void select_and_run_mp( \ + const dataset_descriptor_host& ref_dataset_desc, \ + const multi_partition_desc_t* partition_descs, \ + uint32_t num_partitions, \ + uint32_t max_graph_degree, \ + IndexT* intermediate_indices_ptr, \ + DistanceT* intermediate_distances_ptr, \ + const DataT* queries_ptr, \ + uint32_t num_queries, \ + const search_params& ps, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + uint32_t visited_hash_bitlen, \ + int64_t traversed_hash_bitlen, \ + IndexT* traversed_hashmap_ptr, \ + uint32_t num_cta_per_query, \ + SampleFilterT sample_filter, \ + cudaStream_t stream); + } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh index 4fce5a2a12..8751a7bc0b 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel.cuh @@ -5,6 +5,7 @@ #pragma once #include +#include #include @@ -38,4 +39,36 @@ void select_and_run(const dataset_descriptor_host& dat SampleFilterT sample_filter, cudaStream_t stream); -} +/** + * Multi-partition launcher. Drives `search_kernel_mp` with a 3D grid + * (num_cta_per_query, num_queries, num_partitions). Per-(query, partition) outputs are written + * into the intermediate buffer in partition-major layout + * [num_partitions, num_queries, num_cta_per_query * itopk_size]. Each partition's data + * (dataset_desc, graph, graph_degree) is read by the kernel from partition_descs[blockIdx.z]; + * smem and the result buffer are sized for the max graph_degree across partitions. + */ +template +void select_and_run_mp(const dataset_descriptor_host& ref_dataset_desc, + const multi_partition_desc_t* partition_descs, + uint32_t num_partitions, + uint32_t max_graph_degree, + IndexT* intermediate_indices_ptr, + DistanceT* intermediate_distances_ptr, + const DataT* queries_ptr, + uint32_t num_queries, + const search_params& ps, + uint32_t block_size, + uint32_t result_buffer_size, + uint32_t smem_size, + uint32_t visited_hash_bitlen, + int64_t traversed_hash_bitlen, + IndexT* traversed_hashmap_ptr, + uint32_t num_cta_per_query, + SampleFilterT sample_filter, + cudaStream_t stream); + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh index 663f8a4559..7bd63d0886 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh @@ -13,6 +13,7 @@ #include "compute_distance.hpp" // For dataset_descriptor_host #include "jit_lto_kernels/cagra_jit_launcher_factory.hpp" #include "jit_lto_kernels/kernel_def.hpp" +#include "multi_partition_desc.hpp" #include "sample_filter_utils.cuh" // For CagraSampleFilterWithQueryIdOffset #include "search_plan.cuh" // For search_params #include "set_value_batch.cuh" // For set_value_batch @@ -151,4 +152,105 @@ void select_and_run(const dataset_descriptor_host& dat RAFT_CUDA_TRY(cudaPeekAtLastError()); } +// Multi-partition launcher. Drives `search_multi_cta_mp` with a 3D grid +// (num_cta_per_query, num_queries, num_partitions). `ref_dataset_desc` is used only for JIT tag +// dispatch (metric / vpq / team_size / block_dim) and must be representative of every +// partition's descriptor. Per-partition device descriptors are read from `partition_descs` by +// the kernel itself. +template +void select_and_run_mp(const dataset_descriptor_host& ref_dataset_desc, + const multi_partition_desc_t* partition_descs, + uint32_t num_partitions, + uint32_t max_graph_degree, + IndexT* intermediate_indices_ptr, + DistanceT* intermediate_distances_ptr, + const DataT* queries_ptr, + uint32_t num_queries, + const search_params& ps, + uint32_t block_size, + uint32_t result_buffer_size, + uint32_t smem_size, + uint32_t visited_hash_bitlen, + int64_t traversed_hash_bitlen, + IndexT* traversed_hashmap_ptr, + uint32_t num_cta_per_query, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + const auto bf = extract_cagra_mp_sample_filter(sample_filter); + const uint32_t query_id_offset = bf.query_id_offset; + + std::shared_ptr launcher = + make_cagra_multi_cta_mp_jit_launcher>(ref_dataset_desc); + + if (!launcher) { RAFT_FAIL("Failed to get JIT launcher"); } + + uint32_t max_elements{}; + if (result_buffer_size <= 64) { + max_elements = 64; + } else if (result_buffer_size <= 128) { + max_elements = 128; + } else if (result_buffer_size <= 256) { + max_elements = 256; + } else { + THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); + } + + const uint32_t traversed_hash_size = hashmap::get_size(traversed_hash_bitlen); + set_value_batch(traversed_hashmap_ptr, + traversed_hash_size, + ~static_cast(0), + traversed_hash_size, + static_cast(num_queries) * num_partitions, + stream); + + dim3 block_dims(block_size, 1, 1); + dim3 grid_dims(num_cta_per_query, num_queries, num_partitions); + + const uint32_t max_graph_degree_u32 = static_cast(max_graph_degree); + const uint32_t traversed_hash_bitlen_u32 = static_cast(traversed_hash_bitlen); + const uint32_t itopk_size_u32 = static_cast(ps.itopk_size); + const uint32_t min_iterations_u32 = static_cast(ps.min_iterations); + const uint32_t max_iterations_u32 = static_cast(ps.max_iterations); + const unsigned num_random_samplings_u = static_cast(ps.num_random_samplings); + + auto kernel_launcher = [&]() -> void { + launcher->dispatch< + multi_cta_search::search_multi_cta_mp_kernel_func_t>( + stream, + grid_dims, + block_dims, + smem_size, + partition_descs, + intermediate_indices_ptr, + intermediate_distances_ptr, + queries_ptr, + max_elements, + max_graph_degree_u32, + num_random_samplings_u, + ps.rand_xor_mask, + visited_hash_bitlen, + traversed_hashmap_ptr, + traversed_hash_bitlen_u32, + itopk_size_u32, + min_iterations_u32, + max_iterations_u32, + query_id_offset, + bf.bitset); + }; + cuvs::neighbors::detail::safely_launch_kernel_with_smem_size< + multi_cta_search::search_multi_cta_mp_kernel_func_t>( + smem_size, kernel_launcher, launcher->get_kernel()); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index 74e34e0a14..1b1e4796f6 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -89,6 +89,13 @@ struct search uint32_t num_itopk_candidates; + /** Number of elements in a hashmap covering @p n_queries queries across @p n_segments segments. + */ + static size_t hashmap_element_count(size_t n_segments, size_t n_queries, size_t h_bitlen) + { + return n_segments * n_queries * hashmap::get_size(h_bitlen); + } + search(raft::resources const& res, search_params params, const dataset_descriptor_host& dataset_desc, @@ -199,12 +206,76 @@ struct search RAFT_LOG_DEBUG("# smem_size: %u", smem_size); hashmap_size = 0; if (small_hash_bitlen == 0 && !this->persistent) { - hashmap_size = max_queries * hashmap::get_size(hash_bitlen); + hashmap_size = hashmap_element_count(1, max_queries, hash_bitlen); hashmap.resize(hashmap_size, raft::resource::get_cuda_stream(res)); } RAFT_LOG_DEBUG("# hashmap_size: %lu", hashmap_size); } + /** + * @brief Search all partitions concurrently in a single kernel launch. + * + * Queries and intermediate result buffers are shared across partitions. The intermediate + * buffers must be laid out as [num_partitions][num_queries][topk]; the caller is responsible + * for the post-merge reduction into the user-facing top-k. + * + * @param res RAFT resources (stream is extracted from here) + * @param partition_descs device pointer to [num_partitions] descriptors + * @param num_partitions number of partitions (gridDim.z) + * @param queries_ptr device pointer to [num_queries, dim] queries + * @param num_queries queries per partition (gridDim.y) + * @param intermediate_neighbors device buffer [num_partitions, num_queries, topk] + * @param intermediate_distances device buffer [num_partitions, num_queries, topk] + * @param topk neighbors to return per (query, partition) + */ + template + void run_multi_partition( + raft::resources const& res, + const multi_partition_desc_t* partition_descs, + uint32_t num_partitions, + const DATA_T* queries_ptr, + uint32_t num_queries, + INDEX_T* intermediate_neighbors, + DISTANCE_T* intermediate_distances, + uint32_t topk, + SampleFilterT sample_filter) + { + cudaStream_t stream = raft::resource::get_cuda_stream(res); + + // Allocate global hashmap when small-hash is disabled via the workspace pool + // (no cudaMallocAsync/cudaFreeAsync after pool warmup). + // Layout: [num_partitions][num_queries][hash_size]. + lightweight_uvector mp_hashmap_buf(res); + INDEX_T* mp_hashmap_ptr = nullptr; + if (small_hash_bitlen == 0) { + const size_t mp_hashmap_elems = + hashmap_element_count(num_partitions, num_queries, hash_bitlen); + mp_hashmap_buf.resize(mp_hashmap_elems, stream); + mp_hashmap_ptr = mp_hashmap_buf.data(); + } + + select_and_run_multi_partition( + dataset_desc, + partition_descs, + num_partitions, + queries_ptr, + num_queries, + intermediate_neighbors, + intermediate_distances, + *this, + topk, + num_itopk_candidates, + static_cast(thread_block_size), + smem_size, + hash_bitlen, + mp_hashmap_ptr, + small_hash_bitlen, + small_hash_reset_interval, + sample_filter, + stream); + // mp_hashmap_buf destructor returns memory to workspace pool (stream-ordered). + } + void operator()( raft::resources const& res, raft::device_matrix_view graph, diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in index 85342e7093..4577eae843 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in @@ -11,6 +11,8 @@ namespace { using data_t = @data_type@; using bitset_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< cuvs::neighbors::filtering::bitset_filter>; +using mp_bitset_filter_t = + cuvs::neighbors::filtering::multi_partition_bitset_filter; } // namespace @@ -20,5 +22,11 @@ instantiate_kernel_selection(data_t, float, cuvs::neighbors::filtering::none_sample_filter); instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t); +instantiate_kernel_selection(data_t, uint32_t, float, mp_bitset_filter_t); +instantiate_kernel_selection_mp(data_t, + uint32_t, + float, + cuvs::neighbors::filtering::none_sample_filter); +instantiate_kernel_selection_mp(data_t, uint32_t, float, mp_bitset_filter_t); } // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh index d242e13b95..37bc85aab8 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh @@ -37,4 +37,25 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search { SampleFilterT sample_filter, \ cudaStream_t stream); +#define instantiate_kernel_selection_mp(DataT, IndexT, DistanceT, SampleFilterT) \ + template void select_and_run_multi_partition( \ + const dataset_descriptor_host& ref_dataset_desc, \ + const multi_partition_desc_t* partition_descs, \ + uint32_t num_partitions, \ + const DataT* queries_ptr, \ + uint32_t num_queries, \ + IndexT* intermediate_neighbors_ptr, \ + DistanceT* intermediate_distances_ptr, \ + const search_params& ps, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + IndexT* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + SampleFilterT sample_filter, \ + cudaStream_t stream); + } // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel.cuh index ba308db98b..8b6559cd9c 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel.cuh @@ -5,6 +5,7 @@ #pragma once #include +#include #include @@ -38,4 +39,29 @@ void select_and_run( SampleFilterT sample_filter, cudaStream_t stream); -} +template +void select_and_run_multi_partition( + const dataset_descriptor_host& ref_dataset_desc, + const multi_partition_desc_t* partition_descs, + uint32_t num_partitions, + const DataT* queries_ptr, + uint32_t num_queries, + IndexT* intermediate_neighbors_ptr, + DistanceT* intermediate_distances_ptr, + const search_params& ps, + uint32_t topk, + uint32_t num_itopk_candidates, + uint32_t block_size, + uint32_t smem_size, + int64_t hash_bitlen, + IndexT* hashmap_ptr, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + SampleFilterT sample_filter, + cudaStream_t stream); + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh index 96db9a743b..6c537a16c6 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -14,6 +14,7 @@ #include "hashmap.hpp" #include "jit_lto_kernels/cagra_jit_launcher_factory.hpp" #include "jit_lto_kernels/kernel_def.hpp" +#include "multi_partition_desc.hpp" #include "sample_filter_utils.cuh" // For CagraSampleFilterWithQueryIdOffset #include "search_plan.cuh" // For search_params #include "search_single_cta_kernel_launcher_common.cuh" @@ -882,6 +883,108 @@ void select_and_run( } } +// Multi-partition launcher. Drives `search_single_cta_mp` with a 3D grid +// (1, num_queries, num_partitions). `ref_dataset_desc` is used only for JIT tag dispatch and +// must be representative of every partition's descriptor. Per-partition device descriptors and +// graphs are read from `partition_descs` by the kernel itself. +template +void select_and_run_multi_partition( + const dataset_descriptor_host& ref_dataset_desc, + const multi_partition_desc_t* partition_descs, + uint32_t num_partitions, + const DataT* queries_ptr, + uint32_t num_queries, + IndexT* intermediate_neighbors_ptr, + DistanceT* intermediate_distances_ptr, + const search_params& ps, + uint32_t topk, + uint32_t num_itopk_candidates, + uint32_t block_size, + uint32_t smem_size, + int64_t hash_bitlen, + IndexT* hashmap_ptr, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + const auto bf = extract_cagra_mp_sample_filter(sample_filter); + const mp_cagra_bitset bitset = bf.bitset; + const uint32_t query_id_offset = bf.query_id_offset; + + auto config = compute_launch_config(num_itopk_candidates, ps.itopk_size, block_size); + uint32_t max_candidates = config.max_candidates; + uint32_t max_itopk = config.max_itopk; + bool topk_by_bitonic_sort = config.topk_by_bitonic_sort; + bool bitonic_sort_and_merge_multi_warps = config.bitonic_sort_and_merge_multi_warps; + + std::shared_ptr launcher = + make_cagra_single_cta_mp_jit_launcher>( + ref_dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps); + if (!launcher) { RAFT_FAIL("Failed to get JIT launcher for CAGRA mp search kernel"); } + + const uint32_t hash_bitlen_u32 = static_cast(hash_bitlen); + const uint32_t small_hash_bitlen_u32 = static_cast(small_hash_bitlen); + const uint32_t small_hash_reset_interval_u32 = static_cast(small_hash_reset_interval); + const uint32_t itopk_size_u32 = static_cast(ps.itopk_size); + const uint32_t search_width_u32 = static_cast(ps.search_width); + const uint32_t min_iterations_u32 = static_cast(ps.min_iterations); + const uint32_t max_iterations_u32 = static_cast(ps.max_iterations); + const unsigned num_random_samplings_u = static_cast(ps.num_random_samplings); + + dim3 grid(1, num_queries, num_partitions); + dim3 block(block_size, 1, 1); + + RAFT_LOG_DEBUG("Launching mp JIT kernel: %u threads, %u queries, %u partitions, %u smem", + block_size, + num_queries, + num_partitions, + smem_size); + + auto kernel_launcher = [&]() -> void { + launcher->dispatch>( + stream, + grid, + block, + static_cast(smem_size), + partition_descs, + queries_ptr, + intermediate_neighbors_ptr, + intermediate_distances_ptr, + topk, + num_random_samplings_u, + ps.rand_xor_mask, + 0u, // num_seeds + hashmap_ptr, + max_candidates, + max_itopk, + itopk_size_u32, + search_width_u32, + min_iterations_u32, + max_iterations_u32, + static_cast(nullptr), // num_executed_iterations + hash_bitlen_u32, + small_hash_bitlen_u32, + small_hash_reset_interval_u32, + query_id_offset, + bitset); + }; + + cuvs::neighbors::detail::safely_launch_kernel_with_smem_size< + search_single_cta_mp_kernel_func_t>( + smem_size, kernel_launcher, launcher->get_kernel()); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + // get_runner for JIT persistent runners (similar to non-JIT version) template auto get_runner_jit(Args... args) -> std::shared_ptr diff --git a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp index a16c810409..2f4674774a 100644 --- a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp +++ b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp @@ -100,6 +100,8 @@ struct sample_filter_jit_tag { using namespace cuvs::neighbors::filtering; if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_filter_none{}; + } else if constexpr (is_mp_bitset_filter::value) { + return cuvs::neighbors::detail::tag_filter_mp_bitset{}; } else if constexpr (requires { std::declval().filter; }) { using InnerFilter = decltype(std::declval().filter); if constexpr (is_bitset_filter::value || diff --git a/cpp/src/neighbors/detail/sample_filter_data.cuh b/cpp/src/neighbors/detail/sample_filter_data.cuh index 4c99ca1e3a..40f5f43e4f 100644 --- a/cpp/src/neighbors/detail/sample_filter_data.cuh +++ b/cpp/src/neighbors/detail/sample_filter_data.cuh @@ -21,4 +21,15 @@ struct bitset_filter_data_t { SourceIndexT original_nbits{}; }; +/// Multi-partition variant: a single combined bitset spanning all partitions, plus a device +/// pointer to per-partition bit offsets. The mp sample_filter impl reads +/// `partition_offsets[blockIdx.z]` and adds it to `node_id` before testing the bitset. +template +struct mp_bitset_filter_data_t { + std::uint32_t* bitset_ptr{nullptr}; + SourceIndexT bitset_len{}; + SourceIndexT original_nbits{}; + const std::int64_t* partition_offsets{nullptr}; +}; + } // namespace cuvs::neighbors::detail diff --git a/java/cuvs-java/pom.xml b/java/cuvs-java/pom.xml index 9f1a36c305..5d661b602c 100644 --- a/java/cuvs-java/pom.xml +++ b/java/cuvs-java/pom.xml @@ -11,7 +11,7 @@ com.nvidia.cuvs cuvs-java - 26.08.0 + 26.08.0-SNAPSHOT cuvs-java This project provides Java bindings for cuVS, enabling approximate nearest neighbors search and clustering @@ -56,8 +56,8 @@ - ossrh - https://oss.sonatype.org/content/repositories/snapshots + central + https://central.sonatype.com/repository/maven-snapshots/ @@ -205,14 +205,12 @@ - org.sonatype.plugins - nexus-staging-maven-plugin - 1.6.7 + org.sonatype.central + central-publishing-maven-plugin + 0.7.0 true - ossrh - https://oss.sonatype.org/ - false + central diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java index c87f024124..2e48928636 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ package com.nvidia.cuvs; diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraSearchParams.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraSearchParams.java index 76e1f10bd9..a9ea94a440 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraSearchParams.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraSearchParams.java @@ -25,6 +25,9 @@ public class CagraSearchParams { private long randXORMask; private SearchAlgo searchAlgo; private HashMapMode hashMapMode; + private boolean persistent; + private float persistentLifetime; + private float persistentDeviceUsage; /** * Enum to denote algorithm used to search CAGRA Index. @@ -120,7 +123,10 @@ private CagraSearchParams( int hashmapMinBitlen, float hashmapMaxFillRate, int numRandomSamplings, - long randXORMask) { + long randXORMask, + boolean persistent, + float persistentLifetime, + float persistentDeviceUsage) { this.maxQueries = maxQueries; this.iTopKSize = iTopKSize; this.maxIterations = maxIterations; @@ -134,6 +140,9 @@ private CagraSearchParams( this.hashMapMaxFillRate = hashmapMaxFillRate; this.numRandomSamplings = numRandomSamplings; this.randXORMask = randXORMask; + this.persistent = persistent; + this.persistentLifetime = persistentLifetime; + this.persistentDeviceUsage = persistentDeviceUsage; } /** @@ -254,6 +263,33 @@ public HashMapMode getHashMapMode() { return hashMapMode; } + /** + * Gets whether the persistent kernel is enabled. + * + * @return true if the persistent kernel is enabled + */ + public boolean isPersistent() { + return persistent; + } + + /** + * Gets the persistent kernel lifetime in seconds. + * + * @return the lifetime in seconds + */ + public float getPersistentLifetime() { + return persistentLifetime; + } + + /** + * Gets the fraction of maximum grid size used by the persistent kernel. + * + * @return the device usage fraction (0.0, 1.0] + */ + public float getPersistentDeviceUsage() { + return persistentDeviceUsage; + } + @Override public String toString() { return "CagraSearchParams [maxQueries=" @@ -301,8 +337,11 @@ public static class Builder { private int numRandomSamplings = 1; private float hashMapMaxFillRate = 0.5f; private long randXORMask = 0x128394; - private SearchAlgo searchAlgo; - private HashMapMode hashMapMode; + private SearchAlgo searchAlgo = SearchAlgo.AUTO; + private HashMapMode hashMapMode = HashMapMode.AUTO_HASH; + private boolean persistent = false; + private float persistentLifetime = 2.0f; + private float persistentDeviceUsage = 1.0f; /** * Default constructor. @@ -460,6 +499,43 @@ public Builder withRandXorMask(long randXORMask) { return this; } + /** + * Enables or disables the persistent kernel. + * + *

When enabled, the CAGRA SINGLE_CTA kernel stays resident on the GPU and serves search + * jobs via system-scope atomics. The kernel is shared across all indexes. + * + * @param persistent true to enable the persistent kernel + * @return an instance of this Builder + */ + public Builder withPersistent(boolean persistent) { + this.persistent = persistent; + return this; + } + + /** + * Sets the time in seconds before an idle persistent kernel exits. + * + * @param persistentLifetime lifetime in seconds (default 2.0) + * @return an instance of this Builder + */ + public Builder withPersistentLifetime(float persistentLifetime) { + this.persistentLifetime = persistentLifetime; + return this; + } + + /** + * Sets the fraction of maximum grid size used by the persistent kernel. + * Must be greater than 0.0 and not greater than 1.0. + * + * @param persistentDeviceUsage device usage fraction (default 1.0) + * @return an instance of this Builder + */ + public Builder withPersistentDeviceUsage(float persistentDeviceUsage) { + this.persistentDeviceUsage = persistentDeviceUsage; + return this; + } + /** * Builds an instance of {@link CagraSearchParams} with passed search * parameters. @@ -480,7 +556,10 @@ public CagraSearchParams build() { hashMapMinBitlen, hashMapMaxFillRate, numRandomSamplings, - randXORMask); + randXORMask, + persistent, + persistentLifetime, + persistentDeviceUsage); } } } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSAceParams.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSAceParams.java index b70547b333..d9e6b2598d 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSAceParams.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSAceParams.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ package com.nvidia.cuvs; @@ -75,8 +75,13 @@ public class CuVSAceParams { */ private final double maxGpuMemoryGb; - private CuVSAceParams(long npartitions, long efConstruction, String buildDir, boolean useDisk, - double maxHostMemoryGb, double maxGpuMemoryGb) { + private CuVSAceParams( + long npartitions, + long efConstruction, + String buildDir, + boolean useDisk, + double maxHostMemoryGb, + double maxGpuMemoryGb) { this.npartitions = npartitions; this.efConstruction = efConstruction; this.buildDir = buildDir; @@ -259,8 +264,8 @@ public Builder withMaxGpuMemoryGb(double maxGpuMemoryGb) { * @return an instance of {@link CuVSAceParams} */ public CuVSAceParams build() { - return new CuVSAceParams(npartitions, efConstruction, buildDir, useDisk, - maxHostMemoryGb, maxGpuMemoryGb); + return new CuVSAceParams( + npartitions, efConstruction, buildDir, useDisk, maxHostMemoryGb, maxGpuMemoryGb); } } } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSMatrix.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSMatrix.java index e0e39a4b4b..a19254ee3b 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSMatrix.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSMatrix.java @@ -18,6 +18,7 @@ public interface CuVSMatrix extends AutoCloseable { enum DataType { FLOAT(4), + HALF(2), INT(4), UINT(4), BYTE(1); @@ -94,6 +95,13 @@ interface Builder { */ void addVector(int[] vector); + /** + * Adds a single vector to the matrix. Each element is a raw float16 bit pattern stored in a short. + * + * @param vector A short array of as many elements as the dimensions + */ + void addVector(short[] vector); + T build(); } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java index b105580328..01d500b948 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java @@ -57,6 +57,24 @@ interface ScopedAccess extends AutoCloseable { */ Path tempDirectory(); + /** + * Configure the temporary workspace on this resources object as an uncapped pool backed by the + * current device memory resource. After the initial reservation is allocated on first use, + * subsequent calls to {@code cuvsRMMAlloc} / {@code cuvsRMMFree} on this handle hit the pool + * cache rather than calling {@code cudaMallocAsync} / {@code cudaFreeAsync}, reducing CUDA + * context lock contention under concurrent query threads. The pool grows without shrinking: + * freed allocations are returned to the pool rather than to the device, so the pool's + * high-water mark only increases until the resources object is closed. + * + *

The pool is per-resources-handle (i.e. per query thread when resources are thread-local), + * so there is no cross-thread pool mutex contention. Call this once after creating the resources + * object; calling it again replaces the pool. + * + * @param initialSizeBytes initial pool reservation in bytes; size {@code initialSizeBytes} to + * cover the steady-state working set to avoid growth after warmup + */ + void setWorkspacePool(long initialSizeBytes); + /** * Creates a new resources. * Equivalent to diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswAceParams.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswAceParams.java index 325f424fae..215c22838f 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswAceParams.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswAceParams.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ package com.nvidia.cuvs; @@ -21,8 +21,12 @@ public class HnswAceParams { private double maxHostMemoryGb; private double maxGpuMemoryGb; - private HnswAceParams(long npartitions, String buildDir, boolean useDisk, - double maxHostMemoryGb, double maxGpuMemoryGb) { + private HnswAceParams( + long npartitions, + String buildDir, + boolean useDisk, + double maxHostMemoryGb, + double maxGpuMemoryGb) { this.npartitions = npartitions; this.buildDir = buildDir; this.useDisk = useDisk; @@ -188,8 +192,7 @@ public Builder withMaxGpuMemoryGb(double maxGpuMemoryGb) { * @return an instance of {@link HnswAceParams} */ public HnswAceParams build() { - return new HnswAceParams(npartitions, buildDir, useDisk, - maxHostMemoryGb, maxGpuMemoryGb); + return new HnswAceParams(npartitions, buildDir, useDisk, maxHostMemoryGb, maxGpuMemoryGb); } } } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndex.java index 3eef491b62..84979cfe0c 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndex.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndex.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ package com.nvidia.cuvs; diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndexParams.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndexParams.java index 070cbedae1..d68e01b58b 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndexParams.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/HnswIndexParams.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ package com.nvidia.cuvs; @@ -283,13 +283,7 @@ public Builder withAceParams(HnswAceParams aceParams) { */ public HnswIndexParams build() { return new HnswIndexParams( - hierarchy, - efConstruction, - numThreads, - vectorDimension, - m, - metric, - aceParams); + hierarchy, efConstruction, numThreads, vectorDimension, m, metric, aceParams); } } } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/MultiPartitionSearchResults.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/MultiPartitionSearchResults.java new file mode 100644 index 0000000000..6f78acac31 --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/MultiPartitionSearchResults.java @@ -0,0 +1,55 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.nvidia.cuvs; + +/** + * Holds the decoded results of a multi-partition GPU search. + * + *

Each entry {@code i} in [0, {@link #count}) identifies: + *

    + *
  • which input partition the result came from ({@link #getPartitionIndex(int)})
  • + *
  • the local vector ordinal within that partition ({@link #getOrdinal(int)})
  • + *
  • the raw CAGRA distance ({@link #getDistance(int)})
  • + *
+ * + *

The caller is responsible for mapping ordinals to its own global identifiers. + * + * @since 25.10 + */ +public class MultiPartitionSearchResults { + + private final int count; + private final int[] partitionIndices; + private final int[] ordinals; + private final float[] distances; + + MultiPartitionSearchResults( + int count, int[] partitionIndices, int[] ordinals, float[] distances) { + this.count = count; + this.partitionIndices = partitionIndices; + this.ordinals = ordinals; + this.distances = distances; + } + + /** Number of valid results (may be less than k if fewer candidates exist). */ + public int count() { + return count; + } + + /** Index into the original partition list for result {@code i}. */ + public int getPartitionIndex(int i) { + return partitionIndices[i]; + } + + /** Local vector ordinal within the partition for result {@code i}. */ + public int getOrdinal(int i) { + return ordinals[i]; + } + + /** Post-processed distance for result {@code i} (scaled + metric-transformed). */ + public float getDistance(int i) { + return distances[i]; + } +} diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/SynchronizedCuVSResources.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/SynchronizedCuVSResources.java index 64a72ec32a..aa74893c6f 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/SynchronizedCuVSResources.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/SynchronizedCuVSResources.java @@ -40,6 +40,11 @@ public void close() { inner.close(); } + @Override + public void setWorkspacePool(long sizeBytes) { + inner.setWorkspacePool(sizeBytes); + } + @Override public Path tempDirectory() { return inner.tempDirectory(); diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java index c39578755c..558d2e73f7 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ package com.nvidia.cuvs.spi; @@ -189,6 +189,16 @@ default CagraIndex mergeCagraIndexes(CagraIndex[] indexes, CagraIndexParams merg */ void enableRMMManagedPooledMemory(int initialPoolSizePercent, int maxPoolSizePercent); + /** + * Switch RMM allocations to use stream-ordered asynchronous allocation + * ({@code cudaMallocAsync} / {@code cudaFreeAsync}). Unlike the pool resource, this resource + * returns memory to the stream without blocking the CPU, eliminating device-wide synchronization + * on deallocation. This is especially beneficial when multiple CAGRA searches run concurrently + * on separate CUDA streams, because internal workspace allocations no longer serialize kernel + * launches. This operation has a global effect and will affect all resources on the current device. + */ + void enableRMMAsyncMemory(); + /** Disables pooled memory on the current device, reverting back to the default setting. */ void resetRMMPooledMemory(); diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java index 7cbeee4e75..6701afa47b 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ package com.nvidia.cuvs.spi; @@ -47,8 +47,8 @@ public HnswIndex hnswIndexFromCagra(HnswIndexParams hnswParams, CagraIndex cagra } @Override - public HnswIndex hnswIndexBuild(CuVSResources resources, HnswIndexParams hnswParams, CuVSMatrix dataset) - throws Throwable { + public HnswIndex hnswIndexBuild( + CuVSResources resources, HnswIndexParams hnswParams, CuVSMatrix dataset) throws Throwable { throw new UnsupportedOperationException(reasons); } @@ -106,6 +106,11 @@ public Level getLogLevel() { throw new UnsupportedOperationException(reasons); } + @Override + public void enableRMMAsyncMemory() { + throw new UnsupportedOperationException(reasons); + } + @Override public void enableRMMPooledMemory(int initialPoolSizePercent, int maxPoolSizePercent) { throw new UnsupportedOperationException(reasons); diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/FilterBitsetHandle.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/FilterBitsetHandle.java new file mode 100644 index 0000000000..3abf09d29c --- /dev/null +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/FilterBitsetHandle.java @@ -0,0 +1,153 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.nvidia.cuvs; + +import static com.nvidia.cuvs.internal.common.CloseableRMMAllocation.allocateRMMSegment; +import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.HOST_TO_DEVICE; +import static com.nvidia.cuvs.internal.common.Util.checkCuVSError; +import static com.nvidia.cuvs.internal.common.Util.cudaMemcpyAsync; +import static com.nvidia.cuvs.internal.common.Util.getStream; +import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync; + +import com.nvidia.cuvs.internal.common.CloseableRMMAllocation; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +/** + * Holds a precomputed multi-partition filter bitset and manages its device-memory lifecycle. + * + *

Two-level caching

+ *
    + *
  • Host level – the packed {@code long[]} arrays are owned by this object + * and shared safely across threads (immutable after construction).
  • + *
  • Device level – a single shared device allocation is uploaded once on + * first use (lazy, double-checked locking) and reused by all threads thereafter.
  • + *
+ * + *

Callers must call {@link #close()} when the handle is evicted from their host-level cache. + * + * @since 25.10 + */ +public final class FilterBitsetHandle implements AutoCloseable { + + /** Device-side allocation pair, shared across all threads. */ + static final class DeviceData { + final CloseableRMMAllocation combinedBitsetDP; + final CloseableRMMAllocation partOffsetsDP; + final long totalBits; + final int numPartitions; + + DeviceData( + CloseableRMMAllocation combinedBitsetDP, + CloseableRMMAllocation partOffsetsDP, + long totalBits, + int numPartitions) { + this.combinedBitsetDP = combinedBitsetDP; + this.partOffsetsDP = partOffsetsDP; + this.totalBits = totalBits; + this.numPartitions = numPartitions; + } + + void close() { + try { + combinedBitsetDP.close(); + } catch (Exception ignored) { + } + try { + partOffsetsDP.close(); + } catch (Exception ignored) { + } + } + } + + // Host-side immutable data. + final long[] combinedLongs; + final long[] partBitOffsets; + final long totalBits; + final int numPartitions; + + // Shared device allocation — uploaded once, visible to all threads via volatile. + private volatile DeviceData sharedDeviceData; + private final Object uploadLock = new Object(); + + private volatile boolean closed = false; + + /** + * Creates a handle from pre-packed host arrays. + * + * @param combinedLongs packed bitset words for all partitions concatenated (64-bit aligned) + * @param partBitOffsets per-partition bit offsets into {@code combinedLongs} + * @param totalBits total number of logical bits in {@code combinedLongs} + */ + public FilterBitsetHandle(long[] combinedLongs, long[] partBitOffsets, long totalBits) { + this.combinedLongs = combinedLongs; + this.partBitOffsets = partBitOffsets; + this.totalBits = totalBits; + this.numPartitions = partBitOffsets.length; + } + + /** + * Returns the shared device allocation for this filter, uploading on first call. + * + *

The upload uses stream-ordered {@code cudaMemcpyAsync} followed by + * {@code cuvsStreamSync}, so no other stream is serialized. + * + * @param cuvsRes the native cuvsResources handle for the calling thread + * @return shared device data (valid until {@link #close()} is called) + */ + DeviceData getOrUpload(long cuvsRes) { + if (closed) throw new IllegalStateException("FilterBitsetHandle has been closed"); + DeviceData data = sharedDeviceData; + if (data != null) return data; + synchronized (uploadLock) { + data = sharedDeviceData; + if (data != null) return data; + data = upload(cuvsRes); + sharedDeviceData = data; // volatile write: happens-before all subsequent reads + } + return data; + } + + private DeviceData upload(long cuvsRes) { + long combinedBitsetBytes = (long) combinedLongs.length * Long.BYTES; + long partOffsetsBytes = (long) partBitOffsets.length * Long.BYTES; + + CloseableRMMAllocation combinedBitsetDP = allocateRMMSegment(cuvsRes, combinedBitsetBytes); + CloseableRMMAllocation partOffsetsDP = allocateRMMSegment(cuvsRes, partOffsetsBytes); + + var stream = getStream(cuvsRes); + // Host arenas must outlive the stream sync that confirms the H2D copies. + try (var arena = Arena.ofConfined()) { + MemorySegment hostBitset = arena.allocate(combinedBitsetBytes, Long.BYTES); + MemorySegment.copy( + combinedLongs, 0, hostBitset, ValueLayout.JAVA_LONG, 0, combinedLongs.length); + cudaMemcpyAsync( + combinedBitsetDP.handle(), hostBitset, combinedBitsetBytes, HOST_TO_DEVICE, stream); + + MemorySegment hostOffsets = arena.allocate(partOffsetsBytes, Long.BYTES); + MemorySegment.copy( + partBitOffsets, 0, hostOffsets, ValueLayout.JAVA_LONG, 0, partBitOffsets.length); + cudaMemcpyAsync( + partOffsetsDP.handle(), hostOffsets, partOffsetsBytes, HOST_TO_DEVICE, stream); + + checkCuVSError(cuvsStreamSync(cuvsRes), "cuvsStreamSync in FilterBitsetHandle.upload"); + } + // Stream sync has returned — device memory is fully populated. + return new DeviceData(combinedBitsetDP, partOffsetsDP, totalBits, numPartitions); + } + + /** Marks this handle closed and releases the shared device allocation. */ + @Override + public void close() { + closed = true; + DeviceData data; + synchronized (uploadLock) { + data = sharedDeviceData; + sharedDeviceData = null; + } + if (data != null) data.close(); + } +} diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/MultiPartitionCagraSearch.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/MultiPartitionCagraSearch.java new file mode 100644 index 0000000000..42d559c7d8 --- /dev/null +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/MultiPartitionCagraSearch.java @@ -0,0 +1,235 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.nvidia.cuvs; + +import static com.nvidia.cuvs.internal.common.CloseableRMMAllocation.allocateRMMSegment; +import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.DEVICE_TO_HOST; +import static com.nvidia.cuvs.internal.common.Util.checkCuVSError; +import static com.nvidia.cuvs.internal.common.Util.cudaMemcpyAsync; +import static com.nvidia.cuvs.internal.common.Util.getStream; +import static com.nvidia.cuvs.internal.common.Util.prepareTensor; +import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraSearchMultiPartition; +import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync; +import static com.nvidia.cuvs.internal.panama.headers_h.kDLCUDA; +import static com.nvidia.cuvs.internal.panama.headers_h.kDLFloat; +import static com.nvidia.cuvs.internal.panama.headers_h.kDLInt; +import static com.nvidia.cuvs.internal.panama.headers_h.kDLUInt; + +import com.nvidia.cuvs.FilterBitsetHandle.DeviceData; +import com.nvidia.cuvs.internal.BufferedCagraSearch; +import com.nvidia.cuvs.internal.CuVSMatrixInternal; +import com.nvidia.cuvs.internal.CuVSParamsHelper; +import com.nvidia.cuvs.internal.panama.cuvsFilter; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.util.List; + +/** + * Performs an approximate nearest neighbor search across multiple CAGRA index partitions in a + * single native call. The caller supplies one {@link CagraQuery} whose query matrix is searched + * against every partition; cuVS performs the per-partition searches, the cross-partition top-k + * merge, and the post-processing internally, then returns the merged results. + * + *

Algorithm (executed natively)

+ *
    + *
  1. For each (query, partition) pair, run CAGRA search into an internal + * [num_partitions, n_queries, k] device buffer.
  2. + *
  3. Apply per-partition distance post-processing on the intermediate buffer.
  4. + *
  5. Run a batched {@code raft::matrix::select_k} to pick the global top-k per query.
  6. + *
  7. Decode the select_k positions into {@code partition_ids} and {@code neighbors} outputs.
  8. + *
+ * + * @since 25.10 + */ +public class MultiPartitionCagraSearch { + + private MultiPartitionCagraSearch() {} + + /** + * Searches multiple CAGRA index partitions for the global top-k nearest neighbors. + * + * @param resources shared {@link CuVSResources} handle + * @param indices one {@link CagraIndex} per partition, in partition order + * @param query a single {@link CagraQuery} whose query matrix is searched against every + * partition; its search parameters are shared across all partitions + * @param k number of global nearest neighbors to return per query + */ + public static MultiPartitionSearchResults search( + CuVSResources resources, List indices, CagraQuery query, int k) throws Throwable { + return search(resources, indices, query, k, /* filter= */ null); + } + + /** + * Searches multiple CAGRA index partitions with an optional pre-cached device-side filter. + * + *

When {@code filter} is non-null, the filter is applied via the pre-uploaded combined + * bitset, avoiding both host-side O(N) bit evaluation and H2D transfers on cache hits. + * + * @param resources shared {@link CuVSResources} handle + * @param indices one {@link CagraIndex} per partition, in partition order + * @param query a single {@link CagraQuery} whose query matrix is searched against every + * partition + * @param k number of global nearest neighbors to return per query + * @param filter pre-built combined bitset handle, or {@code null} for unfiltered search + */ + public static MultiPartitionSearchResults search( + CuVSResources resources, + List indices, + CagraQuery query, + int k, + FilterBitsetHandle filter) + throws Throwable { + int numPartitions = indices.size(); + if (numPartitions == 0) { + return new MultiPartitionSearchResults(0, new int[0], new int[0], new float[0]); + } + + BufferedCagraSearch[] buffered = new BufferedCagraSearch[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + CagraIndex idx = indices.get(i); + if (!(idx instanceof BufferedCagraSearch)) { + throw new IllegalArgumentException( + "Index at position " + i + " does not support buffered search"); + } + buffered[i] = (BufferedCagraSearch) idx; + } + + var queryVectors = (CuVSMatrixInternal) query.getQueryVectors(); + int nQueries = (int) queryVectors.size(); + + long partitionIdsBytes = (long) nQueries * k * Integer.BYTES; // uint32 + long neighborsBytes = (long) nQueries * k * Integer.BYTES; // uint32 + long distancesBytes = (long) nQueries * k * Float.BYTES; + + CagraSearchParams searchParameters = query.getCagraSearchParameters(); + + try (var resourcesAccessor = resources.access()) { + long cuvsRes = resourcesAccessor.handle(); + var cuvsStream = getStream(cuvsRes); + + try (var partitionIdsDP = allocateRMMSegment(cuvsRes, partitionIdsBytes); + var neighborsDP = allocateRMMSegment(cuvsRes, neighborsBytes); + var distancesDP = allocateRMMSegment(cuvsRes, distancesBytes)) { + + try (var arena = Arena.ofConfined()) { + MemorySegment sp = CuVSParamsHelper.buildCagraSearchParams(arena, searchParameters); + + MemorySegment indexArray = arena.allocate(ValueLayout.ADDRESS, numPartitions); + for (int i = 0; i < numPartitions; i++) { + indexArray.setAtIndex(ValueLayout.ADDRESS, i, buffered[i].getIndexHandle()); + } + + MemorySegment queriesTensor = queryVectors.toTensor(arena); + + long[] outShape = {nQueries, k}; + MemorySegment partitionIdsTensor = + prepareTensor(arena, partitionIdsDP.handle(), outShape, kDLUInt(), 32, kDLCUDA()); + MemorySegment neighborsTensor = + prepareTensor(arena, neighborsDP.handle(), outShape, kDLUInt(), 32, kDLCUDA()); + MemorySegment distancesTensor = + prepareTensor(arena, distancesDP.handle(), outShape, kDLFloat(), 32, kDLCUDA()); + + MemorySegment filterSeg = cuvsFilter.allocate(arena); + if (filter != null) { + DeviceData dev = filter.getOrUpload(cuvsRes); + buildCuvsFilterStruct( + arena, + filterSeg, + dev.combinedBitsetDP.handle(), + dev.partOffsetsDP.handle(), + dev.totalBits, + dev.numPartitions); + } else { + cuvsFilter.type(filterSeg, 0 /* NO_FILTER */); + cuvsFilter.addr(filterSeg, 0L); + } + + checkCuVSError( + cuvsCagraSearchMultiPartition( + cuvsRes, + sp, + numPartitions, + indexArray, + queriesTensor, + partitionIdsTensor, + neighborsTensor, + distancesTensor, + filterSeg), + "cuvsCagraSearchMultiPartition"); + } + + // Copy the three small output arrays to host in a single allocation. + try (var hostArena = Arena.ofConfined()) { + MemorySegment hostBuf = + hostArena.allocate(partitionIdsBytes + neighborsBytes + distancesBytes, Long.BYTES); + MemorySegment hostPartitionIds = hostBuf.asSlice(0, partitionIdsBytes); + MemorySegment hostNeighbors = hostBuf.asSlice(partitionIdsBytes, neighborsBytes); + MemorySegment hostDistances = + hostBuf.asSlice(partitionIdsBytes + neighborsBytes, distancesBytes); + + cudaMemcpyAsync( + hostPartitionIds, + partitionIdsDP.handle(), + partitionIdsBytes, + DEVICE_TO_HOST, + cuvsStream); + cudaMemcpyAsync( + hostNeighbors, neighborsDP.handle(), neighborsBytes, DEVICE_TO_HOST, cuvsStream); + cudaMemcpyAsync( + hostDistances, distancesDP.handle(), distancesBytes, DEVICE_TO_HOST, cuvsStream); + + checkCuVSError(cuvsStreamSync(cuvsRes), "cuvsStreamSync after D2H copy"); + + int total = nQueries * k; + int[] partitionIds = new int[total]; + int[] selectedNeighbors = new int[total]; + float[] selectedDistances = new float[total]; + int count = 0; + for (int j = 0; j < total; j++) { + int neighbor = hostNeighbors.getAtIndex(ValueLayout.JAVA_INT, j); + if (neighbor < 0) continue; // sentinel from unfilled top-k slots + partitionIds[count] = hostPartitionIds.getAtIndex(ValueLayout.JAVA_INT, j); + selectedNeighbors[count] = neighbor; + selectedDistances[count] = hostDistances.getAtIndex(ValueLayout.JAVA_FLOAT, j); + count++; + } + + return new MultiPartitionSearchResults( + count, partitionIds, selectedNeighbors, selectedDistances); + } + } + } + } + + /** + * Populates a {@code cuvsFilter} MemorySegment for a MULTI_PARTITION_BITSET filter using + * pre-uploaded device buffers. + */ + private static void buildCuvsFilterStruct( + Arena arena, + MemorySegment filterSeg, + MemorySegment combinedBitsetHandle, + MemorySegment partOffsetsHandle, + long totalBits, + int numPartitions) { + long[] bitsetShape = {(totalBits + 31) / 32}; + MemorySegment combinedBitsetTensor = + prepareTensor(arena, combinedBitsetHandle, bitsetShape, kDLUInt(), 32, kDLCUDA()); + long[] offsetsShape = {numPartitions}; + MemorySegment partOffsetsTensor = + prepareTensor(arena, partOffsetsHandle, offsetsShape, kDLInt(), 64, kDLCUDA()); + + // cuvsMultiPartitionBitsetFilter: + // {ptr combined_bitset, int64 total_bits, ptr partition_offsets} + MemorySegment mpbFilter = arena.allocate(24, 8); + mpbFilter.set(ValueLayout.JAVA_LONG, 0, combinedBitsetTensor.address()); + mpbFilter.set(ValueLayout.JAVA_LONG, 8, totalBits); + mpbFilter.set(ValueLayout.JAVA_LONG, 16, partOffsetsTensor.address()); + + cuvsFilter.type(filterSeg, 3 /* MULTI_PARTITION_BITSET */); + cuvsFilter.addr(filterSeg, mpbFilter.address()); + } +} diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BufferedCagraSearch.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BufferedCagraSearch.java new file mode 100644 index 0000000000..0a7b521b29 --- /dev/null +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BufferedCagraSearch.java @@ -0,0 +1,20 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.nvidia.cuvs.internal; + +import java.lang.foreign.MemorySegment; + +/** + * Internal interface implemented by CAGRA index classes that expose their underlying native + * {@code cuvsCagraIndex_t} handle. + * + *

Used by {@link com.nvidia.cuvs.MultiPartitionCagraSearch} to build the index pointer array + * passed to {@code cuvsCagraSearchMultiPartition}. + */ +public interface BufferedCagraSearch { + + /** Returns the raw {@code cuvsCagraIndex_t} handle as a {@link MemorySegment}. */ + MemorySegment getIndexHandle(); +} diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java index abc53a5945..d7bebcbf22 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ package com.nvidia.cuvs.internal; @@ -47,7 +47,7 @@ * * @since 25.02 */ -public class CagraIndexImpl implements CagraIndex { +public class CagraIndexImpl implements CagraIndex, BufferedCagraSearch { private final CuVSResources resources; private final IndexReference cagraIndexReference; private boolean destroyed; @@ -338,6 +338,12 @@ public SearchResults search(CagraQuery query) throws Throwable { } } + /** Returns the underlying {@code cuvsCagraIndex_t} handle for native-side index passing. */ + @Override + public MemorySegment getIndexHandle() { + return cagraIndexReference.getMemorySegment(); + } + @Override public void serialize(OutputStream outputStream) throws Throwable { Path path = @@ -632,8 +638,10 @@ private static void populateNativeIndexParams( cuvsAceParams.npartitions(cuvsAceParamsMemorySegment, cuVSAceParams.getNpartitions()); cuvsAceParams.ef_construction(cuvsAceParamsMemorySegment, cuVSAceParams.getEfConstruction()); cuvsAceParams.use_disk(cuvsAceParamsMemorySegment, cuVSAceParams.isUseDisk()); - cuvsAceParams.max_host_memory_gb(cuvsAceParamsMemorySegment, cuVSAceParams.getMaxHostMemoryGb()); - cuvsAceParams.max_gpu_memory_gb(cuvsAceParamsMemorySegment, cuVSAceParams.getMaxGpuMemoryGb()); + cuvsAceParams.max_host_memory_gb( + cuvsAceParamsMemorySegment, cuVSAceParams.getMaxHostMemoryGb()); + cuvsAceParams.max_gpu_memory_gb( + cuvsAceParamsMemorySegment, cuVSAceParams.getMaxGpuMemoryGb()); String buildDir = cuVSAceParams.getBuildDir(); if (buildDir != null && !buildDir.isEmpty()) { diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSMatrixBaseImpl.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSMatrixBaseImpl.java index 98f4095ffc..03e08ac4c4 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSMatrixBaseImpl.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSMatrixBaseImpl.java @@ -7,6 +7,7 @@ import static com.nvidia.cuvs.internal.common.LinkerHelper.C_CHAR; import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT; import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT; +import static com.nvidia.cuvs.internal.common.LinkerHelper.C_SHORT; import static com.nvidia.cuvs.internal.common.Util.checkCuVSError; import static com.nvidia.cuvs.internal.panama.headers_h.*; @@ -95,6 +96,7 @@ public ValueLayout valueLayout() { protected static ValueLayout valueLayoutFromType(DataType dataType) { return switch (dataType) { case FLOAT -> C_FLOAT; + case HALF -> C_SHORT; case INT, UINT -> C_INT; case BYTE -> C_CHAR; }; @@ -177,6 +179,8 @@ private static DataType dataTypeFromTensor(byte code, byte bits) { dataType = DataType.INT; } else if (code == kDLFloat() && bits == 32) { dataType = DataType.FLOAT; + } else if (code == kDLFloat() && bits == 16) { + dataType = DataType.HALF; } else if ((code == kDLInt() || code == kDLUInt()) && bits == 8) { dataType = DataType.BYTE; } else { diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSMatrixInternal.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSMatrixInternal.java index 35715b8336..6b4b617493 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSMatrixInternal.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSMatrixInternal.java @@ -38,7 +38,7 @@ default int code() { static int code(DataType dataType) { return switch (dataType) { - case FLOAT -> kDLFloat(); + case FLOAT, HALF -> kDLFloat(); case INT -> kDLInt(); case UINT, BYTE -> kDLUInt(); }; diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSParamsHelper.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSParamsHelper.java index 950504bc5a..a4c914dbf8 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSParamsHelper.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSParamsHelper.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ package com.nvidia.cuvs.internal; @@ -7,7 +7,9 @@ import static com.nvidia.cuvs.internal.common.Util.checkCuVSError; import static com.nvidia.cuvs.internal.panama.headers_h.*; +import com.nvidia.cuvs.CagraSearchParams; import com.nvidia.cuvs.internal.common.CloseableHandle; +import com.nvidia.cuvs.internal.panama.*; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; @@ -25,6 +27,30 @@ public final class CuVSParamsHelper { private CuVSParamsHelper() {} + /** + * Allocates and populates a {@code cuvsCagraSearchParams} struct into {@code arena}. + * The returned segment is valid for the lifetime of {@code arena}. + */ + public static MemorySegment buildCagraSearchParams(Arena arena, CagraSearchParams params) { + MemorySegment seg = cuvsCagraSearchParams.allocate(arena); + cuvsCagraSearchParams.max_queries(seg, params.getMaxQueries()); + cuvsCagraSearchParams.itopk_size(seg, params.getITopKSize()); + cuvsCagraSearchParams.max_iterations(seg, params.getMaxIterations()); + cuvsCagraSearchParams.algo(seg, params.getCagraSearchAlgo().value); + cuvsCagraSearchParams.team_size(seg, params.getTeamSize()); + cuvsCagraSearchParams.search_width(seg, params.getSearchWidth()); + cuvsCagraSearchParams.min_iterations(seg, params.getMinIterations()); + cuvsCagraSearchParams.thread_block_size(seg, params.getThreadBlockSize()); + cuvsCagraSearchParams.hashmap_mode(seg, params.getHashMapMode().value); + cuvsCagraSearchParams.hashmap_max_fill_rate(seg, params.getHashMapMaxFillRate()); + cuvsCagraSearchParams.num_random_samplings(seg, params.getNumRandomSamplings()); + cuvsCagraSearchParams.rand_xor_mask(seg, params.getRandXORMask()); + cuvsCagraSearchParams.persistent(seg, params.isPersistent()); + cuvsCagraSearchParams.persistent_lifetime(seg, params.getPersistentLifetime()); + cuvsCagraSearchParams.persistent_device_usage(seg, params.getPersistentDeviceUsage()); + return seg; + } + public static CloseableHandle createCagraIndexParams() { try (var localArena = Arena.ofConfined()) { var paramsPtrPtr = localArena.allocate(cuvsCagraIndexParams_t); diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSResourcesImpl.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSResourcesImpl.java index efdf7283ac..e421cad660 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSResourcesImpl.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSResourcesImpl.java @@ -65,6 +65,12 @@ public void close() { } } + @Override + public void setWorkspacePool(long sizeBytes) { + checkCuVSError( + cuvsResourcesSetWorkspacePool(resourceHandle, sizeBytes), "cuvsResourcesSetWorkspacePool"); + } + @Override public Path tempDirectory() { return tempDirectory; diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/HnswIndexImpl.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/HnswIndexImpl.java index ca528ac010..66a53fa6d3 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/HnswIndexImpl.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/HnswIndexImpl.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ package com.nvidia.cuvs.internal; @@ -251,8 +251,8 @@ public static HnswIndex.Builder newBuilder(CuVSResources cuvsResources) { * @return A new HNSW index ready for search * @throws Throwable if an error occurs during building */ - public static HnswIndex build(CuVSResources resources, HnswIndexParams hnswParams, CuVSMatrix dataset) - throws Throwable { + public static HnswIndex build( + CuVSResources resources, HnswIndexParams hnswParams, CuVSMatrix dataset) throws Throwable { Objects.requireNonNull(resources); Objects.requireNonNull(hnswParams); Objects.requireNonNull(dataset); @@ -288,7 +288,8 @@ public static HnswIndex build(CuVSResources resources, HnswIndexParams hnswParam return new HnswIndexImpl(new IndexReference(hnswIndex), resources, hnswParams); } - private static CloseableHandle createHnswIndexParamsForBuild(Arena arena, HnswIndexParams params) { + private static CloseableHandle createHnswIndexParamsForBuild( + Arena arena, HnswIndexParams params) { var hnswParams = createHnswIndexParams(); MemorySegment seg = hnswParams.handle(); @@ -324,7 +325,7 @@ private static MemorySegment prepareTensorFromMatrix(Arena arena, CuVSMatrix dat return prepareTensor( arena, matrixInternal.memorySegment(), - new long[]{dataset.size(), dataset.columns()}, + new long[] {dataset.size(), dataset.columns()}, matrixInternal.code(), matrixInternal.bits(), kDLCPU()); diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/SelectKHelper.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/SelectKHelper.java new file mode 100644 index 0000000000..f859ea26c7 --- /dev/null +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/SelectKHelper.java @@ -0,0 +1,94 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +package com.nvidia.cuvs.internal; + +import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT; +import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG; +import static com.nvidia.cuvs.internal.common.LinkerHelper.C_POINTER; +import static com.nvidia.cuvs.internal.common.Util.checkCuVSError; +import static com.nvidia.cuvs.internal.common.Util.prepareTensor; +import static com.nvidia.cuvs.internal.panama.headers_h.kDLCUDA; +import static com.nvidia.cuvs.internal.panama.headers_h.kDLFloat; +import static com.nvidia.cuvs.internal.panama.headers_h.kDLInt; + +import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.Linker; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SymbolLookup; +import java.lang.invoke.MethodHandle; + +/** + * Panama FFI binding for {@code cuvsSelectK}. + * + *

Selects the k smallest float values from a flat device array of n candidates, writing output + * distances and their flat-array positions (int64) into caller-supplied device buffers. + */ +public class SelectKHelper { + + private static final MethodHandle cuvsSelectK$mh; + + static { + var linker = Linker.nativeLinker(); + SymbolLookup lookup = + SymbolLookup.libraryLookup(System.mapLibraryName("cuvs_c"), Arena.ofAuto()) + .or(SymbolLookup.loaderLookup()) + .or(linker.defaultLookup()); + + cuvsSelectK$mh = + linker.downcallHandle( + lookup + .find("cuvsSelectK") + .orElseThrow(() -> new UnsatisfiedLinkError("cuvsSelectK not found in libcuvs_c")), + FunctionDescriptor.of( + C_INT, // return: cuvsError_t + C_LONG, // cuvsResources_t res + C_POINTER, // DLManagedTensor* in_val + C_POINTER, // DLManagedTensor* out_val + C_POINTER // DLManagedTensor* out_idx + )); + } + + private SelectKHelper() {} + + /** + * Selects the {@code k} smallest distances from a flat device array of {@code n} candidates. + * + *

Output positions in {@code outIdxDP} are int64 column indices into [0, n). The caller + * recovers per-segment identity as {@code segment = position / segmentK}. + * + * @param cuvsRes cuvsResources_t handle (raw long) + * @param inValDP device pointer to float[n] input distances + * @param n number of input candidates + * @param outValDP device pointer to float[k] output distances + * @param outIdxDP device pointer to int64[k] output positions + * @param k number of results to select + */ + public static void selectK( + long cuvsRes, + MemorySegment inValDP, + long n, + MemorySegment outValDP, + MemorySegment outIdxDP, + long k) { + try (var arena = Arena.ofConfined()) { + long[] inShape = {1, n}; + long[] outShape = {1, k}; + + MemorySegment inValTensor = prepareTensor(arena, inValDP, inShape, kDLFloat(), 32, kDLCUDA()); + MemorySegment outValTensor = + prepareTensor(arena, outValDP, outShape, kDLFloat(), 32, kDLCUDA()); + MemorySegment outIdxTensor = + prepareTensor(arena, outIdxDP, outShape, kDLInt(), 64, kDLCUDA()); + + int rc = (int) cuvsSelectK$mh.invokeExact(cuvsRes, inValTensor, outValTensor, outIdxTensor); + checkCuVSError(rc, "cuvsSelectK"); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable t) { + throw new RuntimeException("cuvsSelectK failed", t); + } + } +} diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/LinkerHelper.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/LinkerHelper.java index 6de70ce920..79b440b4a8 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/LinkerHelper.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/LinkerHelper.java @@ -25,6 +25,9 @@ public class LinkerHelper { public static final ValueLayout.OfLong C_LONG = (ValueLayout.OfLong) LINKER.canonicalLayouts().get("long"); + public static final ValueLayout.OfShort C_SHORT = + (ValueLayout.OfShort) LINKER.canonicalLayouts().get("short"); + public static final ValueLayout.OfFloat C_FLOAT = (ValueLayout.OfFloat) LINKER.canonicalLayouts().get("float"); diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java index 7ca7640cd8..f6b8ef6ee4 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ package com.nvidia.cuvs.internal.common; @@ -27,6 +27,7 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.VarHandle; import java.util.BitSet; +import java.util.logging.Logger; public class Util { @@ -35,11 +36,46 @@ public class Util { private Util() {} + private static final Logger log = Logger.getLogger(Util.class.getName()); + + static { + if (!tryLoadCudart()) { + log.warning( + "Could not load libcudart from java.library.path, LD_LIBRARY_PATH, or" + + " /usr/local/cuda/lib64. If libcuvs_c.so was built with static CUDA," + + " initialization will fail. Set -Djava.library.path to your CUDA lib64" + + " directory."); + } + } + + private static boolean tryLoadCudart() { + try { + System.loadLibrary("cudart"); + return true; + } catch (UnsatisfiedLinkError ignored) { + } + String ldLibPath = System.getenv("LD_LIBRARY_PATH"); + if (ldLibPath != null) { + for (String dir : ldLibPath.split(":")) { + try { + System.load(dir + "/" + System.mapLibraryName("cudart")); + return true; + } catch (UnsatisfiedLinkError ignored) { + } + } + } + try { + System.load("/usr/local/cuda/lib64/" + System.mapLibraryName("cudart")); + return true; + } catch (UnsatisfiedLinkError ignored) { + } + return false; + } + private static final Linker LINKER = Linker.nativeLinker(); static final SymbolLookup SYMBOL_LOOKUP = SymbolLookup.libraryLookup(System.mapLibraryName("cuvs_c"), Arena.ofAuto()) - .or(SymbolLookup.libraryLookup(System.mapLibraryName("cudart"), Arena.ofAuto())) .or(SymbolLookup.loaderLookup()) .or(Linker.nativeLinker().defaultLookup()); diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java index 1d3199f26f..c2b64cc0c2 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java @@ -7,7 +7,6 @@ import static com.nvidia.cuvs.internal.CuVSParamsHelper.*; import static com.nvidia.cuvs.internal.common.Util.*; import static com.nvidia.cuvs.internal.panama.headers_h.*; -import static com.nvidia.cuvs.internal.panama.headers_h_1.cudaStreamSynchronize; import com.nvidia.cuvs.*; import com.nvidia.cuvs.internal.*; @@ -138,6 +137,9 @@ public String toString() { private final cuvsRMMMemoryResourceReset cuvsRMMMemoryResourceResetInvoker = cuvsRMMMemoryResourceReset.makeInvoker(); + private final cuvsRMMAsyncMemoryResourceEnable cuvsRMMAsyncMemoryResourceEnableInvoker = + cuvsRMMAsyncMemoryResourceEnable.makeInvoker(); + private final cuvsGetLogLevel GET_LOG_LEVEL_INVOKER = cuvsGetLogLevel.makeInvoker(); private JDKProvider() {} @@ -255,8 +257,8 @@ public HnswIndex hnswIndexFromCagra(HnswIndexParams hnswParams, CagraIndex cagra } @Override - public HnswIndex hnswIndexBuild(CuVSResources resources, HnswIndexParams hnswParams, CuVSMatrix dataset) - throws Throwable { + public HnswIndex hnswIndexBuild( + CuVSResources resources, HnswIndexParams hnswParams, CuVSMatrix dataset) throws Throwable { return HnswIndexImpl.build(resources, hnswParams, dataset); } @@ -436,6 +438,12 @@ public Level getLogLevel() { throw new IllegalArgumentException("Unexpected log level [" + logLevel + "]"); } + @Override + public void enableRMMAsyncMemory() { + checkCuVSError( + cuvsRMMAsyncMemoryResourceEnableInvoker.apply(), "cuvsRMMAsyncMemoryResourceEnable"); + } + @Override public void enableRMMPooledMemory(int initialPoolSizePercent, int maxPoolSizePercent) { checkCuVSError( @@ -603,6 +611,15 @@ public void addVector(int[] vector) { internalAddVector(MemorySegment.ofArray(vector)); } + public void addVector(short[] vector) { + if (vector.length != columns) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, "Expected a vector of size [%d], got [%d]", columns, vector.length)); + } + internalAddVector(MemorySegment.ofArray(vector)); + } + protected abstract void internalAddVector(MemorySegment vector); } diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CheckedCuVSResources.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CheckedCuVSResources.java index e880edc85d..c615e5ee30 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CheckedCuVSResources.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CheckedCuVSResources.java @@ -56,6 +56,11 @@ public void close() { inner.close(); } + @Override + public void setWorkspacePool(long sizeBytes) { + inner.setWorkspacePool(sizeBytes); + } + @Override public Path tempDirectory() { return inner.tempDirectory(); diff --git a/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py b/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py index 4315c8e3ac..930c816519 100644 --- a/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py +++ b/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py @@ -83,6 +83,7 @@ def force_fallback_to_numpy(): from rmm.allocators.cupy import rmm_cupy_allocator from cuvs.common import Resources + from cuvs.neighbors import filters from cuvs.neighbors.brute_force import build, search except ImportError: # RMM is available, cupy is available, but cuVS is not @@ -111,7 +112,37 @@ def choose_random_queries(dataset, n_queries): return dataset[query_idx, :] -def cpu_search(dataset, queries, k, metric="squeclidean"): +def create_bitset_filter(n_samples, filter_reject_rate): + """ + Creates a packed uint32 bitset where bit i is set iff vector i passes the + filter. Uses a modulo-1000 bucket scheme: vector i passes when + ``i % 1000 >= round(filter_reject_rate * 1000)``, giving a reject rate + within 0.1% of the requested value. + + Parameters + ---------- + n_samples : int + Number of vectors in the dataset. + filter_reject_rate : float + Fraction of vectors to reject, in [0.0, 1.0). + + Returns + ------- + numpy.ndarray + Packed uint32 array of shape ``(ceil(n_samples / 32),)``. + """ + import numpy as np + + fail_buckets = round(filter_reject_rate * 1000) + n_padded = ((n_samples + 31) // 32) * 32 + bool_mask = np.zeros(n_padded, dtype=bool) + bool_mask[:n_samples] = (np.arange(n_samples) % 1000) >= fail_buckets + # Pack with little-endian bit order: bit j maps to bit (j%32) of uint32 + # word (j//32), LSB first — matching cuVS bitset layout. + return np.packbits(bool_mask, bitorder="little").view(np.uint32) + + +def cpu_search(dataset, queries, k, metric="squeclidean", accept_mask=None): """ Find the k nearest neighbors for each query point in the dataset using the specified metric. @@ -128,6 +159,9 @@ def cpu_search(dataset, queries, k, metric="squeclidean"): metric : str, optional The distance metric to use. Can be 'squeclidean' or 'inner_product'. Default is 'squeclidean'. + accept_mask : numpy.ndarray, optional + Boolean array of shape (n_samples,). Where False, the corresponding + dataset vector is excluded from results. Returns ------- @@ -144,6 +178,9 @@ def cpu_search(dataset, queries, k, metric="squeclidean"): diff = queries[:, xp.newaxis, :] - dataset[xp.newaxis, :, :] dist_sq = xp.sum(diff**2, axis=2) # Shape: (n_queries, n_samples) + if accept_mask is not None: + dist_sq[:, ~accept_mask] = xp.inf + indices = xp.argpartition(dist_sq, kth=k - 1, axis=1)[:, :k] distances = xp.take_along_axis(dist_sq, indices, axis=1) @@ -156,6 +193,9 @@ def cpu_search(dataset, queries, k, metric="squeclidean"): queries, dataset.T ) # Shape: (n_queries, n_samples) + if accept_mask is not None: + similarities[:, ~accept_mask] = -xp.inf + neg_similarities = -similarities indices = xp.argpartition(neg_similarities, kth=k - 1, axis=1)[:, :k] distances = xp.take_along_axis(similarities, indices, axis=1) @@ -175,7 +215,27 @@ def cpu_search(dataset, queries, k, metric="squeclidean"): return distances, indices -def calc_truth(dataset, queries, k, metric="sqeuclidean"): +def calc_truth(dataset, queries, k, metric="sqeuclidean", bitset=None): + """ + Calculate exact nearest neighbors, optionally with a prefilter. + + Parameters + ---------- + dataset : array-like + Dataset of shape (n_samples, n_features). + queries : array-like + Queries of shape (n_queries, n_features). + k : int + Number of neighbors. + metric : str + Distance metric. + bitset : numpy.ndarray, optional + Packed uint32 array of shape (ceil(n_samples / 32),) as returned by + :func:`create_bitset_filter`. Bit i set means vector i passes the + filter. When None, all vectors are considered. + """ + import numpy as np + n_samples = dataset.shape[0] n = 500000 # batch size for processing neighbors i = 0 @@ -194,10 +254,28 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"): if gpu_system: index = build(X, metric=metric, resources=resources) - D, Ind = search(index, queries, k, resources=resources) + prefilter = None + if bitset is not None: + word_start = i // 32 + word_end = (i + n_batch + 31) // 32 + batch_words = xp.asarray(bitset[word_start:word_end]) + prefilter = filters.from_bitset(batch_words) + D, Ind = search( + index, queries, k, resources=resources, prefilter=prefilter + ) resources.sync() else: - D, Ind = cpu_search(X, queries, k, metric=metric) + accept_mask = None + if bitset is not None: + word_start = i // 32 + word_end = (i + n_batch + 31) // 32 + batch_bytes = bitset[word_start:word_end].view(np.uint8) + accept_mask = np.unpackbits(batch_bytes, bitorder="little")[ + :n_batch + ].astype(bool) + D, Ind = cpu_search( + X, queries, k, metric=metric, accept_mask=accept_mask + ) D, Ind = xp.asarray(D), xp.asarray(Ind) Ind = offset_neighbor_indices(Ind, i, n_samples) @@ -247,6 +325,16 @@ def main(): python -m cuvs_bench.generate_groundtruth --dataset /dataset/base.\ fbin --nrows=2000000 --cols=128 --output=groundtruth_dir \ --queries=random-choice --n_queries=10000 + + # Prefiltered ground truth using a saved bitset file + python -m cuvs_bench.generate_groundtruth --dataset /dataset/base.\ +fbin --output=groundtruth_dir --queries=/dataset/query.fbin \ +--bitset=/dataset/filter.npy + + # Prefiltered ground truth generated on-the-fly from a reject rate + python -m cuvs_bench.generate_groundtruth --dataset /dataset/base.\ +fbin --output=groundtruth_dir --queries=/dataset/query.fbin \ +--filter_reject_rate=0.1 """, formatter_class=argparse.RawDescriptionHelpFormatter, ) @@ -313,6 +401,25 @@ def main(): " commonly used with cuVS are 'sqeuclidean' and 'inner_product'", ) + filter_group = parser.add_mutually_exclusive_group() + filter_group.add_argument( + "--bitset", + type=str, + default=None, + help="Path to a .npy file containing a packed uint32 prefilter " + "bitset of shape (ceil(n_samples / 32),). Bit i set means vector i " + "passes the filter. Mutually exclusive with --filter_reject_rate.", + ) + filter_group.add_argument( + "--filter_reject_rate", + type=float, + default=None, + help="Fraction of vectors to reject, in [0.0, 1.0). Generates a " + "bitset using a modulo-1000 bucket scheme (vector i passes when " + "i %% 1000 >= round(filter_reject_rate * 1000)). Mutually exclusive " + "with --bitset.", + ) + if len(sys.argv) == 1: parser.print_help() sys.exit(1) @@ -328,6 +435,7 @@ def main(): args.dataset, args.dtype, shape=(args.rows, args.cols) ) n_features = dataset.shape[1] + n_samples = dataset.shape[0] dtype = dataset.dtype print( @@ -362,8 +470,24 @@ def main(): print("Reading queries from file", args.queries) queries = memmap_bin_file(args.queries, dtype) + # Resolve prefilter bitset. + bitset = None + if args.bitset is not None: + import numpy as np + + print("Loading prefilter bitset from", args.bitset) + bitset = np.load(args.bitset) + elif args.filter_reject_rate is not None: + print( + f"Generating prefilter bitset for filter_reject_rate=" + f"{args.filter_reject_rate}" + ) + bitset = create_bitset_filter(n_samples, args.filter_reject_rate) + print("Calculating true nearest neighbors") - distances, indices = calc_truth(dataset, queries, args.k, args.metric) + distances, indices = calc_truth( + dataset, queries, args.k, args.metric, bitset=bitset + ) n_base = dataset.shape[0] write_groundtruth_neighbors(