Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] fix panic for can not search all the vector for not clear the searcher.seen when lower_search #36

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/hnsw/hnsw_const.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ where
params,
}
}

pub fn new_with_capacity(metric: Met, params: Params, capacity: usize) -> Self {
Self {
metric,
zero: Vec::with_capacity(capacity),
features: Vec::with_capacity(capacity),
layers: vec![],
prng: R::from_seed(R::Seed::default()),
params,
}
}
}

impl<Met, T, R, const M: usize, const M0: usize> Knn for Hnsw<Met, T, R, M, M0>
Expand Down Expand Up @@ -374,12 +385,14 @@ where
// See Algorithm 5 line 5 of the paper. The paper makes no further comment on why `1` was chosen.
let &Neighbor { index, distance } = searcher.nearest.first().unwrap();
searcher.nearest.clear();
searcher.seen.clear();
// Update the node to the next layer.
let new_index = layer[index].next_node as usize;
let candidate = Neighbor {
index: new_index,
distance,
};
searcher.seen.insert(layer[index].zero_node);
// Insert the index of the nearest neighbor into the nearest pool for the next layer.
searcher.nearest.push(candidate);
// Insert the index into the candidate pool as well.
Expand Down
55 changes: 50 additions & 5 deletions tests/simple.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Useful tests for debugging since they are hand-written and easy to see the debugging output.

use hnsw::{Hnsw, Searcher};
use itertools::Itertools;
use rand_pcg::Pcg64;
use space::{Metric, Neighbor};

Expand All @@ -18,13 +19,43 @@ impl Metric<&[f64]> for Euclidean {
}
}

struct TestBruteForceHelper {
vectors: Vec<(usize, Vec<f64>)>,
}

impl TestBruteForceHelper {
fn new() -> Self {
Self {
vectors: Vec::new(),
}
}

fn push(&mut self, v: (usize, Vec<f64>)) {
self.vectors.push(v);
}

fn search(&self, query: &[f64], top_k: usize) -> Vec<usize> {
let metric = Euclidean;
let mut candidates: Vec<(usize, u64)> = self
.vectors
.iter()
.map(|v| (v.0.clone(), metric.distance(&query, &v.1.as_slice())))
.collect_vec();

candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));

candidates.into_iter().take(top_k).map(|v| v.0).collect()
}
}

fn test_hnsw() -> (
Hnsw<Euclidean, &'static [f64], Pcg64, 12, 24>,
Searcher<u64>,
TestBruteForceHelper,
) {
let mut searcher = Searcher::default();
let mut hnsw = Hnsw::new(Euclidean);

let mut helper = TestBruteForceHelper::new();
let features = [
&[0.0, 0.0, 0.0, 1.0],
&[0.0, 0.0, 1.0, 0.0],
Expand All @@ -36,11 +67,12 @@ fn test_hnsw() -> (
&[1.0, 0.0, 0.0, 1.0],
];

for &feature in &features {
hnsw.insert(feature, &mut searcher);
for (index, feature) in features.iter().enumerate() {
helper.push((index, feature.to_vec()));
hnsw.insert(*feature, &mut searcher);
}

(hnsw, searcher)
(hnsw, searcher, helper)
}

#[test]
Expand All @@ -50,7 +82,7 @@ fn insertion() {

#[test]
fn nearest_neighbor() {
let (hnsw, mut searcher) = test_hnsw();
let (hnsw, mut searcher, helper) = test_hnsw();
let searcher = &mut searcher;
let mut neighbors = [Neighbor {
index: !0,
Expand Down Expand Up @@ -101,4 +133,17 @@ fn nearest_neighbor() {
}
]
);
// test for not panicking
for topk in 0..8 {
let mut neighbors = vec![
Neighbor {
index: !0,
distance: !0,
};
topk
];
hnsw.nearest(&&[0.0, 0.0, 0.0, 1.0][..], 24, searcher, &mut neighbors);
let result = neighbors.iter().map(|item| item.index).collect_vec();
assert_eq!(result, helper.search(&[0.0, 0.0, 0.0, 1.0], topk));
}
}