Skip to content

Commit

Permalink
Benchmark HNSW for Jaccard (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu committed Sep 8, 2023
1 parent e11bb70 commit 8baa603
Show file tree
Hide file tree
Showing 9 changed files with 482 additions and 332 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ sub-linear query time:
+---------------------------+-----------------------------+------------------------+
| `MinHash LSH Ensemble`_ | MinHash | Containment Threshold |
+---------------------------+-----------------------------+------------------------+
| `HNSW`_ | Customizable | Metric Distances |
| `HNSW`_ | Any | Custom Metric Top-K |
+---------------------------+-----------------------------+------------------------+

datasketch must be used with Python 3.7 or above, NumPy 1.11 or above, and Scipy.
Expand Down
41 changes: 26 additions & 15 deletions benchmark/indexes/jaccard/exact.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import time
import sys
import collections

from SetSimilaritySearch import SearchIndex
import tqdm


def _query_jaccard_topk(index, query, k):
Expand All @@ -22,20 +24,31 @@ def _query_jaccard_topk(index, query, k):
return candidates[:k]


def search_jaccard_topk(index_data, query_data, k):
(index_sets, index_keys) = index_data
(query_sets, query_keys) = query_data
print("Building jaccard search index.")
start = time.perf_counter()
# Build the search index with the 0 threshold to index all tokens.
index = SearchIndex(
index_sets, similarity_func_name="jaccard", similarity_threshold=0.0
)
indexing_time = time.perf_counter() - start
print("Finished building index in {:.3f}.".format(indexing_time))
def search_jaccard_topk(index_data, query_data, index_params, k):
(index_sets, index_keys, _, index_cache) = index_data
(query_sets, query_keys, _) = query_data
cache_key = json.dumps(index_params)
if cache_key not in index_cache:
print("Building jaccard search index.")
start = time.perf_counter()
# Build the search index with the 0 threshold to index all tokens.
index = SearchIndex(
index_sets, similarity_func_name="jaccard", similarity_threshold=0.0
)
indexing_time = time.perf_counter() - start
print("Finished building index in {:.3f}.".format(indexing_time))
index_cache[cache_key] = (
index,
{
"indexing_time": indexing_time,
},
)
index, indexing = index_cache[cache_key]
times = []
results = []
for query_set, query_key in zip(query_sets, query_keys):
for query_set, query_key in tqdm.tqdm(
zip(query_sets, query_keys), total=len(query_keys), desc="Querying", unit=" set"
):
start = time.perf_counter()
result = [
[index_keys[i], similarity]
Expand All @@ -44,6 +57,4 @@ def search_jaccard_topk(index_data, query_data, k):
duration = time.perf_counter() - start
times.append(duration)
results.append((query_key, result))
sys.stdout.write("\rQueried {} sets.".format(len(results)))
sys.stdout.write("\n")
return (indexing_time, results, times)
return (indexing, results, times)
169 changes: 147 additions & 22 deletions benchmark/indexes/jaccard/hnsw.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,55 @@
import json
import time
import sys

import nmslib
import tqdm
from datasketch.hnsw import HNSW

from utils import compute_jaccard
from utils import (
compute_jaccard,
compute_jaccard_distance,
compute_minhash_jaccard_distance,
lazy_create_minhashes_from_sets,
)


def search_hnsw_jaccard_topk(index_data, query_data, index_params, k):
(index_sets, index_keys) = index_data
(query_sets, query_keys) = query_data
print("Building HNSW Index.")
start = time.perf_counter()
index = nmslib.init(
method="hnsw",
space="jaccard_sparse",
data_type=nmslib.DataType.OBJECT_AS_STRING,
)
index.addDataPointBatch(
[" ".join(str(v) for v in s) for s in index_sets], range(len(index_keys))
)
index.createIndex(index_params)
indexing_time = time.perf_counter() - start
print("Indexing time: {:.3f}.".format(indexing_time))
def search_nswlib_jaccard_topk(index_data, query_data, index_params, k):
import nmslib

(index_sets, index_keys, _, index_cache) = index_data
(query_sets, query_keys, _) = query_data
cache_key = json.dumps(index_params)
if cache_key not in index_cache:
print("Building HNSW Index.")
start = time.perf_counter()
index = nmslib.init(
method="hnsw",
space="jaccard_sparse",
data_type=nmslib.DataType.OBJECT_AS_STRING,
)
index.addDataPointBatch(
[" ".join(str(v) for v in s) for s in index_sets], range(len(index_keys))
)
index.createIndex(index_params)
indexing_time = time.perf_counter() - start
print("Indexing time: {:.3f}.".format(indexing_time))
index_cache[cache_key] = (
index,
{
"indexing_time": indexing_time,
},
)
index, indexing = index_cache[cache_key]
print("Querying.")
times = []
results = []
index.setQueryTimeParams({"efSearch": index_params["efConstruction"]})
for query_set, query_key in zip(query_sets, query_keys):
for query_set, query_key in tqdm.tqdm(
zip(query_sets, query_keys),
total=len(query_keys),
desc="Querying",
unit=" query",
):
start = time.perf_counter()
result, _ = index.knnQuery(" ".join(str(v) for v in query_set), k)
result = [
Expand All @@ -36,6 +59,108 @@ def search_hnsw_jaccard_topk(index_data, query_data, index_params, k):
duration = time.perf_counter() - start
times.append(duration)
results.append((query_key, result))
sys.stdout.write(f"\rQueried {len(results)} sets")
sys.stdout.write("\n")
return (indexing_time, results, times)
return (indexing, results, times)


def search_hnsw_jaccard_topk(index_data, query_data, index_params, k):
(index_sets, index_keys, _, index_cache) = index_data
(query_sets, query_keys, _) = query_data
cache_key = json.dumps(index_params)
if cache_key not in index_cache:
print("Building HNSW Index.")
start = time.perf_counter()
index = HNSW(distance_func=compute_jaccard_distance, **index_params)
for i in tqdm.tqdm(
range(len(index_keys)),
desc="Indexing",
unit=" set",
total=len(index_keys),
):
index.insert(i, index_sets[i])
indexing_time = time.perf_counter() - start
print("Indexing time: {:.3f}.".format(indexing_time))
index_cache[cache_key] = (
index,
{
"indexing_time": indexing_time,
},
)
index, indexing = index_cache[cache_key]
print("Querying.")
times = []
results = []
for query_set, query_key in tqdm.tqdm(
zip(query_sets, query_keys),
total=len(query_keys),
desc="Querying",
unit=" query",
):
start = time.perf_counter()
result = index.query(query_set, k)
# Convert distances to similarities.
result = [(index_keys[i], 1.0 - dist) for i, dist in result]
duration = time.perf_counter() - start
times.append(duration)
results.append((query_key, result))
return (indexing, results, times)


def search_hnsw_minhash_jaccard_topk(index_data, query_data, index_params, k):
(index_sets, index_keys, index_minhashes, index_cache) = index_data
(query_sets, query_keys, query_minhashes) = query_data
num_perm = index_params["num_perm"]
cache_key = json.dumps(index_params)
if cache_key not in index_cache:
# Create minhashes
index_minhash_time, query_minhash_time = lazy_create_minhashes_from_sets(
index_minhashes,
index_sets,
query_minhashes,
query_sets,
num_perm,
)
print("Building HNSW Index for MinHash.")
start = time.perf_counter()
kwargs = index_params.copy()
kwargs.pop("num_perm")
index = HNSW(distance_func=compute_minhash_jaccard_distance, **kwargs)
for i in tqdm.tqdm(
range(len(index_keys)),
desc="Indexing",
unit=" query",
total=len(index_keys),
):
index.insert(i, index_minhashes[num_perm][i])
indexing_time = time.perf_counter() - start
print("Indexing time: {:.3f}.".format(indexing_time))
index_cache[cache_key] = (
index,
{
"index_minhash_time": index_minhash_time,
"query_minhash_time": query_minhash_time,
"indexing_time": indexing_time,
},
)
index, indexing = index_cache[cache_key]
print("Querying.")
times = []
results = []
for query_minhash, query_key, query_set in tqdm.tqdm(
zip(query_minhashes[num_perm], query_keys, query_sets),
total=len(query_keys),
desc="Querying",
unit=" query",
):
start = time.perf_counter()
result = index.query(query_minhash, k)
# Recover the retrieved indexed sets and
# compute the exact Jaccard similarities.
result = [
[index_keys[i], compute_jaccard(query_set, index_sets[i])] for i in result
]
# Sort by similarity.
result.sort(key=lambda x: x[1], reverse=True)
duration = time.perf_counter() - start
times.append(duration)
results.append((query_key, result))
return (indexing, results, times)
66 changes: 44 additions & 22 deletions benchmark/indexes/jaccard/lsh.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,59 @@
import json
import time
import sys

import tqdm

from datasketch import MinHashLSH

from utils import compute_jaccard
from utils import compute_jaccard, lazy_create_minhashes_from_sets


def search_lsh_jaccard_topk(index_data, query_data, b, r, k):
(index_sets, index_keys, index_minhashes) = index_data
def search_lsh_jaccard_topk(index_data, query_data, index_params, k):
(index_sets, index_keys, index_minhashes, index_cache) = index_data
(query_sets, query_keys, query_minhashes) = query_data
b, r = index_params["b"], index_params["r"]
num_perm = b * r
print("Building LSH Index.")
start = time.perf_counter()
index = MinHashLSH(
num_perm=num_perm,
params=(b, r),
)
# Use the indices of the indexed sets as keys in LSH.
for i in range(len(index_keys)):
index.insert(
i,
index_minhashes[num_perm][i],
check_duplication=False,
cache_key = json.dumps(index_params)
if cache_key not in index_cache:
# Create minhashes
index_minhash_time, query_minhash_time = lazy_create_minhashes_from_sets(
index_minhashes,
index_sets,
query_minhashes,
query_sets,
num_perm,
)
print("Building LSH Index.")
start = time.perf_counter()
index = MinHashLSH(num_perm=num_perm, params=(b, r))
# Use the indices of the indexed sets as keys in LSH.
for i in tqdm.tqdm(
range(len(index_keys)),
desc="Indexing",
unit=" minhash",
total=len(index_keys),
):
index.insert(i, index_minhashes[num_perm][i], check_duplication=False)
indexing_time = time.perf_counter() - start
print("Indexing time: {:.3f}.".format(indexing_time))
index_cache[cache_key] = (
index,
{
"index_minhash_time": index_minhash_time,
"query_minhash_time": query_minhash_time,
"indexing_time": indexing_time,
},
)
indexing_time = time.perf_counter() - start
print("Indexing time: {:.3f}.".format(indexing_time))
index, indexing = index_cache[cache_key]
print("Querying.")
times = []
results = []
for query_minhash, query_key, query_set in zip(
query_minhashes[num_perm], query_keys, query_sets
for query_minhash, query_key, query_set in tqdm.tqdm(
zip(query_minhashes[num_perm], query_keys, query_sets),
total=len(query_keys),
desc="Querying",
unit=" query",
):
start = time.perf_counter()
result = index.query(query_minhash)
Expand All @@ -45,6 +69,4 @@ def search_lsh_jaccard_topk(index_data, query_data, b, r, k):
duration = time.perf_counter() - start
times.append(duration)
results.append((query_key, result))
sys.stdout.write(f"\rQueried {len(results)} sets")
sys.stdout.write("\n")
return (indexing_time, results, times)
return (indexing, results, times)
Loading

0 comments on commit 8baa603

Please sign in to comment.