Skip to content

Commit

Permalink
[#16] Eliminate qdrant_collection.add() and consolidate into update()
Browse files Browse the repository at this point in the history
  • Loading branch information
uogbuji committed Jul 17, 2023
1 parent 7c49d48 commit ff62150
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 58 deletions.
2 changes: 1 addition & 1 deletion demo/chat_web_selects.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def read_site(url, collection):
# Crude—for demo. Set URL metadata for all chunks to doc URL
metas = [{'url': url}]*len(chunks)
# Add the text to the collection. Blocks, so no reentrancy concern
collection.add(texts=chunks, metas=metas)
collection.update(texts=chunks, metas=metas)


async def async_main(sites, api_params):
Expand Down
87 changes: 30 additions & 57 deletions pylib/embedding_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@


class qdrant_collection:
def __init__(self, name, embedding_model, db=None, **conn_params):
def __init__(self, name, embedding_model, db=None,
distance_function=None, **conn_params):
'''
Initialize a Qdrant client
Expand All @@ -58,6 +59,8 @@ def __init__(self, name, embedding_model, db=None, **conn_params):
db (optional QdrantClient): existing DB/client to use
distance_function (str): Distance function by which vectors will be compared
conn_params (mapping): keyword parameters for setting up QdrantClient
See the main docstring (or run `help(QdrantClient)`)
https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L12
Expand Down Expand Up @@ -87,68 +90,29 @@ def __init__(self, name, embedding_model, db=None, **conn_params):
if not conn_params:
conn_params = MEMORY_QDRANT_CONNECTION_PARAMS
self.db = QdrantClient(**conn_params)
self._vector_size = -1
self._distance_function = distance_function or models.Distance.COSINE
self._db_initialized = False

def _determine_vector_size(self, text):
def _first_update_prep(self, text):
# Make sure we have a vector size set; use a sample embedding if need be
partial_embeddings = self._embedding_model.encode(text)
self._vector_size = len(partial_embeddings)

def add(self, texts, distance_function='Cosine', metas=None):
'''
Add a collection to a Qdrant client, and add some strings (chunks) to that collection
Args:
chunks (List[str]): List of similar length strings to embed
distance_function (str): Distance function by which vectors will be compared
qdrant_conn_params (mapping): keyword parameters for setting up QdrantClient
See the main docstring (or run `help(QdrantClient)`)
https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L12
'''
if len(texts) == 0:
warnings.warn(f'Empty sequence of texts provided. No action will be taken.')
return

if metas is None:
metas = []
else:
if len(texts) > len(metas):
warnings.warn(f'More texts ({len(texts)} provided than metadata {len(metas)}). Extra metadata items will be ignored.')
metas = itertools.chain(metas, [{}]*(len(texts)-len(texts)))
elif len(metas) > len(texts):
warnings.warn(f'Fewer texts ({len(texts)} provided than metadata {len(metas)}). '
'The extra text will be given empty metadata.')
metas = itertools.islice(metas, len(texts))

# meta is a list of dicts
# Find the size of the first chunk's embedding
if self._vector_size == -1:
self._determine_vector_size(texts[0])

# Set the default distance function, giving grace to capitalization
distance_function = distance_function.lower().capitalize()
distance_function = models.Distance.COSINE

# Create a collection in the Qdrant client, and configure its vectors
# Using REcreate_collection ensures overwrite
self.db.recreate_collection(
collection_name=self.name,
vectors_config=models.VectorParams(
size=self._vector_size,
distance=distance_function
distance=self._distance_function
)
)

# Put the items in the collection
self.upsert(texts=texts, metas=metas)

# current_count = int(str(self.db.count(self.name)).partition('=')[-1])
# print('COLLECTION COUNT:', current_count)
self._db_initialized = True

def upsert(self, texts, metas=None):
def update(self, texts, metas=None):
'''
Update/insert a Qdrant client's collection with the some chunks of text
Update/insert into a Qdrant client's collection with the some chunks of text
Args:
texts (List[str]): Strings to be stored and indexed. For best results these should be of similar length.
Expand All @@ -157,17 +121,26 @@ def upsert(self, texts, metas=None):
metas (List[dict]): Optional metadata per text, stored with the text and included whenever the text is
retrieved via search/query
'''
metas = metas or []
if len(texts) == 0:
warnings.warn('Empty sequence of texts provided. No action will be taken.')
return

if len(texts) > len(metas):
warnings.warn(f'More texts ({len(texts)} provided than metadata {len(metas)}). Extra metadata items will be ignored.')
metas = itertools.chain(metas, [{}]*(len(texts)-len(texts)))
elif len(metas) > len(texts):
warnings.warn(f'Fewer texts ({len(texts)} provided than metadata {len(metas)}). '
'The extra text will be given empty metadata.')
metas = itertools.islice(metas, len(texts))
if metas is None:
metas = []
else:
if len(texts) > len(metas):
warnings.warn(f'More texts ({len(texts)} provided than metadata {len(metas)}). Extra metadata items will be ignored.')
metas = itertools.chain(metas, [{}]*(len(texts)-len(texts)))
elif len(metas) > len(texts):
warnings.warn(f'Fewer texts ({len(texts)} provided than metadata {len(metas)}). '
'The extra text will be given empty metadata.')
metas = itertools.islice(metas, len(texts))

before_count = self.count()
if not self._db_initialized:
self._first_update_prep(texts[0])
before_count = 0
else:
before_count = self.count()

for ix, (text, meta) in enumerate(zip(texts, metas)):
# Embeddings as float/vectors
Expand Down

0 comments on commit ff62150

Please sign in to comment.