diff --git a/pylib/embedding/pgvector_data_doc.py b/pylib/embedding/pgvector_data_doc.py index fefeca3..e16b778 100644 --- a/pylib/embedding/pgvector_data_doc.py +++ b/pylib/embedding/pgvector_data_doc.py @@ -11,7 +11,7 @@ from ogbujipt.embedding.pgvector import PGVectorHelper, asyncpg, process_search_response -__all__ = ['DocDB'] +__all__ = ['DataDB', 'DocDB'] # ------ SQL queries --------------------------------------------------------------------------------------------------- # PG only supports proper query arguments (e.g. $1, $2, etc.) for values, not for table or column names @@ -177,11 +177,8 @@ async def search( generator which yields the rows os the query results ass attributable dicts ''' if threshold is not None: - if not isinstance(threshold, float): - raise TypeError('threshold must be a float') - if (threshold < 0) or (threshold > 1): - raise ValueError('threshold must be between 0 and 1') - + if not isinstance(threshold, float) or (threshold < 0) or (threshold > 1): + raise TypeError('threshold must be a float between 0.0 and 1.0') if not isinstance(limit, int): raise TypeError('limit must be an integer') # Guard against injection @@ -218,10 +215,8 @@ async def search( # Execute the search via SQL async with self.pool.acquire() as conn: search_results = await conn.fetch( - QUERY_DATA_TABLE.format( - table_name=self.table_name, - where_clauses=where_clauses, - limit_clause=limit_clause, + QUERY_DATA_TABLE.format(table_name=self.table_name, where_clauses=where_clauses, + limit_clause=limit_clause, ), *query_args ) @@ -237,10 +232,7 @@ async def create_table(self) -> None: async with self.pool.acquire() as conn: async with conn.transaction(): await conn.execute( - CREATE_DOC_TABLE.format( - table_name=self.table_name, - embed_dimension=self._embed_dimension) - ) + CREATE_DOC_TABLE.format(table_name=self.table_name, embed_dimension=self._embed_dimension)) async def insert( self, @@ -336,10 +328,8 @@ async def search( warnings.warn('query_tags is deprecated. Use tags instead.', DeprecationWarning) tags = query_tags if threshold is not None: - if not isinstance(threshold, float): - raise TypeError('threshold must be a float') - if (threshold < 0) or (threshold > 1): - raise ValueError('threshold must be between 0 and 1') + if not isinstance(threshold, float) or (threshold < 0) or (threshold > 1): + raise TypeError('threshold must be a float between 0.0 and 1.0') if not isinstance(limit, int): raise TypeError('limit must be an integer') # Guard against injection diff --git a/pylib/embedding/pgvector_message.py b/pylib/embedding/pgvector_message.py index cbf540f..00d3cfd 100644 --- a/pylib/embedding/pgvector_message.py +++ b/pylib/embedding/pgvector_message.py @@ -71,15 +71,35 @@ role, content, metadata -FROM - {table_name} -WHERE - history_key = $2 +FROM {table_name} +{where_clauses} ORDER BY cosine_similarity DESC -LIMIT $3; +{limit_clause}; ''' +# The cosine_similarity alias is not available in the WHERE clause, so use a nested SELECT +SEMANTIC_QUERY_MESSAGE_TABLE = '''-- Find messages with closest semantic similarity +SELECT + cosine_similarity, + ts, + role, + content, + metadata +FROM + (SELECT + history_key, + 1 - (embedding <=> $1) AS cosine_similarity, + ts, + role, + content, + metadata FROM {table_name}) AS main +{where_clauses} +{limit_clause}; +''' + +THRESHOLD_WHERE_CLAUSE = 'main.cosine_similarity >= {query_threshold}\n' + DELETE_OLDEST_MESSAGES = '''-- Delete oldest messages for given history key, such that only the newest N messages remain DELETE FROM {table_name} t_outer WHERE @@ -262,26 +282,32 @@ async def get_messages( Returns: generates asyncpg.Record instances of resulting messages ''' + if not isinstance(history_key, UUID): + history_key = UUID(history_key) + # msg = f'history_key must be a UUID, not {type(history_key)} ({history_key}))' + # raise TypeError(msg) + if not isinstance(since, datetime) and since is not None: + msg = 'since must be a datetime or None' + raise TypeError(msg) + if not isinstance(limit, int): + raise TypeError('limit must be an integer') # Guard against injection + qparams = [history_key] + # Build query + if since: + # Don't really need the ${len(qparams) + N} thing here (first optional), but used for consistency + since_clause = f' AND ts > ${len(qparams) + 1}' + qparams.append(since) + else: + since_clause = '' + if limit: + limit_clause = f'LIMIT ${len(qparams) + 1}' + qparams.append(limit) + else: + limit_clause = '' + + # Execute async with self.pool.acquire() as conn: - if not isinstance(history_key, UUID): - history_key = UUID(history_key) - # msg = f'history_key must be a UUID, not {type(history_key)} ({history_key}))' - # raise TypeError(msg) - if not isinstance(since, datetime) and since is not None: - msg = 'since must be a datetime or None' - raise TypeError(msg) - if since: - # Really don't need the ${len(qparams) + 1} trick here, but used for consistency - since_clause = f' AND ts > ${len(qparams) + 1}' - qparams.append(since) - else: - since_clause = '' - if limit: - limit_clause = f'LIMIT ${len(qparams) + 1}' - qparams.append(limit) - else: - limit_clause = '' message_records = await conn.fetch( RETURN_MESSAGES_BY_HISTORY_KEY.format( table_name=self.table_name, @@ -303,6 +329,7 @@ async def search( history_key: UUID, text: str, since: datetime | None = None, + threshold: float | None = None, limit: int = 1 ) -> list[asyncpg.Record]: ''' @@ -317,21 +344,45 @@ async def search( list[asyncpg.Record]: list of search results (asyncpg.Record objects are similar to dicts, but allow for attribute-style access) ''' + # Type checks + if threshold is not None: + if not isinstance(threshold, float) or (threshold < 0) or (threshold > 1): + raise TypeError('threshold must be a float between 0.0 and 1.0') if not isinstance(limit, int): raise TypeError('limit must be an integer') + if not isinstance(history_key, UUID): + history_key = UUID(history_key) + if not isinstance(since, datetime) and since is not None: + msg = 'since must be a datetime or None' + raise TypeError(msg) - # Get the embedding of the query string as a PGvector compatible list + # Get embedding of the query string as a PGvector compatible list query_embedding = list(self._embedding_model.encode(text)) - # Search the table + # Build query + clauses = ['main.history_key = $2\n'] + qparams = [query_embedding, history_key] + if since is not None: + # Don't really need the ${len(qparams) + N} thing here (first optional), but used for consistency + clauses.append(f'ts > ${len(qparams) + 1}') + qparams.append(since) + if threshold is not None: + clauses.append(THRESHOLD_WHERE_CLAUSE.format(query_threshold=f'${len(qparams) + 1}')) + qparams.append(threshold) + clauses = 'AND\n'.join(clauses) # TODO: move this into the fstring below after py3.12 + where_clauses = f'WHERE\n{clauses}' + limit_clause = f'LIMIT ${len(qparams) + 1}' + qparams.append(limit) + + # Execute async with self.pool.acquire() as conn: - records = await conn.fetch( + message_records = await conn.fetch( SEMANTIC_QUERY_MESSAGE_TABLE.format( - table_name=self.table_name + table_name=self.table_name, + where_clauses=where_clauses, + limit_clause=limit_clause ), - query_embedding, - history_key, - limit + *qparams ) search_results = [ @@ -342,7 +393,7 @@ async def search( 'metadata': record['metadata'], 'cosine_similarity': record['cosine_similarity'] } - for record in records + for record in message_records ] return process_search_response(search_results) diff --git a/pylib/llm_wrapper.py b/pylib/llm_wrapper.py index bb44db4..9f2f9ef 100644 --- a/pylib/llm_wrapper.py +++ b/pylib/llm_wrapper.py @@ -266,9 +266,9 @@ class openai_chat_api(openai_api): You need to set an OpenAI API key in your environment, or pass it in, for this next example - >>> from ogbujipt.llm_wrapper import openai_chat_api, prompt_to_chat + >>> import asyncio; from ogbujipt.llm_wrapper import openai_chat_api, prompt_to_chat >>> llm_api = openai_chat_api(model='gpt-3.5-turbo') - >>> resp = llm_api(prompt_to_chat('Knock knock!')) + >>> resp = asyncio.run(llm_api(prompt_to_chat('Knock knock!'))) >>> resp.first_choice_text ''' def call(self, messages, api_func=None, **kwargs): diff --git a/test/embedding/README.md b/test/embedding/README.md index 857bef3..61dadbf 100644 --- a/test/embedding/README.md +++ b/test/embedding/README.md @@ -7,3 +7,21 @@ docker run --name mock-postgres -p 5432:5432 \ -e POSTGRES_USER=mock_user -e POSTGRES_PASSWORD=mock_password -e POSTGRES_DB=mock_db \ -d ankane/pgvector ``` + +You can also use another PGVector setup, but then you need the following environment variables: + +* `PG_HOST` +* `PG_DATABASE` +* `PG_USER` +* `PG_PASSWORD` +* `PG_PORT` + +e.g.: + +```sh +PG_HOST="localhost" +PG_PORT="5432" +PG_USER="username" +PG_PASSWORD="passwd" +PG_DATABASE="PeeGeeVee" +``` diff --git a/test/embedding/test_pgvector_doc.py b/test/embedding/test_pgvector_doc.py index 7751fc4..d3cbbc6 100644 --- a/test/embedding/test_pgvector_doc.py +++ b/test/embedding/test_pgvector_doc.py @@ -39,8 +39,6 @@ def __init__(self, model_name_or_path): @pytest.mark.asyncio async def test_PGv_embed_pacer(DB): - dummy_model = SentenceTransformer('mock_transformer') - dummy_model.encode.return_value = np.array([1, 2, 3]) # Insert data for index, text in enumerate(pacer_copypasta): # For each line in the copypasta await DB.insert( # Insert the line into the table @@ -61,8 +59,6 @@ async def test_PGv_embed_pacer(DB): @pytest.mark.asyncio async def test_PGv_embed_many_pacer(DB): - dummy_model = SentenceTransformer('mock_transformer') - dummy_model.encode.return_value = np.array([1, 2, 3]) # Insert data using insert_many() documents = ( ( diff --git a/test/embedding/test_pgvector_message.py b/test/embedding/test_pgvector_message.py index dfc4f71..f7ac6f8 100644 --- a/test/embedding/test_pgvector_message.py +++ b/test/embedding/test_pgvector_message.py @@ -18,7 +18,7 @@ import pytest from unittest.mock import MagicMock, DEFAULT # noqa: F401 -# import numpy as np +import numpy as np # XXX: Move to a fixture? # Definitely don't want to even import SentenceTransformer class due to massive side-effects @@ -164,5 +164,30 @@ async def test_get_messages_since(DB, MESSAGES): assert len(results) == 1, Exception('Incorrect number of messages returned from chatlog') +@pytest.mark.asyncio +async def test_search_threshold(DB, MESSAGES): + dummy_model = SentenceTransformer('mock_transformer') + def encode_tweaker(*args, **kwargs): + # Note: cosine similarity of [1, 2, 3] & [100, 300, 500] appears to be ~ 0.9939 + if args[0].startswith('Hi'): + return np.array([100, 300, 500]) + else: + return np.array([1, 2, 3]) + + dummy_model.encode.side_effect = encode_tweaker + # Need to replace the default encoder set up by the fixture + DB._embedding_model = dummy_model + + await DB.insert_many(MESSAGES) + + history_key, role, content, timestamp, metadata = MESSAGES[0] + + results = list(await DB.search(history_key, 'Hi!', threshold=0.999)) + assert results is not None and len(results) == 0 + + results = list(await DB.search(history_key, 'Hi!', threshold=0.5)) + assert results is not None and len(results) == 1 + + if __name__ == '__main__': raise SystemExit("Attention! Run with pytest")