Skip to content

Commit

Permalink
Use Weaviate v4 gRPC based client
Browse files Browse the repository at this point in the history
  • Loading branch information
trengrj committed Mar 12, 2024
1 parent 20ec952 commit 3825c98
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 165 deletions.
21 changes: 14 additions & 7 deletions engine/clients/weaviate/configure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from weaviate import Client
from weaviate import WeaviateClient
from weaviate.connect import ConnectionParams

from benchmark.dataset import Dataset
from engine.base_client.configure import BaseConfigurator
Expand All @@ -23,16 +24,17 @@ class WeaviateConfigurator(BaseConfigurator):
def __init__(self, host, collection_params: dict, connection_params: dict):
super().__init__(host, collection_params, connection_params)
url = f"http://{host}:{connection_params.get('port', WEAVIATE_DEFAULT_PORT)}"
self.client = Client(url, **connection_params)
client = WeaviateClient(
ConnectionParams.from_url(url, 50051), skip_init_checks=True
)
client.connect()
self.client = client

def clean(self):
classes = self.client.schema.get()
for cl in classes["classes"]:
if cl["class"] == WEAVIATE_CLASS_NAME:
self.client.schema.delete_class(WEAVIATE_CLASS_NAME)
self.client.collections.delete(WEAVIATE_CLASS_NAME)

def recreate(self, dataset: Dataset, collection_params):
self.client.schema.create_class(
self.client.collections.create_from_dict(
{
"class": WEAVIATE_CLASS_NAME,
"vectorizer": "none",
Expand All @@ -55,3 +57,8 @@ def recreate(self, dataset: Dataset, collection_params):
},
}
)
self.client.close()

def __del__(self):
if self.client.is_connected():
self.client.close()
101 changes: 40 additions & 61 deletions engine/clients/weaviate/parser.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
from typing import Any, Dict, List, Optional

from engine.base_client import IncompatibilityError
from weaviate.collections.classes.filters import _Filters

from engine.base_client.parser import BaseConditionParser, FieldValue
import weaviate.classes as wvc


class WeaviateConditionParser(BaseConditionParser):
def parse(self, meta_conditions: Dict[str, Any]) -> Optional[Any]:
def parse(self, meta_conditions: Dict[str, Any]) -> Optional[_Filters]:
if meta_conditions is None or len(meta_conditions) == 0:
return None
return super().parse(meta_conditions)

def build_condition(
self, and_subfilters: Optional[List[Any]], or_subfilters: Optional[List[Any]]
) -> Optional[Any]:
clause = {}
self,
and_subfilters: Optional[List[_Filters]],
or_subfilters: Optional[List[_Filters]],
) -> Optional[_Filters]:
weaviate_filter = None
if or_subfilters is not None and len(or_subfilters) > 0:
clause = {
"operator": "Or",
"operands": or_subfilters,
}
weaviate_filter = or_subfilters[0]
for filt in or_subfilters[1:]:
weaviate_filter = weaviate_filter | filt

if and_subfilters is not None and len(and_subfilters) > 0:
clause = {
"operator": "And",
"operands": and_subfilters + [clause]
if len(clause) > 0
else and_subfilters,
}
return clause
if weaviate_filter is not None:
weaviate_filter = and_subfilters[0] & weaviate_filter
else:
weaviate_filter = and_subfilters[0]
for filt in and_subfilters[1:]:
weaviate_filter = weaviate_filter & filt
return weaviate_filter

def build_exact_match_filter(self, field_name: str, value: FieldValue) -> Any:
return {
"operator": "Equal",
"path": [field_name],
self.value_key(value): value,
}
def build_exact_match_filter(self, field_name: str, value: FieldValue) -> _Filters:
return wvc.query.Filter.by_property(field_name).equal(value)

def build_range_filter(
self,
Expand All @@ -43,45 +43,24 @@ def build_range_filter(
lte: Optional[FieldValue],
gte: Optional[FieldValue],
) -> Any:
clauses = {
"LessThan": lt,
"GreaterThan": gt,
"LessThanEqual": lte,
"GreaterThanEqual": gte,
}
return {
"operator": "And",
"operands": [
{
"operator": op,
"path": [field_name],
self.value_key(value): value,
}
for op, value in clauses.items()
if value is not None
],
}
prop = wvc.query.Filter.by_property(field_name)
ltf = prop.less_than(lt) if lt is not None else None
ltef = prop.less_or_equal(lte) if lte is not None else None
gtf = prop.greater_than(gt) if gt is not None else None
gtef = prop.greater_or_equal(gte) if gte is not None else None
filtered_lst = list(filter(lambda x: x is not None, [ltf, ltef, gtf, gtef]))
if len(filtered_lst) == 0:
return filtered_lst

result = filtered_lst[0]
for filt in filtered_lst[1:]:
result = result & filt
return result

def build_geo_filter(
self, field_name: str, lat: float, lon: float, radius: float
) -> Any:
return {
"operator": "WithinGeoRange",
"path": [field_name],
"valueGeoRange": {
"geoCoordinates": {
"latitude": lat,
"longitude": lon,
},
"distance": {"max": radius},
},
}

def value_key(self, value: FieldValue) -> str:
if isinstance(value, str):
return "valueString"
if isinstance(value, int):
return "valueInt"
if isinstance(value, float):
return "valueNumber"
raise IncompatibilityError
) -> _Filters:
return wvc.query.Filter.by_property(field_name).within_geo_range(
distance=radius,
coordinate=wvc.query.GeoCoordinate(latitude=lat, longitude=lon),
)
71 changes: 34 additions & 37 deletions engine/clients/weaviate/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import uuid
from typing import List, Tuple

from weaviate import Client
from weaviate import WeaviateClient
from weaviate.collections import Collection
from weaviate.connect import ConnectionParams
from weaviate.classes.query import MetadataQuery
from weaviate.classes.config import Reconfigure

from engine.base_client.search import BaseSearcher
from engine.clients.weaviate.config import WEAVIATE_CLASS_NAME, WEAVIATE_DEFAULT_PORT
Expand All @@ -10,49 +14,42 @@

class WeaviateSearcher(BaseSearcher):
search_params = {}
client: Client = None
parser = WeaviateConditionParser()
collection: Collection
client: WeaviateClient

@classmethod
def init_client(cls, host, distance, connection_params: dict, search_params: dict):
url = f"http://{host}:{connection_params.get('port', WEAVIATE_DEFAULT_PORT)}"
cls.client = Client(url, **connection_params)
client = WeaviateClient(
ConnectionParams.from_url(url, 50051), skip_init_checks=True
)
client.connect()
cls.collection = client.collections.get(
WEAVIATE_CLASS_NAME, skip_argument_validation=True
)
cls.search_params = search_params
cls.client = client

@classmethod
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
near_vector = {"vector": vector}
where_conditions = cls.parser.parse(meta_conditions)
query = cls.client.query.get(
WEAVIATE_CLASS_NAME, ["_additional {id distance}"]
).with_near_vector(near_vector)

is_geo_query = False
if where_conditions is not None:
operands = where_conditions["operands"]
is_geo_query = any(
operand["operator"] == "WithinGeoRange" for operand in operands
)
query = query.with_where(where_conditions)

query_obj = query.with_limit(top)
if is_geo_query:
# weaviate can't handle geo queries in python due to excess quotes in generated queries
gql_query = query_obj.build()
for field in ("geoCoordinates", "latitude", "longitude", "distance", "max"):
gql_query = gql_query.replace(f'"{field}"', field) # get rid of quotes
response = cls.client.query.raw(gql_query)
else:
response = query_obj.do()
res = response["data"]["Get"][WEAVIATE_CLASS_NAME]

id_score_pairs: List[Tuple[int, float]] = []
for obj in res:
description = obj["_additional"]
score = description["distance"]
id_ = uuid.UUID(hex=description["id"]).int
id_score_pairs.append((id_, score))
return id_score_pairs
def search_one(self, vector, meta_conditions, top) -> List[Tuple[int, float]]:
res = self.collection.query.near_vector(
near_vector=vector,
filters=self.parser.parse(meta_conditions),
limit=top,
return_metadata=MetadataQuery(distance=True),
return_properties=[],
)
return [(hit.uuid.int, hit.metadata.distance) for hit in res.objects]

def setup_search(self):
self.client.schema.update_config(WEAVIATE_CLASS_NAME, self.search_params)
self.collection.config.update(
vector_index_config=Reconfigure.VectorIndex.hnsw(
ef=self.search_params["vectorIndexConfig"]["ef"]
)
)

@classmethod
def delete_client(cls):
if cls.client is not None:
cls.client.close()
56 changes: 23 additions & 33 deletions engine/clients/weaviate/upload.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,45 @@
import uuid
from typing import List, Optional

from weaviate import Client
from weaviate import WeaviateClient
from weaviate.connect import ConnectionParams
from weaviate.classes.data import DataObject

from engine.base_client.upload import BaseUploader
from engine.clients.weaviate.config import WEAVIATE_CLASS_NAME, WEAVIATE_DEFAULT_PORT


class WeaviateUploader(BaseUploader):
client = None
client: WeaviateClient = None
upload_params = {}
collection = None

@classmethod
def init_client(cls, host, distance, connection_params, upload_params):
url = f"http://{host}:{connection_params.get('port', WEAVIATE_DEFAULT_PORT)}"
cls.client = Client(url, **connection_params)

cls.client = WeaviateClient(
ConnectionParams.from_url(url, 50051), skip_init_checks=True
)
cls.client.connect()
cls.upload_params = upload_params
cls.connection_params = connection_params

@staticmethod
def _update_geo_data(data_object):
keys = data_object.keys()
for key in keys:
if isinstance(data_object[key], dict):
if lat := data_object[key].get("lat", None):
data_object[key]["latitude"] = lat
if lon := data_object[key].get("lon", None):
data_object[key]["longitude"] = lon

return data_object
cls.collection = cls.client.collections.get(
WEAVIATE_CLASS_NAME, skip_argument_validation=True
)

@classmethod
def upload_batch(
cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]
):
# Weaviate introduced the batch_size, so it can handle built-in client's
# multi-threading. That should make the upload faster.
cls.client.batch.configure(
batch_size=100,
timeout_retries=3,
)
objects = []
for i in range(len(ids)):
id = uuid.UUID(int=ids[i])
property = metadata[i] or {}
objects.append(DataObject(properties=property, vector=vectors[i], uuid=id))
if len(objects) > 0:
cls.collection.data.insert_many(objects)

with cls.client.batch as batch:
for id_, vector, data_object in zip(ids, vectors, metadata):
data_object = cls._update_geo_data(data_object or {})
batch.add_data_object(
data_object=data_object,
class_name=WEAVIATE_CLASS_NAME,
uuid=uuid.UUID(int=id_).hex,
vector=vector,
)

batch.create_objects()
@classmethod
def delete_client(cls):
if cls.client is not None:
cls.client.close()
2 changes: 1 addition & 1 deletion engine/servers/qdrant-single-node/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ version: '3.7'

services:
qdrant_bench:
image: ${CONTAINER_REGISTRY:-docker.io}/qdrant/qdrant:v1.7.3
image: ${CONTAINER_REGISTRY:-docker.io}/qdrant/qdrant:v1.8.1
network_mode: host
logging:
driver: "json-file"
Expand Down
6 changes: 2 additions & 4 deletions engine/servers/weaviate-single-node/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ services:
- '8090'
- --scheme
- http
image: semitechnologies/weaviate:1.21.5
ports:
- "8090:8090"
image: semitechnologies/weaviate:1.24.1
network_mode: host
logging:
driver: "json-file"
options:
Expand All @@ -24,7 +23,6 @@ services:
ENABLE_MODULES: ''
CLUSTER_HOSTNAME: 'node1'
GOMEMLIMIT: 25GiB # https://weaviate.io/blog/gomemlimit-a-game-changer-for-high-memory-applications
GOGC: 50
deploy:
resources:
limits:
Expand Down
Loading

0 comments on commit 3825c98

Please sign in to comment.