From 5585244d9940745c10cc1d0d2e6c0dd3857d6a03 Mon Sep 17 00:00:00 2001 From: George Date: Tue, 31 Jan 2023 16:18:32 +0400 Subject: [PATCH] =?UTF-8?q?fix:=20fix=20redis=20search=20w/o=20filters,=20?= =?UTF-8?q?fix=20weaviate=20for=20geo,=20add=20arxiv-no=E2=80=A6=20(#38)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: fix redis search w/o filters, fix weaviate for geo, add arxiv-no-filters dataset --- datasets/datasets.json | 8 ++++++++ engine/clients/redis/search.py | 8 +++++++- engine/clients/weaviate/search.py | 18 +++++++++++++++++- engine/clients/weaviate/upload.py | 14 +++++++++++++- poetry.lock | 2 +- sync.sh | 2 ++ 6 files changed, 48 insertions(+), 4 deletions(-) mode change 100644 => 100755 sync.sh diff --git a/datasets/datasets.json b/datasets/datasets.json index ec6cf94d..ba9cfcbc 100644 --- a/datasets/datasets.json +++ b/datasets/datasets.json @@ -85,6 +85,14 @@ "labels": "keyword", "submitter": "keyword" } + }, + { + "name": "arxiv-titles-384-angular-no-filters", + "vector_size": 384, + "distance": "cosine", + "type": "tar", + "path": "arxiv-titles-384-angular-no-filters/arxiv_no_filters", + "link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/arxiv_no_filters.tar.gz" }, { "name": "random-match-keyword-100-angular-filters", diff --git a/engine/clients/redis/search.py b/engine/clients/redis/search.py index 4c6d0c92..04b868b2 100644 --- a/engine/clients/redis/search.py +++ b/engine/clients/redis/search.py @@ -21,7 +21,13 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic @classmethod def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]: - prefilter_condition, params = cls.parser.parse(meta_conditions) + conditions = cls.parser.parse(meta_conditions) + if conditions is None: + prefilter_condition = "*" + params = {} + else: + prefilter_condition, params = conditions + q = ( Query( f"{prefilter_condition}=>[KNN $K @vector $vec_param EF_RUNTIME $EF AS vector_score]" diff --git a/engine/clients/weaviate/search.py b/engine/clients/weaviate/search.py index 5adcda4b..126e95b3 100644 --- a/engine/clients/weaviate/search.py +++ b/engine/clients/weaviate/search.py @@ -26,9 +26,25 @@ def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]: query = cls.client.query.get( WEAVIATE_CLASS_NAME, ["_additional {id distance}"] ).with_near_vector(near_vector) + + is_geo_query = False if where_conditions is not None: + operands = where_conditions["operands"] + is_geo_query = any( + operand["operator"] == "WithinGeoRange" for operand in operands + ) query = query.with_where(where_conditions) - res = (query.with_limit(top).do())["data"]["Get"][WEAVIATE_CLASS_NAME] + + query_obj = query.with_limit(top) + if is_geo_query: + # weaviate can't handle geo queries in python due to excess quotes in generated queries + gql_query = query_obj.build() + for field in ("geoCoordinates", "latitude", "longitude", "distance", "max"): + gql_query = gql_query.replace(f'"{field}"', field) # get rid of quotes + response = cls.client.query.raw(gql_query) + else: + response = query_obj.do() + res = response["data"]["Get"][WEAVIATE_CLASS_NAME] id_score_pairs: List[Tuple[int, float]] = [] for obj in res: diff --git a/engine/clients/weaviate/upload.py b/engine/clients/weaviate/upload.py index 9ef17a1a..eca19221 100644 --- a/engine/clients/weaviate/upload.py +++ b/engine/clients/weaviate/upload.py @@ -1,4 +1,3 @@ -import time import uuid from typing import List, Optional @@ -20,6 +19,18 @@ def init_client(cls, host, distance, connection_params, upload_params): cls.upload_params = upload_params cls.connection_params = connection_params + @staticmethod + def _update_geo_data(data_object): + keys = data_object.keys() + for key in keys: + if isinstance(data_object[key], dict): + if lat := data_object[key].pop("lat", None): + data_object[key]["latitude"] = lat + if lon := data_object[key].pop("lon", None): + data_object[key]["longitude"] = lon + + return data_object + @classmethod def upload_batch( cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]] @@ -33,6 +44,7 @@ def upload_batch( with cls.client.batch as batch: for id_, vector, data_object in zip(ids, vectors, metadata): + data_object = cls._update_geo_data(data_object) batch.add_data_object( data_object=data_object or {}, class_name=WEAVIATE_CLASS_NAME, diff --git a/poetry.lock b/poetry.lock index 17dcdc93..1f75193f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1804,4 +1804,4 @@ validators = ">=0.18.2,<0.20.0" [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "fa147cdc33d2be1373deb6755e3f8f89582b0cede7adbf245b201e0c89e10627" +content-hash = "fa147cdc33d2be1373deb6755e3f8f89582b0cede7adbf245b201e0c89e10627" \ No newline at end of file diff --git a/sync.sh b/sync.sh old mode 100644 new mode 100755 index 169e6724..c0b540f6 --- a/sync.sh +++ b/sync.sh @@ -9,4 +9,6 @@ rsync -avP \ --exclude='__pycache__' \ --exclude='frontend' \ --exclude='.idea' \ + --exclude='.git' \ + --exclude='datasets/*/' \ . $1:./projects/vector-db-benchmark/