From 75576cf23a257a427bf88058afd7896166b6d31d Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Wed, 4 Dec 2024 13:41:09 +0100 Subject: [PATCH] Lint error fixed --- dicee/scripts/index_serve.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/dicee/scripts/index_serve.py b/dicee/scripts/index_serve.py index 84b332fd..7430aab5 100644 --- a/dicee/scripts/index_serve.py +++ b/dicee/scripts/index_serve.py @@ -11,8 +11,9 @@ from qdrant_client.http.models import Distance, VectorParams from qdrant_client.http.models import PointStruct +from fastapi import FastAPI +import uvicorn -# from qdrant_client.http.models import Filter, FieldCondition, MatchValue def get_default_arguments(): parser = argparse.ArgumentParser(add_help=False) @@ -69,14 +70,6 @@ def index(args): print("Completed!") -import argparse -from fastapi import FastAPI -import uvicorn -from qdrant_client import QdrantClient -from qdrant_client.http.models import Filter, FieldCondition, MatchValue -import pandas as pd -from typing import List - app = FastAPI() # Create a neural searcher instance neural_searcher = None @@ -84,30 +77,26 @@ def index(args): class NeuralSearcher: def __init__(self, args): self.collection_name = args.collection - self.collection_name = args.collection + assert os.path.exists(args.path + "/entity_to_idx.csv"), f"{args.path + '/entity_to_idx.csv'} does not exist!" self.entity_to_idx = pd.read_csv(args.path + "/entity_to_idx.csv", index_col=0) assert self.entity_to_idx.index.is_monotonic_increasing, "Entity Index must be monotonically increasing!{}" self.entity_to_idx = {name: idx for idx, name in enumerate(self.entity_to_idx["entity"].tolist())} # initialize Qdrant client self.qdrant_client = QdrantClient(host=args.vdb_host,port=args.vdb_port) + # semantic search + self.topk=5 - def get(self,entity:str|List[str]=None): + def get(self,entity:str=None): if entity is None: return {"Input {entity} cannot be None"} elif self.entity_to_idx.get(entity,None) is None: return {f"Input {entity} not found"} else: - if isinstance(entity,str): - ids=[self.entity_to_idx[entity]] - else: - if isinstance(entity,list): - ids = [self.entity_to_idx[ent] for ent in entity] - else: - return {"Error":f"Input must be a string or a list of strings!Given:{entity} with type {type(entity)}"} - return self.qdrant_client.retrieve(collection_name=self.collection_name,ids=ids,with_vectors=True) + ids=[self.entity_to_idx[entity]] + return self.qdrant_client.retrieve(collection_name=self.collection_name,ids=ids, with_vectors=True) def search(self, entity: str): - return self.qdrant_client.query_points(collection_name=self.collection_name,query=self.entity_to_idx[entity],limit=5) + return self.qdrant_client.query_points(collection_name=self.collection_name, query=self.entity_to_idx[entity],limit=self.topk) @app.get("/") async def root():