From c71920fd503b719e8b48de94ed6cb8d02274b4a4 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Mon, 30 Dec 2024 02:49:20 +0000 Subject: [PATCH] disable hgraph filter check when filter is nullptr - disable filter check while filter is nullptr Signed-off-by: LHT129 --- src/algorithm/hgraph.cpp | 7 +++++-- src/index/hgraph_index.h | 8 ++++---- src/simd/avx512.cpp | 7 +------ 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index 5ac04632..ba417aa9 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -155,7 +155,10 @@ HGraph::KnnSearch(const DatasetPtr& query, int64_t k, const std::string& parameters, const std::function& filter) const { - BitsetOrCallbackFilter ft(filter); + std::unique_ptr ft = nullptr; + if (filter != nullptr) { + ft = std::make_unique(filter); + } try { int64_t query_dim = query->GetDim(); CHECK_ARGUMENT( @@ -183,7 +186,7 @@ HGraph::KnnSearch(const DatasetPtr& query, auto params = HGraphSearchParameters::FromJson(parameters); search_param.ef_ = params.ef_search; - search_param.is_id_allowed_ = &ft; + search_param.is_id_allowed_ = ft.get(); auto search_result = this->search_one_graph(query->GetFloat32Vectors(), this->bottom_graph_, this->basic_flatten_codes_, diff --git a/src/index/hgraph_index.h b/src/index/hgraph_index.h index 07da8c76..373058fe 100644 --- a/src/index/hgraph_index.h +++ b/src/index/hgraph_index.h @@ -47,13 +47,13 @@ class HGraphIndex : public Index { int64_t k, const std::string& parameters, BitsetPtr invalid = nullptr) const override { - auto func = [&invalid](int64_t id) -> bool { - if (invalid == nullptr) { - return false; - } + std::function func = [&invalid](int64_t id) -> bool { int64_t bit_index = id & ROW_ID_MASK; return invalid->Test(bit_index); }; + if (invalid == nullptr) { + func = nullptr; + } SAFE_CALL(return this->hgraph_->KnnSearch(query, k, parameters, func)); } diff --git a/src/simd/avx512.cpp b/src/simd/avx512.cpp index 9b443c0b..cdd382bf 100644 --- a/src/simd/avx512.cpp +++ b/src/simd/avx512.cpp @@ -353,8 +353,6 @@ SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t return 0.0f; } - alignas(64) int32_t temp[16]; - int32_t result = 0; uint64_t d = 0; __m512i sum = _mm512_setzero_si512(); __m512i mask = _mm512_set1_epi16(0xff); @@ -370,10 +368,7 @@ SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t sum = _mm512_add_epi32(sum, _mm512_madd_epi16(xx1, yy1)); sum = _mm512_add_epi32(sum, _mm512_madd_epi16(xx2, yy2)); } - _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum); - for (int i = 0; i < 16; ++i) { - result += temp[i]; - } + int32_t result = _mm512_reduce_add_epi32(sum); result += static_cast(avx2::SQ8UniformComputeCodesIP(codes1 + d, codes2 + d, dim - d)); return static_cast(result); #else