diff --git a/pylib/embedding/pgvector_doc.py b/pylib/embedding/pgvector_doc.py index 0a1993a..c190344 100644 --- a/pylib/embedding/pgvector_doc.py +++ b/pylib/embedding/pgvector_doc.py @@ -6,6 +6,7 @@ Vector databases embeddings using PGVector ''' +import warnings from typing import Iterable from ogbujipt.embedding.pgvector import PGVectorHelper, asyncpg @@ -13,57 +14,220 @@ __all__ = ['DocDB'] -# Generic SQL template for creating a table to hold embedded documents -CREATE_DOC_TABLE = '''-- Create a table to hold embedded documents -CREATE TABLE IF NOT EXISTS {table_name} ( +CREATE_TABLE_BASE = '''-- Create a table to hold embedded documents or data +CREATE TABLE IF NOT EXISTS {{table_name}} ( id BIGSERIAL PRIMARY KEY, - embedding VECTOR({embed_dimension}), -- embedding vectors (array dimension) + embedding VECTOR({{embed_dimension}}), -- embedding vectors (array dimension) content TEXT NOT NULL, -- text content of the chunk - title TEXT, -- title of file - page_numbers INTEGER[], -- page number of the document that the chunk is found in tags TEXT[] -- tags associated with the chunk -); +{extra_fields}); ''' -INSERT_DOCS = '''-- Insert a document into a table -INSERT INTO {table_name} ( +CREATE_DOC_TABLE = CREATE_TABLE_BASE.format(extra_fields='''\ +, title TEXT, -- title of file + page_numbers INTEGER[] -- page number of the document that the chunk is found in +''') + +CREATE_DATA_TABLE = CREATE_TABLE_BASE.format(extra_fields='') + +INSERT_BASE = '''-- Insert a document into a table +INSERT INTO {{table_name}} ( embedding, content, - title, - page_numbers, tags -) VALUES ($1, $2, $3, $4, $5); + {extra_fields}) VALUES ($1, $2, $3{extra_vals}); ''' -QUERY_DOC_TABLE = '''-- Semantic search a document +INSERT_DOCS = INSERT_BASE.format(extra_fields='''\ +, title, + page_numbers +''', extra_vals=', $4, $5') + +INSERT_DATA = INSERT_BASE.format(extra_fields='', extra_vals='') + +QUERY_TABLE_BASE = '''-- Semantic search a document SELECT * FROM -- Subquery to calculate cosine similarity, required to use the alias in the WHERE clause ( SELECT 1 - (embedding <=> $1) AS cosine_similarity, - title, content, - page_numbers, tags + {extra_fields} FROM - {table_name} + {{table_name}} ) subquery -{where_clauses} +{{where_clauses}} ORDER BY cosine_similarity DESC -{limit_clause}; +{{limit_clause}}; ''' +QUERY_DOC_TABLE = QUERY_TABLE_BASE.format(extra_fields='''\ +, title, + page_numbers +''') + +QUERY_DATA_TABLE = QUERY_TABLE_BASE.format(extra_fields='') + TITLE_WHERE_CLAUSE = 'title = {query_title} -- Equals operator \n' PAGE_NUMBERS_WHERE_CLAUSE = 'page_numbers && {query_page_numbers} -- Overlap operator \n' -TAGS_WHERE_CLAUSE_CONJ = 'tags @> {query_tags} -- Contains operator \n' -TAGS_WHERE_CLAUSE_DISJ = 'tags && {query_tags} -- Overlap operator \n' +TAGS_WHERE_CLAUSE_CONJ = 'tags @> {tags} -- Contains operator \n' +TAGS_WHERE_CLAUSE_DISJ = 'tags && {tags} -- Overlap operator \n' THRESHOLD_WHERE_CLAUSE = '{query_threshold} >= cosine_similarity\n' +# XXX: Data vs doc DB can probably be modularized further, but this will do for now +class DataDB(PGVectorHelper): + ''' Specialize PGvectorHelper for data (snippets) ''' + async def create_table(self) -> None: + ''' + Create the table to hold embedded documents + ''' + await self.conn.execute( + CREATE_DATA_TABLE.format( + table_name=self.table_name, + embed_dimension=self._embed_dimension) + ) + + async def insert( + self, + content: str, + tags: list[str] = [] + ) -> None: + ''' + Update a table with one embedded document + + Args: + content (str): text content of the document + + title (str, optional): title of the document + + page_numbers (list[int], optional): page number of the document that the chunk is found in + + tags (list[str], optional): tags associated with the document + ''' + # Get the embedding of the content as a PGvector compatible list + content_embedding = self._embedding_model.encode(content) + + await self.conn.execute( + INSERT_DATA.format(table_name=self.table_name), + content_embedding.tolist(), + content, + tags + ) + + async def insert_many( + self, + content_list: Iterable[tuple[str, list[str]]] + ) -> None: + ''' + Update a table with one or more embedded documents + + Semantically equivalent to multiple insert_doc calls, but uses executemany for efficiency + + Args: + content_list: List of tuples, each of the form: (content, title, page_numbers, tags) + ''' + await self.conn.executemany( + INSERT_DOCS.format(table_name=self.table_name), + ( + (self._embedding_model.encode(content), content, tags) + for content, tags in content_list + ) + ) + + async def search( + self, + text: str, + tags: list[str] | None = None, + threshold: float | None = None, + limit: int = 0, + conjunctive: bool = True, + query_tags: list[str] | None = None, + ) -> list[asyncpg.Record]: + ''' + Similarity search documents using a query string + + Args: + text (str): string to compare against items in the table. + This will be a vector/fuzzy/nearest-neighbor type search. + + tags (list[str], optional): tags associated with the document to compare against items in the table. + Each individual tag must match exactly, but see the conjunctive param + for how multiple tags are interpreted. + + limit (int, optional): maximum number of results to return (useful for top-k query) + Default is no limit + + conjunctive (bool, optional): whether to use conjunctive (AND) or disjunctive (OR) matching + in the case of multiple tags. Defaults to True. + Returns: + list[asyncpg.Record]: list of search results + (asyncpg.Record objects are similar to dicts, but allow for attribute-style access) + ''' + if query_tags is not None: + warnings.warn('query_tags is deprecated. Use tags instead.', DeprecationWarning) + tags = query_tags + # else: + # if not isinstance(query_tags, list): + # raise TypeError('query_tags must be a list of strings') + # if not all(isinstance(tag, str) for tag in query_tags): + # raise TypeError('query_tags must be a list of strings') + if threshold is not None: + if not isinstance(threshold, float): + raise TypeError('threshold must be a float') + if (threshold < 0) or (threshold > 1): + raise ValueError('threshold must be between 0 and 1') + + if not isinstance(limit, int): + raise TypeError('limit must be an integer') # Guard against injection + + # Get the embedding of the query string as a PGvector compatible list + query_embedding = list(self._embedding_model.encode(text)) + + tags_where_clause = TAGS_WHERE_CLAUSE_CONJ if conjunctive else TAGS_WHERE_CLAUSE_DISJ + + # Build where clauses + if (tags is None) and (threshold is None): + # No where clauses, so don't bother with the WHERE keyword + where_clauses = '' + query_args = [query_embedding] + else: # construct where clauses + param_count = 1 + clauses = [] + query_args = [query_embedding] + if tags is not None: + param_count += 1 + query_args.append(tags) + clauses.append(tags_where_clause.format(tags=f'${param_count}')) + if threshold is not None: + param_count += 1 + query_args.append(threshold) + clauses.append(THRESHOLD_WHERE_CLAUSE.format(query_threshold=f'${param_count}')) + clauses = 'AND\n'.join(clauses) # TODO: move this into the fstring below after py3.12 + where_clauses = f'WHERE\n{clauses}' + + if limit: + limit_clause = f'LIMIT {limit}\n' + else: + limit_clause = '' + + # Execute the search via SQL + search_results = await self.conn.fetch( + QUERY_DATA_TABLE.format( + table_name=self.table_name, + where_clauses=where_clauses, + limit_clause=limit_clause, + ), + *query_args + ) + return search_results + + class DocDB(PGVectorHelper): ''' Specialize PGvectorHelper for documents ''' async def create_table(self) -> None: @@ -107,14 +271,14 @@ async def insert( INSERT_DOCS.format(table_name=self.table_name), content_embedding.tolist(), content, + tags, title, - page_numbers, - tags + page_numbers ) async def insert_many( self, - content_list: Iterable[tuple[str, str | None, list[int], list[str]]] + content_list: Iterable[tuple[str, list[str], str | None, list[int]]] ) -> None: ''' Update a table with one or more embedded documents @@ -122,38 +286,39 @@ async def insert_many( Semantically equivalent to multiple insert_doc calls, but uses executemany for efficiency Args: - content_list: List of tuples, each of the form: (content, title, page_numbers, tags) + content_list: List of tuples, each of the form: (content, tags, title, page_numbers) ''' await self.conn.executemany( INSERT_DOCS.format(table_name=self.table_name), ( - (self._embedding_model.encode(content), content, title, page_numbers, tags) - for content, title, page_numbers, tags in content_list + (self._embedding_model.encode(content), content, tags, title, page_numbers) + for content, tags, title, page_numbers in content_list ) ) async def search( self, - query_string: str, + text: str, query_title: str | None = None, query_page_numbers: list[int] | None = None, - query_tags: list[str] | None = None, + tags: list[str] | None = None, threshold: float | None = None, limit: int = 0, - conjunctive: bool = True + conjunctive: bool = True, + query_tags: list[str] | None = None, ) -> list[asyncpg.Record]: ''' Similarity search documents using a query string Args: - query_string (str): string to compare against items in the table. + text (str): string to compare against items in the table. This will be a vector/fuzzy/nearest-neighbor type search. query_title (str, optional): title of the document to compare against items in the table. query_page_numbers (list[int], optional): target page number in the document for query string comparison. - query_tags (list[str], optional): tags associated with the document to compare against items in the table. + tags (list[str], optional): tags associated with the document to compare against items in the table. Each individual tag must match exactly, but see the conjunctive param for how multiple tags are interpreted. @@ -166,6 +331,9 @@ async def search( list[asyncpg.Record]: list of search results (asyncpg.Record objects are similar to dicts, but allow for attribute-style access) ''' + if query_tags is not None: + warnings.warn('query_tags is deprecated. Use tags instead.', DeprecationWarning) + tags = query_tags if threshold is not None: if not isinstance(threshold, float): raise TypeError('threshold must be a float') @@ -176,12 +344,12 @@ async def search( raise TypeError('limit must be an integer') # Guard against injection # Get the embedding of the query string as a PGvector compatible list - query_embedding = list(self._embedding_model.encode(query_string)) + query_embedding = list(self._embedding_model.encode(text)) tags_where_clause = TAGS_WHERE_CLAUSE_CONJ if conjunctive else TAGS_WHERE_CLAUSE_DISJ # Build where clauses - if (query_title is None) and (query_page_numbers is None) and (query_tags is None) and (threshold is None): + if (query_title is None) and (query_page_numbers is None) and (tags is None) and (threshold is None): # No where clauses, so don't bother with the WHERE keyword where_clauses = '' query_args = [query_embedding] @@ -197,10 +365,10 @@ async def search( param_count += 1 query_args.append(query_page_numbers) clauses.append(PAGE_NUMBERS_WHERE_CLAUSE.format(query_page_numbers=f'${param_count}')) - if query_tags is not None: + if tags is not None: param_count += 1 - query_args.append(query_tags) - clauses.append(tags_where_clause.format(query_tags=f'${param_count}')) + query_args.append(tags) + clauses.append(tags_where_clause.format(tags=f'${param_count}')) if threshold is not None: param_count += 1 query_args.append(threshold) diff --git a/test/embedding/test_pgvector.py b/test/embedding/test_pgvector.py index d5d4966..3a1be04 100644 --- a/test/embedding/test_pgvector.py +++ b/test/embedding/test_pgvector.py @@ -88,7 +88,7 @@ async def test_PGv_embed_pacer(): # search table with perfect match search_string = '[beep] A single lap should be completed each time you hear this sound.' - sim_search = await vDB.search(query_string=search_string, limit=3) + sim_search = await vDB.search(text=search_string, limit=3) assert sim_search is not None, Exception("No results returned from perfect search") await vDB.drop_table() @@ -123,9 +123,9 @@ async def test_PGv_embed_many_pacer(): documents = ( ( text, + ['fitness', 'pacer', 'copypasta'], f'Pacer Copypasta line {index}', - [1, 2, 3], - ['fitness', 'pacer', 'copypasta'] + [1, 2, 3] ) for index, text in enumerate(pacer_copypasta) ) @@ -135,7 +135,7 @@ async def test_PGv_embed_many_pacer(): # Search table with perfect match search_string = '[beep] A single lap should be completed each time you hear this sound.' - sim_search = await vDB.search(query_string=search_string, limit=3) + sim_search = await vDB.search(text=search_string, limit=3) assert sim_search is not None, Exception("No results returned from perfect search") await vDB.drop_table() @@ -187,10 +187,10 @@ def encode_tweaker(*args, **kwargs): # search table with filtered match search_string = '[beep] A single lap should be completed each time you hear this sound.' sim_search = await vDB.search( - query_string=search_string, + text=search_string, query_title='Pacer Copypasta', query_page_numbers=[3], - query_tags=['pacer'], + tags=['pacer'], conjunctive=False ) assert sim_search is not None, Exception("No results returned from filtered search") @@ -201,11 +201,11 @@ def encode_tweaker(*args, **kwargs): await vDB.insert(content='Text', title='Even mo text', page_numbers=[1], tags=['tag3']) # Using limit default - sim_search = await vDB.search(query_string='Text', query_tags=['tag1', 'tag3'], conjunctive=False) + sim_search = await vDB.search(text='Text', tags=['tag1', 'tag3'], conjunctive=False) assert sim_search is not None, Exception("No results returned from filtered search") assert len(sim_search) == 3, Exception(f"There should be 3 results, received {sim_search}") - sim_search = await vDB.search(query_string='Text', query_tags=['tag1', 'tag3'], conjunctive=False, limit=1000) + sim_search = await vDB.search(text='Text', tags=['tag1', 'tag3'], conjunctive=False, limit=1000) assert sim_search is not None, Exception("No results returned from filtered search") assert len(sim_search) == 3, Exception(f"There should be 3 results, received {sim_search}") @@ -213,14 +213,14 @@ def encode_tweaker(*args, **kwargs): authors = ['Brian Kernighan', 'Louis Armstrong', 'Robert Graves'] metas = [[f'author={a}'] for a in authors] count = len(texts) - records = zip(texts, ['']*count, [None]*count, metas) + records = zip(texts, metas, ['']*count, [None]*count) await vDB.insert_many(records) - sim_search = await vDB.search(query_string='Hi there!', threshold=0.999, limit=0) + sim_search = await vDB.search(text='Hi there!', threshold=0.999, limit=0) assert sim_search is not None, Exception("No results returned from filtered search") assert len(sim_search) == 3, Exception(f"There should be 3 results, received {sim_search}") - sim_search = await vDB.search(query_string='Hi there!', threshold=0.999, limit=2) + sim_search = await vDB.search(text='Hi there!', threshold=0.999, limit=2) assert sim_search is not None, Exception("No results returned from filtered search") assert len(sim_search) == 2, Exception(f"There should be 2 results, received {sim_search}")