Skip to content

Commit

Permalink
Lint error fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Dec 4, 2024
1 parent 0507893 commit 75576cf
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions dicee/scripts/index_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http.models import PointStruct

from fastapi import FastAPI
import uvicorn

# from qdrant_client.http.models import Filter, FieldCondition, MatchValue

def get_default_arguments():
parser = argparse.ArgumentParser(add_help=False)
Expand Down Expand Up @@ -69,45 +70,33 @@ def index(args):
print("Completed!")


import argparse
from fastapi import FastAPI
import uvicorn
from qdrant_client import QdrantClient
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
import pandas as pd
from typing import List

app = FastAPI()
# Create a neural searcher instance
neural_searcher = None

class NeuralSearcher:
def __init__(self, args):
self.collection_name = args.collection
self.collection_name = args.collection
assert os.path.exists(args.path + "/entity_to_idx.csv"), f"{args.path + '/entity_to_idx.csv'} does not exist!"
self.entity_to_idx = pd.read_csv(args.path + "/entity_to_idx.csv", index_col=0)
assert self.entity_to_idx.index.is_monotonic_increasing, "Entity Index must be monotonically increasing!{}"
self.entity_to_idx = {name: idx for idx, name in enumerate(self.entity_to_idx["entity"].tolist())}
# initialize Qdrant client
self.qdrant_client = QdrantClient(host=args.vdb_host,port=args.vdb_port)
# semantic search
self.topk=5

def get(self,entity:str|List[str]=None):
def get(self,entity:str=None):
if entity is None:
return {"Input {entity} cannot be None"}
elif self.entity_to_idx.get(entity,None) is None:
return {f"Input {entity} not found"}
else:
if isinstance(entity,str):
ids=[self.entity_to_idx[entity]]
else:
if isinstance(entity,list):
ids = [self.entity_to_idx[ent] for ent in entity]
else:
return {"Error":f"Input must be a string or a list of strings!Given:{entity} with type {type(entity)}"}
return self.qdrant_client.retrieve(collection_name=self.collection_name,ids=ids,with_vectors=True)
ids=[self.entity_to_idx[entity]]
return self.qdrant_client.retrieve(collection_name=self.collection_name,ids=ids, with_vectors=True)

def search(self, entity: str):
return self.qdrant_client.query_points(collection_name=self.collection_name,query=self.entity_to_idx[entity],limit=5)
return self.qdrant_client.query_points(collection_name=self.collection_name, query=self.entity_to_idx[entity],limit=self.topk)

@app.get("/")
async def root():
Expand Down

0 comments on commit 75576cf

Please sign in to comment.