From ff62150c70c020ca775f3073d43cbbc370e06eea Mon Sep 17 00:00:00 2001 From: Uche Ogbuji Date: Mon, 17 Jul 2023 15:32:34 -0600 Subject: [PATCH] [#16] Eliminate qdrant_collection.add() and consolidate into update() --- demo/chat_web_selects.py | 2 +- pylib/embedding_helper.py | 87 ++++++++++++++------------------------- 2 files changed, 31 insertions(+), 58 deletions(-) diff --git a/demo/chat_web_selects.py b/demo/chat_web_selects.py index ccf18a4..d8e2f97 100644 --- a/demo/chat_web_selects.py +++ b/demo/chat_web_selects.py @@ -78,7 +78,7 @@ async def read_site(url, collection): # Crude—for demo. Set URL metadata for all chunks to doc URL metas = [{'url': url}]*len(chunks) # Add the text to the collection. Blocks, so no reentrancy concern - collection.add(texts=chunks, metas=metas) + collection.update(texts=chunks, metas=metas) async def async_main(sites, api_params): diff --git a/pylib/embedding_helper.py b/pylib/embedding_helper.py index 1abb65b..80b717b 100644 --- a/pylib/embedding_helper.py +++ b/pylib/embedding_helper.py @@ -46,7 +46,8 @@ class qdrant_collection: - def __init__(self, name, embedding_model, db=None, **conn_params): + def __init__(self, name, embedding_model, db=None, + distance_function=None, **conn_params): ''' Initialize a Qdrant client @@ -58,6 +59,8 @@ def __init__(self, name, embedding_model, db=None, **conn_params): db (optional QdrantClient): existing DB/client to use + distance_function (str): Distance function by which vectors will be compared + conn_params (mapping): keyword parameters for setting up QdrantClient See the main docstring (or run `help(QdrantClient)`) https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L12 @@ -87,68 +90,29 @@ def __init__(self, name, embedding_model, db=None, **conn_params): if not conn_params: conn_params = MEMORY_QDRANT_CONNECTION_PARAMS self.db = QdrantClient(**conn_params) - self._vector_size = -1 + self._distance_function = distance_function or models.Distance.COSINE + self._db_initialized = False - def _determine_vector_size(self, text): + def _first_update_prep(self, text): + # Make sure we have a vector size set; use a sample embedding if need be partial_embeddings = self._embedding_model.encode(text) self._vector_size = len(partial_embeddings) - def add(self, texts, distance_function='Cosine', metas=None): - ''' - Add a collection to a Qdrant client, and add some strings (chunks) to that collection - - Args: - chunks (List[str]): List of similar length strings to embed - - distance_function (str): Distance function by which vectors will be compared - - qdrant_conn_params (mapping): keyword parameters for setting up QdrantClient - See the main docstring (or run `help(QdrantClient)`) - https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L12 - ''' - if len(texts) == 0: - warnings.warn(f'Empty sequence of texts provided. No action will be taken.') - return - - if metas is None: - metas = [] - else: - if len(texts) > len(metas): - warnings.warn(f'More texts ({len(texts)} provided than metadata {len(metas)}). Extra metadata items will be ignored.') - metas = itertools.chain(metas, [{}]*(len(texts)-len(texts))) - elif len(metas) > len(texts): - warnings.warn(f'Fewer texts ({len(texts)} provided than metadata {len(metas)}). ' - 'The extra text will be given empty metadata.') - metas = itertools.islice(metas, len(texts)) - - # meta is a list of dicts - # Find the size of the first chunk's embedding - if self._vector_size == -1: - self._determine_vector_size(texts[0]) - - # Set the default distance function, giving grace to capitalization - distance_function = distance_function.lower().capitalize() - distance_function = models.Distance.COSINE - # Create a collection in the Qdrant client, and configure its vectors # Using REcreate_collection ensures overwrite self.db.recreate_collection( collection_name=self.name, vectors_config=models.VectorParams( size=self._vector_size, - distance=distance_function + distance=self._distance_function ) ) - # Put the items in the collection - self.upsert(texts=texts, metas=metas) - - # current_count = int(str(self.db.count(self.name)).partition('=')[-1]) - # print('COLLECTION COUNT:', current_count) + self._db_initialized = True - def upsert(self, texts, metas=None): + def update(self, texts, metas=None): ''' - Update/insert a Qdrant client's collection with the some chunks of text + Update/insert into a Qdrant client's collection with the some chunks of text Args: texts (List[str]): Strings to be stored and indexed. For best results these should be of similar length. @@ -157,17 +121,26 @@ def upsert(self, texts, metas=None): metas (List[dict]): Optional metadata per text, stored with the text and included whenever the text is retrieved via search/query ''' - metas = metas or [] + if len(texts) == 0: + warnings.warn('Empty sequence of texts provided. No action will be taken.') + return - if len(texts) > len(metas): - warnings.warn(f'More texts ({len(texts)} provided than metadata {len(metas)}). Extra metadata items will be ignored.') - metas = itertools.chain(metas, [{}]*(len(texts)-len(texts))) - elif len(metas) > len(texts): - warnings.warn(f'Fewer texts ({len(texts)} provided than metadata {len(metas)}). ' - 'The extra text will be given empty metadata.') - metas = itertools.islice(metas, len(texts)) + if metas is None: + metas = [] + else: + if len(texts) > len(metas): + warnings.warn(f'More texts ({len(texts)} provided than metadata {len(metas)}). Extra metadata items will be ignored.') + metas = itertools.chain(metas, [{}]*(len(texts)-len(texts))) + elif len(metas) > len(texts): + warnings.warn(f'Fewer texts ({len(texts)} provided than metadata {len(metas)}). ' + 'The extra text will be given empty metadata.') + metas = itertools.islice(metas, len(texts)) - before_count = self.count() + if not self._db_initialized: + self._first_update_prep(texts[0]) + before_count = 0 + else: + before_count = self.count() for ix, (text, meta) in enumerate(zip(texts, metas)): # Embeddings as float/vectors