From 4eb7975ea6cc75da254b477ded37624e5121c7a4 Mon Sep 17 00:00:00 2001 From: Uche Ogbuji Date: Sun, 14 Jan 2024 20:54:01 -0700 Subject: [PATCH] Implement pgvector_message.get_chatlog(), with support for date constrint & limit, while depecating pgvector_message.get_table(). Tidy up embedding tests with more use of fixtures & teardown --- pylib/embedding/pgvector.py | 6 +- pylib/embedding/pgvector_message.py | 87 +++++++++++------ test/embedding/conftest.py | 68 +++++++++++++ test/embedding/test_pgvector_data.py | 69 ++----------- test/embedding/test_pgvector_doc.py | 122 +++++------------------ test/embedding/test_pgvector_message.py | 124 +++++++++--------------- 6 files changed, 212 insertions(+), 264 deletions(-) diff --git a/pylib/embedding/pgvector.py b/pylib/embedding/pgvector.py index 5852dcc..8005d1a 100644 --- a/pylib/embedding/pgvector.py +++ b/pylib/embedding/pgvector.py @@ -187,7 +187,7 @@ async def count_items(self) -> int: # Count the number of documents in the table count = await conn.fetchval(f'SELECT COUNT(*) FROM {self.table_name}') return count - + async def table_exists(self) -> bool: ''' Check if the table exists @@ -197,11 +197,11 @@ async def table_exists(self) -> bool: ''' # Check if the table exists async with (await self.connection_pool()).acquire() as conn: - table_exists = await conn.fetchval( + exists = await conn.fetchval( CHECK_TABLE_EXISTS, self.table_name ) - return table_exists + return exists async def drop_table(self) -> None: ''' diff --git a/pylib/embedding/pgvector_message.py b/pylib/embedding/pgvector_message.py index 980acbb..7d427a5 100644 --- a/pylib/embedding/pgvector_message.py +++ b/pylib/embedding/pgvector_message.py @@ -9,6 +9,7 @@ from uuid import UUID from datetime import datetime, timezone from typing import Iterable +import warnings from ogbujipt.config import attr_dict from ogbujipt.embedding.pgvector import PGVectorHelper, asyncpg, process_search_response @@ -21,7 +22,7 @@ CREATE_MESSAGE_TABLE = '''-- Create a table to hold individual messages (e.g. from a chatlog) and their metadata CREATE TABLE IF NOT EXISTS {table_name} ( - ts TIMESTAMP WITH TIME ZONE PRIMARY KEY, -- timestamp of the message + ts TIMESTAMP WITH TIME ZONE, -- timestamp of the message history_key UUID, -- uunique identifier for contextual message history role TEXT, -- role of the message (meta ID such as 'system' or user, -- or an ID associated with the sender) @@ -39,11 +40,7 @@ embedding, ts, metadata -) VALUES ($1, $2, $3, $4, $5, $6) -ON CONFLICT (ts) DO UPDATE SET -- Update the content, embedding, and metadata of the message if it already exists - content = EXCLUDED.content, - embedding = EXCLUDED.embedding, - metadata = EXCLUDED.metadata; +) VALUES ($1, $2, $3, $4, $5, $6); ''' CLEAR_MESSAGE = '''-- Deletes matching messages @@ -52,7 +49,7 @@ history_key = $1 ''' -RETURN_MESSAGE_BY_HISTORY_KEY = '''-- Get entire chatlog by history key +RETURN_MESSAGES_BY_HISTORY_KEY = '''-- Get entire chatlog by history key SELECT ts, role, @@ -62,8 +59,10 @@ {table_name} WHERE history_key = $1 +{since_clause} ORDER BY - ts; + ts +{limit_clause}; ''' SEMANTIC_QUERY_MESSAGE_TABLE = '''-- Find messages with closest semantic similarity @@ -179,54 +178,88 @@ async def clear( ), history_key ) - + # XXX: Change to a generator - async def get_table( + async def get_chatlog( self, - history_key: UUID - ) -> list[asyncpg.Record]: + history_key: UUID | str, + since: datetime | None = None, + limit: int = 0 + ): # -> list[asyncpg.Record]: ''' - Retrieve all entries in a message history + Retrieve entries in a message history Args: - history_key (str): history key (unique identifier) to match + history_key (str): history key (unique identifier) to match; string or object + since (datetime, optional): only return messages after this timestamp + limit (int, optional): maximum number of messages to return. Default is all messages Returns: - list[asyncpg.Record]: list of message entries - (asyncpg.Record objects are similar to dicts, but allow for attribute-style access) + generates asyncpg.Record instances of resulting messages ''' + qparams = [history_key] async with (await self.connection_pool()).acquire() as conn: + if not isinstance(history_key, UUID): + history_key = UUID(history_key) + # msg = f'history_key must be a UUID, not {type(history_key)} ({history_key}))' + # raise TypeError(msg) + if not isinstance(since, datetime) and since is not None: + msg = 'since must be a datetime or None' + raise TypeError(msg) + if since: + # Really don't need the ${len(qparams) + 1} trick here, but used for consistency + since_clause = f' AND ts > ${len(qparams) + 1}' + qparams.append(since) + else: + since_clause = '' + if limit: + limit_clause = f'LIMIT ${len(qparams) + 1}' + qparams.append(limit) + else: + limit_clause = '' message_records = await conn.fetch( - RETURN_MESSAGE_BY_HISTORY_KEY.format( + RETURN_MESSAGES_BY_HISTORY_KEY.format( table_name=self.table_name, + since_clause=since_clause, + limit_clause=limit_clause ), - history_key + *qparams ) - messages = [ - attr_dict({ + return (attr_dict({ 'ts': record['ts'], 'role': record['role'], 'content': record['content'], 'metadata': record['metadata'] - }) - for record in message_records - ] - - return messages + }) for record in message_records) + + # async for record in message_records: + # yield attr_dict({ + # 'ts': record['ts'], + # 'role': record['role'], + # 'content': record['content'], + # 'metadata': record['metadata'] + # }) + async def get_table(self, history_key): + # Deprecated + warnings.warn('pgvector_message.get_table() is deprecated. Use get_chatlog().', DeprecationWarning) + return list(await self.get_chatlog(history_key)) + async def search( self, history_key: UUID, text: str, + since: datetime | None = None, limit: int = 1 ) -> list[asyncpg.Record]: ''' Similarity search documents using a query string Args: + history_key (str): history key for the conversation to query text (str): string to compare against items in the table - - k (int, optional): maximum number of results to return (useful for top-k query) + since (datetime, optional): only return results after this timestamp + limit (int, optional): maximum number of messages to return; for top-k type query. Default is 1 Returns: list[asyncpg.Record]: list of search results (asyncpg.Record objects are similar to dicts, but allow for attribute-style access) diff --git a/test/embedding/conftest.py b/test/embedding/conftest.py index 53b6356..90d562e 100644 --- a/test/embedding/conftest.py +++ b/test/embedding/conftest.py @@ -1,5 +1,19 @@ +# SPDX-FileCopyrightText: 2023-present Oori Data +# SPDX-License-Identifier: Apache-2.0 +# test/embedding/conftest.py +''' +Fixtures/setup/teardown for embedding tests +''' + +import sys +import os import pytest +import pytest_asyncio +from unittest.mock import MagicMock, DEFAULT # noqa: F401 + import numpy as np +from ogbujipt.embedding.pgvector import MessageDB, DataDB, DocDB + @pytest.fixture def CORRECT_STRING(): @@ -67,3 +81,57 @@ def HITHERE_all_MiniLM_L12_v2(): -0.071223356, 0.0019683593, 0.032683503, -0.08899012, 0.10160039, 0.04948275, 0.048017487, -0.046223965, 0.032460734, -0.043729845, 0.030224336, -0.019220904, 0.08223829, 0.03851222, -0.016376046, 0.041965306, 0.0445879, -0.03780432, -0.024826797, 0.014669102, 0.057102628, -0.031820614, 0.0027352672, 0.052658144]) + +# XXX: This stanza to go away once mocking is complete - Kai +HOST = os.environ.get('PG_HOST', 'localhost') +DB_NAME = os.environ.get('PG_DATABASE', 'mock_db') +USER = os.environ.get('PG_USER', 'mock_user') +PASSWORD = os.environ.get('PG_PASSWORD', 'mock_password') +PORT = os.environ.get('PG_PORT', 5432) + + +# XXX: Move to a fixture? +# Definitely don't want to even import SentenceTransformer class due to massive side-effects +class SentenceTransformer(object): + def __init__(self, model_name_or_path): + self.encode = MagicMock() + + +DB_CLASS = { + 'test/embedding/test_pgvector_message.py': MessageDB, + 'test/embedding/test_pgvector_data.py': DataDB, + 'test/embedding/test_pgvector_doc.py': DocDB, +} + + +@pytest_asyncio.fixture # Notice the async aware fixture declaration +async def DB(request): + testname = request.node.name + testfile = request.node.location[0] + table_name = testname.lower() + print(f'DB setup for test: {testname}. Table name {table_name}', file=sys.stderr) + dummy_model = SentenceTransformer('mock_transformer') + dummy_model.encode.return_value = np.array([1, 2, 3]) + try: + vDB = await DB_CLASS[testfile].from_conn_params( + embedding_model=dummy_model, + table_name=table_name, + db_name=DB_NAME, + host=HOST, + port=int(PORT), + user=USER, + password=PASSWORD) + except ConnectionRefusedError: + pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) + if vDB is None: + pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) + + # Create table + await vDB.drop_table() + assert not await vDB.table_exists(), Exception("Table exists before creation") + await vDB.create_table() + assert await vDB.table_exists(), Exception("Table does not exist after creation") + # The test will take control upon the yield + yield vDB + # Teardown: Drop table + await vDB.drop_table() diff --git a/test/embedding/test_pgvector_data.py b/test/embedding/test_pgvector_data.py index 2897c0c..b9690e4 100644 --- a/test/embedding/test_pgvector_data.py +++ b/test/embedding/test_pgvector_data.py @@ -8,16 +8,8 @@ import pytest from unittest.mock import MagicMock, DEFAULT # noqa: F401 -import os -from ogbujipt.embedding.pgvector import DataDB import numpy as np -# XXX: This stanza to go away once mocking is complete - Kai -HOST = os.environ.get('PG_HOST', 'localhost') -DB_NAME = os.environ.get('PG_DATABASE', 'mock_db') -USER = os.environ.get('PG_USER', 'mock_user') -PASSWORD = os.environ.get('PG_PASSWORD', 'mock_password') -PORT = os.environ.get('PG_PORT', 5432) KG_STATEMENTS = [ # Demo data ("👤 Alikiba `releases_single` 💿 'Yalaiti'", {'url': 'https://notjustok.com/lyrics/yalaiti-lyrics-by-alikiba-ft-sabah-salum/'}), @@ -36,44 +28,24 @@ def __init__(self, model_name_or_path): @pytest.mark.asyncio -async def test_insert_data_vector(): +async def test_insert_data_vector(DB): dummy_model = SentenceTransformer('mock_transformer') dummy_model.encode.return_value = np.array([1, 2, 3]) - TABLE_NAME = 'embedding_data_test' - try: - vDB = await DataDB.from_conn_params( - embedding_model=dummy_model, - table_name=TABLE_NAME, - db_name=DB_NAME, - host=HOST, - port=int(PORT), - user=USER, - password=PASSWORD) - except ConnectionRefusedError: - pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) - if vDB is None: - pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) - - # Create tables - await vDB.drop_table() - assert await vDB.table_exists() is False, Exception("Table exists before creation") - await vDB.create_table() - assert await vDB.table_exists() is True, Exception("Table does not exist after creation") item1_text = KG_STATEMENTS[0][0] item1_meta = KG_STATEMENTS[0][1] # Insert data for index, (text, meta) in enumerate(KG_STATEMENTS): - await vDB.insert( # Insert the row into the table + await DB.insert( # Insert the row into the table content=text, # text to be embedded tags=[f'{k}={v}' for (k, v) in meta.items()], # Tag metadata ) - assert await vDB.count_items() == len(KG_STATEMENTS), Exception('Incorrect number of documents after insertion') + assert await DB.count_items() == len(KG_STATEMENTS), Exception('Incorrect number of documents after insertion') # search table with perfect match - result = await vDB.search(text=item1_text, limit=3) + result = await DB.search(text=item1_text, limit=3) # assert result is not None, Exception('No results returned from perfect search') # Even though the embedding is mocked, the stored text should be faithful @@ -81,47 +53,26 @@ async def test_insert_data_vector(): assert row.content == item1_text assert row.tags == [f'{k}={v}' for (k, v) in item1_meta.items()] - await vDB.drop_table() + await DB.drop_table() @pytest.mark.asyncio -async def test_insertmany_data_vector(): +async def test_insertmany_data_vector(DB): dummy_model = SentenceTransformer('mock_transformer') dummy_model.encode.return_value = np.array([1, 2, 3]) - # print(f'EMODEL: {dummy_model}') - TABLE_NAME = 'embedding_test' - try: - vDB = await DataDB.from_conn_params( - embedding_model=dummy_model, - table_name=TABLE_NAME, - db_name=DB_NAME, - host=HOST, - port=int(PORT), - user=USER, - password=PASSWORD) - except ConnectionRefusedError: - pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) - if vDB is None: - pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) item1_text = KG_STATEMENTS[0][0] item1_meta = KG_STATEMENTS[0][1] - # Create tables - await vDB.drop_table() - assert await vDB.table_exists() is False, Exception("Table exists before creation") - await vDB.create_table() - assert await vDB.table_exists() is True, Exception("Table does not exist after creation") - # Insert data using insert_many() dataset = ((text, [f'{k}={v}' for (k, v) in tags.items()]) for (text, tags) in KG_STATEMENTS) - await vDB.insert_many(dataset) + await DB.insert_many(dataset) - assert await vDB.count_items() == len(KG_STATEMENTS), Exception('Incorrect number of documents after insertion') + assert await DB.count_items() == len(KG_STATEMENTS), Exception('Incorrect number of documents after insertion') # search table with perfect match - result = await vDB.search(text=item1_text, limit=3) + result = await DB.search(text=item1_text, limit=3) # assert result is not None, Exception('No results returned from perfect search') # Even though the embedding is mocked, the stored text should be faithful @@ -129,7 +80,7 @@ async def test_insertmany_data_vector(): assert row.content == item1_text assert row.tags == [f'{k}={v}' for (k, v) in item1_meta.items()] - await vDB.drop_table() + await DB.drop_table() if __name__ == '__main__': diff --git a/test/embedding/test_pgvector_doc.py b/test/embedding/test_pgvector_doc.py index 37b615d..6e47f6a 100644 --- a/test/embedding/test_pgvector_doc.py +++ b/test/embedding/test_pgvector_doc.py @@ -24,17 +24,9 @@ # from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, DEFAULT # noqa: F401 -import os from ogbujipt.embedding.pgvector import DocDB import numpy as np -# XXX: This stanza to go away once mocking is complete - Kai -HOST = os.environ.get('PG_HOST', 'localhost') -DB_NAME = os.environ.get('PG_DATABASE', 'mock_db') -USER = os.environ.get('PG_USER', 'mock_user') -PASSWORD = os.environ.get('PG_PASSWORD', 'mock_password') -PORT = os.environ.get('PG_PORT', 5432) - pacer_copypasta = [ # Demo document ('The FitnessGram™ Pacer Test is a multistage aerobic capacity test that progressively gets more difficult as it' ' continues.'), @@ -53,75 +45,31 @@ def __init__(self, model_name_or_path): @pytest.mark.asyncio -async def test_PGv_embed_pacer(): +async def test_PGv_embed_pacer(DB): dummy_model = SentenceTransformer('mock_transformer') dummy_model.encode.return_value = np.array([1, 2, 3]) - # print(f'EMODEL: {dummy_model}') - TABLE_NAME = 'embedding_test' - try: - vDB = await DocDB.from_conn_params( - embedding_model=dummy_model, - table_name=TABLE_NAME, - db_name=DB_NAME, - host=HOST, - port=int(PORT), - user=USER, - password=PASSWORD) - except ConnectionRefusedError: - pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) - - assert vDB is not None, ConnectionError("Postgres docker instance not available for testing PG code") - - # Create tables - await vDB.drop_table() - assert await vDB.table_exists() is False, Exception("Table exists before creation") - await vDB.create_table() - assert await vDB.table_exists() is True, Exception("Table does not exist after creation") - # Insert data for index, text in enumerate(pacer_copypasta): # For each line in the copypasta - await vDB.insert( # Insert the line into the table + await DB.insert( # Insert the line into the table content=text, # The text to be embedded title=f'Pacer Copypasta line {index}', # Title metadata page_numbers=[1, 2, 3], # Page number metadata tags=['fitness', 'pacer', 'copypasta'], # Tag metadata ) - assert await vDB.count_items() == len(pacer_copypasta), Exception("Not all documents inserted") + assert await DB.count_items() == len(pacer_copypasta), Exception("Not all documents inserted") # search table with perfect match search_string = '[beep] A single lap should be completed each time you hear this sound.' - sim_search = await vDB.search(text=search_string, limit=3) + sim_search = await DB.search(text=search_string, limit=3) assert sim_search is not None, Exception("No results returned from perfect search") - await vDB.drop_table() + await DB.drop_table() @pytest.mark.asyncio -async def test_PGv_embed_many_pacer(): +async def test_PGv_embed_many_pacer(DB): dummy_model = SentenceTransformer('mock_transformer') dummy_model.encode.return_value = np.array([1, 2, 3]) - # print(f'EMODEL: {dummy_model}') - TABLE_NAME = 'embedding_test' - try: - vDB = await DocDB.from_conn_params( - embedding_model=dummy_model, - table_name=TABLE_NAME, - db_name=DB_NAME, - host=HOST, - port=int(PORT), - user=USER, - password=PASSWORD) - except ConnectionRefusedError: - pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) - - assert vDB is not None, ConnectionError("Postgres docker instance not available for testing PG code") - - # Create tables - await vDB.drop_table() - assert await vDB.table_exists() is False, Exception("Table exists before creation") - await vDB.create_table() - assert await vDB.table_exists() is True, Exception("Table does not exist after creation") - # Insert data using insert_many() documents = ( ( @@ -132,20 +80,20 @@ async def test_PGv_embed_many_pacer(): ) for index, text in enumerate(pacer_copypasta) ) - await vDB.insert_many(documents) + await DB.insert_many(documents) - assert await vDB.count_items() == len(pacer_copypasta), Exception("Not all documents inserted") + assert await DB.count_items() == len(pacer_copypasta), Exception("Not all documents inserted") # Search table with perfect match search_string = '[beep] A single lap should be completed each time you hear this sound.' - sim_search = await vDB.search(text=search_string, limit=3) + sim_search = await DB.search(text=search_string, limit=3) assert sim_search is not None, Exception("No results returned from perfect search") - await vDB.drop_table() + await DB.drop_table() @pytest.mark.asyncio -async def test_PGv_search_filtered(): +async def test_PGv_search_filtered(DB): dummy_model = SentenceTransformer('mock_transformer') def encode_tweaker(*args, **kwargs): if args[0].startswith('Text'): @@ -154,42 +102,22 @@ def encode_tweaker(*args, **kwargs): return np.array([100, 300, 500]) dummy_model.encode.side_effect = encode_tweaker - # print(f'EMODEL: {dummy_model}') - TABLE_NAME = 'embedding_test' - try: - vDB = await DocDB.from_conn_params( - embedding_model=dummy_model, - table_name=TABLE_NAME, - db_name=DB_NAME, - host=HOST, - port=int(PORT), - user=USER, - password=PASSWORD) - except ConnectionRefusedError: - pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) - - assert vDB is not None, ConnectionError("Postgres docker instance not available for testing PG code") - - # Create tables - await vDB.drop_table() - assert await vDB.table_exists() is False, Exception("Table exists before creation") - await vDB.create_table() - assert await vDB.table_exists() is True, Exception("Table does not exist after creation") - + # Need to replace the default encoder set up by the fixture + DB._embedding_model = dummy_model # Insert data for index, text in enumerate(pacer_copypasta): # For each line in the copypasta - await vDB.insert( # Insert the line into the table + await DB.insert( # Insert the line into the table content=text, # The text to be embedded title='Pacer Copypasta', # Title metadata page_numbers=[index], # Page number metadata tags=['fitness', 'pacer', 'copypasta'], # Tag metadata ) - assert await vDB.count_items() == len(pacer_copypasta), Exception("Not all documents inserted") + assert await DB.count_items() == len(pacer_copypasta), Exception("Not all documents inserted") # search table with filtered match search_string = '[beep] A single lap should be completed each time you hear this sound.' - sim_search = await vDB.search( + sim_search = await DB.search( text=search_string, query_title='Pacer Copypasta', query_page_numbers=[3], @@ -199,16 +127,16 @@ def encode_tweaker(*args, **kwargs): assert sim_search is not None, Exception("No results returned from filtered search") #Test conjunctive semantics - await vDB.insert(content='Text', title='Some text', page_numbers=[1], tags=['tag1']) - await vDB.insert(content='Text', title='Some mo text', page_numbers=[1], tags=['tag2', 'tag3']) - await vDB.insert(content='Text', title='Even mo text', page_numbers=[1], tags=['tag3']) + await DB.insert(content='Text', title='Some text', page_numbers=[1], tags=['tag1']) + await DB.insert(content='Text', title='Some mo text', page_numbers=[1], tags=['tag2', 'tag3']) + await DB.insert(content='Text', title='Even mo text', page_numbers=[1], tags=['tag3']) # Using limit default - sim_search = await vDB.search(text='Text', tags=['tag1', 'tag3'], conjunctive=False) + sim_search = await DB.search(text='Text', tags=['tag1', 'tag3'], conjunctive=False) assert sim_search is not None, Exception("No results returned from filtered search") assert len(list(sim_search)) == 3, Exception(f"There should be 3 results, received {sim_search}") - sim_search = await vDB.search(text='Text', tags=['tag1', 'tag3'], conjunctive=False, limit=1000) + sim_search = await DB.search(text='Text', tags=['tag1', 'tag3'], conjunctive=False, limit=1000) assert sim_search is not None, Exception("No results returned from filtered search") assert len(list(sim_search)) == 3, Exception(f"There should be 3 results, received {sim_search}") @@ -217,17 +145,17 @@ def encode_tweaker(*args, **kwargs): metas = [[f'author={a}'] for a in authors] count = len(texts) records = zip(texts, metas, ['']*count, [None]*count) - await vDB.insert_many(records) + await DB.insert_many(records) - sim_search = await vDB.search(text='Hi there!', threshold=0.999, limit=0) + sim_search = await DB.search(text='Hi there!', threshold=0.999, limit=0) assert sim_search is not None, Exception("No results returned from filtered search") assert len(list(sim_search)) == 3, Exception(f"There should be 3 results, received {sim_search}") - sim_search = await vDB.search(text='Hi there!', threshold=0.999, limit=2) + sim_search = await DB.search(text='Hi there!', threshold=0.999, limit=2) assert sim_search is not None, Exception("No results returned from filtered search") assert len(list(sim_search)) == 2, Exception(f"There should be 2 results, received {sim_search}") - await vDB.drop_table() + await DB.drop_table() if __name__ == '__main__': diff --git a/test/embedding/test_pgvector_message.py b/test/embedding/test_pgvector_message.py index 0dc05dd..d777df0 100644 --- a/test/embedding/test_pgvector_message.py +++ b/test/embedding/test_pgvector_message.py @@ -5,7 +5,6 @@ See test/embedding/test_pgvector.py for important notes on running these tests ''' -import os from datetime import datetime import pytest @@ -13,17 +12,16 @@ import numpy as np -from ogbujipt.embedding.pgvector import MessageDB - -# XXX: This stanza to go away once mocking is complete - Kai -HOST = os.environ.get('PG_HOST', 'localhost') -DB_NAME = os.environ.get('PG_DATABASE', 'mock_db') -USER = os.environ.get('PG_USER', 'mock_user') -PASSWORD = os.environ.get('PG_PASSWORD', 'mock_password') -PORT = os.environ.get('PG_PORT', 5432) +# XXX: Move to a fixture? +# Definitely don't want to even import SentenceTransformer class due to massive side-effects +class SentenceTransformer(object): + def __init__(self, model_name_or_path): + self.encode = MagicMock() -MESSAGES = [ # Test data: history_key, role, content, timestamp, metadata +@pytest.fixture +def MESSAGES(): + messages = [ # Test data: history_key, role, content, timestamp, metadata ('00000000-0000-0000-0000-000000000000', 'ama', 'Hello Eme!', '2021-10-01 00:00:00+00:00', {'1': 'a'}), ('00000000-0000-0000-0000-000000000001', 'ugo', 'Greetings Ego', '2021-10-01 00:00:01+00:00', {'2': 'b'}), ('00000000-0000-0000-0000-000000000000', 'eme', 'How you dey, Ama!', '2021-10-01 00:00:02+00:00', {'3': 'c'}), @@ -33,54 +31,23 @@ ('00000000-0000-0000-0000-000000000000', 'eme', 'Very good. Say hello to your family for me.', '2021-10-01 00:00:06+00:00', {'7': 'g'}), # noqa: E501 ('00000000-0000-0000-0000-000000000001', 'ugo', 'An even better surprise, I hope!', '2021-10-01 00:00:07+00:00', {'8': 'h'}) # noqa: E501 ] -MESSAGES = [ - (history_key, role, content, datetime.fromisoformat(timestamp), metadata) - for history_key, role, content, timestamp, metadata in MESSAGES -] - - -# XXX: Move to a fixture? -# Definitely don't want to even import SentenceTransformer class due to massive side-effects -class SentenceTransformer(object): - def __init__(self, model_name_or_path): - self.encode = MagicMock() + return [(history_key, role, content, datetime.fromisoformat(timestamp), metadata) + for history_key, role, content, timestamp, metadata in messages + ] @pytest.mark.asyncio -async def test_insert_message_vector(): - dummy_model = SentenceTransformer('mock_transformer') - dummy_model.encode.return_value = np.array([1, 2, 3]) - TABLE_NAME = 'embedding_msg_test' - try: - vDB = await MessageDB.from_conn_params( - embedding_model=dummy_model, - table_name=TABLE_NAME, - db_name=DB_NAME, - host=HOST, - port=int(PORT), - user=USER, - password=PASSWORD) - except ConnectionRefusedError: - pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) - if vDB is None: - pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) - - # Create tables - await vDB.drop_table() - assert await vDB.table_exists() is False, Exception("Table exists before creation") - await vDB.create_table() - assert await vDB.table_exists() is True, Exception("Table does not exist after creation") - +async def test_insert_message_vector(DB, MESSAGES): # Insert data for index, (row) in enumerate(MESSAGES): - await vDB.insert(*row) + await DB.insert(*row) - assert await vDB.count_items() == len(MESSAGES), Exception('Incorrect number of messages after insertion') + assert await DB.count_items() == len(MESSAGES), Exception('Incorrect number of messages after insertion') history_key, role, content, timestamp, metadata = MESSAGES[0] # search table with perfect match - result = await vDB.search(text=content, history_key=history_key, limit=3) + result = await DB.search(text=content, history_key=history_key, limit=3) # assert result is not None, Exception('No results returned from perfect search') # Even though the embedding is mocked, the stored text should be faithful @@ -88,43 +55,18 @@ async def test_insert_message_vector(): assert row.content == content assert row.metadata == {'1': 'a'} - await vDB.drop_table() - @pytest.mark.asyncio -async def test_insertmany_message_vector(): - dummy_model = SentenceTransformer('mock_transformer') - dummy_model.encode.return_value = np.array([1, 2, 3]) - TABLE_NAME = 'embedding_msg_test' - try: - vDB = await MessageDB.from_conn_params( - embedding_model=dummy_model, - table_name=TABLE_NAME, - db_name=DB_NAME, - host=HOST, - port=int(PORT), - user=USER, - password=PASSWORD) - except ConnectionRefusedError: - pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) - if vDB is None: - pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) - - # Create tables - await vDB.drop_table() - assert await vDB.table_exists() is False, Exception("Table exists before creation") - await vDB.create_table() - assert await vDB.table_exists() is True, Exception("Table does not exist after creation") - +async def test_insertmany_message_vector(DB, MESSAGES): # Insert data using insert_many() - await vDB.insert_many(MESSAGES) + await DB.insert_many(MESSAGES) - assert await vDB.count_items() == len(MESSAGES), Exception('Incorrect number of messages after insertion') + assert await DB.count_items() == len(MESSAGES), Exception('Incorrect number of messages after insertion') history_key, role, content, timestamp, metadata = MESSAGES[0] # search table with perfect match - result = await vDB.search(text=content, history_key=history_key, limit=3) + result = await DB.search(text=content, history_key=history_key, limit=3) # assert result is not None, Exception('No results returned from perfect search') # Even though the embedding is mocked, the stored text should be faithful @@ -132,8 +74,34 @@ async def test_insertmany_message_vector(): assert row.content == content assert row.metadata == {'1': 'a'} - await vDB.drop_table() +@pytest.mark.asyncio +async def test_get_chatlog_all_limit(DB, MESSAGES): + # Insert data using insert_many() + await DB.insert_many(MESSAGES) + + history_key, role, content, timestamp, metadata = MESSAGES[0] + + results = await DB.get_chatlog(history_key=history_key) + assert len(list(results)) == 4, Exception('Incorrect number of messages returned from chatlog') + + results = await DB.get_chatlog(history_key=history_key, limit=3) + assert len(list(results)) == 3, Exception('Incorrect number of messages returned from chatlog') + + +@pytest.mark.asyncio +async def test_get_chatlog_since(DB, MESSAGES): + await DB.insert_many(MESSAGES) + + history_key, role, content, timestamp, metadata = MESSAGES[0] + + since_ts = datetime.fromisoformat('2021-10-01 00:00:03+00:00') + results = list(await DB.get_chatlog(history_key=history_key, since=since_ts)) + assert len(results) == 2, Exception('Incorrect number of messages returned from chatlog') + + since_ts = datetime.fromisoformat('2021-10-01 00:00:04+00:00') + results = list(await DB.get_chatlog(history_key=history_key, since=since_ts)) + assert len(results) == 1, Exception('Incorrect number of messages returned from chatlog') if __name__ == '__main__': raise SystemExit("Attention! Run with pytest")