Skip to content

Commit

Permalink
disable hgraph filter check when filter is nullptr
Browse files Browse the repository at this point in the history
- disable filter check while filter is nullptr

Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 committed Dec 30, 2024
1 parent 2fc80bb commit 4dddb82
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
7 changes: 5 additions & 2 deletions src/algorithm/hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ HGraph::KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const {
BitsetOrCallbackFilter ft(filter);
std::unique_ptr<BitsetOrCallbackFilter> ft = nullptr;
if (filter != nullptr) {
ft = std::make_unique<BitsetOrCallbackFilter>(filter);
}
try {
int64_t query_dim = query->GetDim();
CHECK_ARGUMENT(
Expand Down Expand Up @@ -183,7 +186,7 @@ HGraph::KnnSearch(const DatasetPtr& query,
auto params = HGraphSearchParameters::FromJson(parameters);

search_param.ef_ = params.ef_search;
search_param.is_id_allowed_ = &ft;
search_param.is_id_allowed_ = ft.get();
auto search_result = this->search_one_graph(query->GetFloat32Vectors(),
this->bottom_graph_,
this->basic_flatten_codes_,
Expand Down
8 changes: 4 additions & 4 deletions src/index/hgraph_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ class HGraphIndex : public Index {
int64_t k,
const std::string& parameters,
BitsetPtr invalid = nullptr) const override {
auto func = [&invalid](int64_t id) -> bool {
if (invalid == nullptr) {
return false;
}
std::function<bool(int64_t)> func = [&invalid](int64_t id) -> bool {
int64_t bit_index = id & ROW_ID_MASK;
return invalid->Test(bit_index);
};
if (invalid == nullptr) {
func = nullptr;
}
SAFE_CALL(return this->hgraph_->KnnSearch(query, k, parameters, func));
}

Expand Down
7 changes: 1 addition & 6 deletions src/simd/avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,6 @@ SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t
return 0.0f;
}

alignas(64) int32_t temp[16];
int32_t result = 0;
uint64_t d = 0;
__m512i sum = _mm512_setzero_si512();
__m512i mask = _mm512_set1_epi16(0xff);
Expand All @@ -370,10 +368,7 @@ SQ8UniformComputeCodesIP(const uint8_t* codes1, const uint8_t* codes2, uint64_t
sum = _mm512_add_epi32(sum, _mm512_madd_epi16(xx1, yy1));
sum = _mm512_add_epi32(sum, _mm512_madd_epi16(xx2, yy2));
}
_mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum);
for (int i = 0; i < 16; ++i) {
result += temp[i];
}
int32_t result = _mm512_reduce_add_epi32(sum);
result += static_cast<int32_t>(avx2::SQ8UniformComputeCodesIP(codes1 + d, codes2 + d, dim - d));
return static_cast<float>(result);
#else
Expand Down

0 comments on commit 4dddb82

Please sign in to comment.