Skip to content

Commit

Permalink
fix: fix redis search w/o filters, fix weaviate for geo, add arxiv-no… (
Browse files Browse the repository at this point in the history
#38)

* fix: fix redis search w/o filters, fix weaviate for geo, add arxiv-no-filters dataset
  • Loading branch information
joein authored Jan 31, 2023
1 parent d4dc925 commit 5585244
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 4 deletions.
8 changes: 8 additions & 0 deletions datasets/datasets.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@
"labels": "keyword",
"submitter": "keyword"
}
},
{
"name": "arxiv-titles-384-angular-no-filters",
"vector_size": 384,
"distance": "cosine",
"type": "tar",
"path": "arxiv-titles-384-angular-no-filters/arxiv_no_filters",
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/arxiv_no_filters.tar.gz"
},
{
"name": "random-match-keyword-100-angular-filters",
Expand Down
8 changes: 7 additions & 1 deletion engine/clients/redis/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic

@classmethod
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
prefilter_condition, params = cls.parser.parse(meta_conditions)
conditions = cls.parser.parse(meta_conditions)
if conditions is None:
prefilter_condition = "*"
params = {}
else:
prefilter_condition, params = conditions

q = (
Query(
f"{prefilter_condition}=>[KNN $K @vector $vec_param EF_RUNTIME $EF AS vector_score]"
Expand Down
18 changes: 17 additions & 1 deletion engine/clients/weaviate/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,25 @@ def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
query = cls.client.query.get(
WEAVIATE_CLASS_NAME, ["_additional {id distance}"]
).with_near_vector(near_vector)

is_geo_query = False
if where_conditions is not None:
operands = where_conditions["operands"]
is_geo_query = any(
operand["operator"] == "WithinGeoRange" for operand in operands
)
query = query.with_where(where_conditions)
res = (query.with_limit(top).do())["data"]["Get"][WEAVIATE_CLASS_NAME]

query_obj = query.with_limit(top)
if is_geo_query:
# weaviate can't handle geo queries in python due to excess quotes in generated queries
gql_query = query_obj.build()
for field in ("geoCoordinates", "latitude", "longitude", "distance", "max"):
gql_query = gql_query.replace(f'"{field}"', field) # get rid of quotes
response = cls.client.query.raw(gql_query)
else:
response = query_obj.do()
res = response["data"]["Get"][WEAVIATE_CLASS_NAME]

id_score_pairs: List[Tuple[int, float]] = []
for obj in res:
Expand Down
14 changes: 13 additions & 1 deletion engine/clients/weaviate/upload.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import time
import uuid
from typing import List, Optional

Expand All @@ -20,6 +19,18 @@ def init_client(cls, host, distance, connection_params, upload_params):
cls.upload_params = upload_params
cls.connection_params = connection_params

@staticmethod
def _update_geo_data(data_object):
keys = data_object.keys()
for key in keys:
if isinstance(data_object[key], dict):
if lat := data_object[key].pop("lat", None):
data_object[key]["latitude"] = lat
if lon := data_object[key].pop("lon", None):
data_object[key]["longitude"] = lon

return data_object

@classmethod
def upload_batch(
cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]
Expand All @@ -33,6 +44,7 @@ def upload_batch(

with cls.client.batch as batch:
for id_, vector, data_object in zip(ids, vectors, metadata):
data_object = cls._update_geo_data(data_object)
batch.add_data_object(
data_object=data_object or {},
class_name=WEAVIATE_CLASS_NAME,
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions sync.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ rsync -avP \
--exclude='__pycache__' \
--exclude='frontend' \
--exclude='.idea' \
--exclude='.git' \
--exclude='datasets/*/' \
. $1:./projects/vector-db-benchmark/

0 comments on commit 5585244

Please sign in to comment.