diff --git a/demo/PGvector_demo.ipynb b/demo/PGvector_demo.ipynb index 809da4d..13e733b 100644 --- a/demo/PGvector_demo.ipynb +++ b/demo/PGvector_demo.ipynb @@ -999,7 +999,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.11.6" }, "orig_nbformat": 4 }, diff --git a/pylib/embedding/pgvector_message.py b/pylib/embedding/pgvector_message.py index 326a98b..cbf540f 100644 --- a/pylib/embedding/pgvector_message.py +++ b/pylib/embedding/pgvector_message.py @@ -79,9 +79,55 @@ cosine_similarity DESC LIMIT $3; ''' -# ------ SQL queries --------------------------------------------------------------------------------------------------- + +DELETE_OLDEST_MESSAGES = '''-- Delete oldest messages for given history key, such that only the newest N messages remain +DELETE FROM {table_name} t_outer +WHERE + t_outer.history_key = $1 +AND + t_outer.ctid NOT IN ( + SELECT t_inner.ctid + FROM {table_name} t_inner + WHERE + t_inner.history_key = $1 + ORDER BY + t_inner.ts DESC + LIMIT $2 +); +''' + +# Delete after full comfort with windowed implementation +# TEMPQ = ''' +# SELECT t_inner.ctid +# FROM {table_name} t_inner +# WHERE +# t_inner.history_key = $1 +# ORDER BY +# t_inner.ts DESC +# LIMIT $2; +# ''' + +# ------ Class implementations --------------------------------------------------------------------------------------- class MessageDB(PGVectorHelper): + def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool, window=0): + ''' + Helper class for messages/chatlog storage and retrieval + + Args: + embedding (SentenceTransformer): SentenceTransformer object of your choice + https://huggingface.co/sentence-transformers + window (int, optional): number of messages to maintain in the DB. Default is 0 (all messages) + ''' + super().__init__(embedding_model, table_name, pool) + self.window = window + + @classmethod + async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password, window=0) -> 'MessageDB': # noqa: E501 + obj = await super().from_conn_params(embedding_model, table_name, host, port, db_name, user, password) + obj.window = window + return obj + ''' Specialize PGvectorHelper for messages, e.g. chatlogs ''' async def create_table(self): async with self.pool.acquire() as conn: @@ -135,7 +181,17 @@ async def insert( content_embedding.tolist(), timestamp, metadata - ) + ) + # print(f'{self.window=}, Pre-count: {await self.count_items()}') + async with self.pool.acquire() as conn: + async with conn.transaction(): + if self.window: + await conn.execute( + DELETE_OLDEST_MESSAGES.format(table_name=self.table_name), + history_key, + self.window) + # async with self.pool.acquire() as conn: + # print(f'{self.window=}, Post-count: {await self.count_items()}, {list(await conn.fetch(TEMPQ.format(table_name=self.table_name), history_key, self.window))}') # noqa E501 async def insert_many( self, @@ -158,6 +214,17 @@ async def insert_many( for hk, role, text, ts, metadata in content_list ) ) + # print(f'{self.window=}, Pre-count: {await self.count_items()}') # noqa E501 + async with self.pool.acquire() as conn: + async with conn.transaction(): + if self.window: + # Set uniquifies the history keys + for hk in {hk for hk, _, _, _, _ in content_list}: + await conn.execute( + DELETE_OLDEST_MESSAGES.format(table_name=self.table_name), + hk, self.window) + # async with self.pool.acquire() as conn: + # print(f'{self.window=}, {hk=}, Post-count: {await self.count_items()}, {list(await conn.fetch(TEMPQ.format(table_name=self.table_name), hk, self.window))}') # noqa E501 async def clear( self, diff --git a/test/embedding/README.md b/test/embedding/README.md new file mode 100644 index 0000000..857bef3 --- /dev/null +++ b/test/embedding/README.md @@ -0,0 +1,9 @@ +To run these tests, first set up a mock Postgres instance with the following commands +(make sure you don't have anything running on port 0.0.0.0:5432): + +```sh +docker pull ankane/pgvector +docker run --name mock-postgres -p 5432:5432 \ + -e POSTGRES_USER=mock_user -e POSTGRES_PASSWORD=mock_password -e POSTGRES_DB=mock_db \ + -d ankane/pgvector +``` diff --git a/test/embedding/conftest.py b/test/embedding/conftest.py index 90d562e..8ae87f5 100644 --- a/test/embedding/conftest.py +++ b/test/embedding/conftest.py @@ -3,6 +3,14 @@ # test/embedding/conftest.py ''' Fixtures/setup/teardown for embedding tests + +General note: After setup as described in the README.md for this directory, run the tests with: + +pytest test + +or, for just embeddings tests: + +pytest test/embedding/ ''' import sys @@ -103,6 +111,7 @@ def __init__(self, model_name_or_path): 'test/embedding/test_pgvector_doc.py': DocDB, } +# print(HOST, DB_NAME, USER, PASSWORD, PORT) @pytest_asyncio.fixture # Notice the async aware fixture declaration async def DB(request): @@ -123,8 +132,44 @@ async def DB(request): password=PASSWORD) except ConnectionRefusedError: pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) - if vDB is None: + # Actually we want to propagate the error condition, in this case + # if vDB is None: + # pytest.skip("Unable to create a valid DB instance. Skipping.", allow_module_level=True) + + # Create table + await vDB.drop_table() + assert not await vDB.table_exists(), Exception("Table exists before creation") + await vDB.create_table() + assert await vDB.table_exists(), Exception("Table does not exist after creation") + # The test will take control upon the yield + yield vDB + # Teardown: Drop table + await vDB.drop_table() + + +# FIXME: Lots of DRY violations +@pytest_asyncio.fixture # Notice the async aware fixture declaration +async def DB_WINDOWED2(request): + testname = request.node.name + table_name = testname.lower() + print(f'DB setup for test: {testname}. Table name {table_name}', file=sys.stderr) + dummy_model = SentenceTransformer('mock_transformer') + dummy_model.encode.return_value = np.array([1, 2, 3]) + try: + vDB = await MessageDB.from_conn_params( + embedding_model=dummy_model, + table_name=table_name, + db_name=DB_NAME, + host=HOST, + port=int(PORT), + user=USER, + password=PASSWORD, + window=2) + except ConnectionRefusedError: pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) + # Actually we want to propagate the error condition, in this case + # if vDB is None: + # pytest.skip("Unable to create a valid DB instance. Skipping.", allow_module_level=True) # Create table await vDB.drop_table() diff --git a/test/embedding/test_pgvector_data.py b/test/embedding/test_pgvector_data.py index b9690e4..7a8a356 100644 --- a/test/embedding/test_pgvector_data.py +++ b/test/embedding/test_pgvector_data.py @@ -2,7 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 # test/embedding/test_pgvector_data.py ''' -See test/embedding/test_pgvector.py for important notes on running these tests +After setup as described in the README.md for this directory, run the tests with: + +pytest test + +or, for just this test module: + +pytest test/embedding/test_pgvector_data.py + +Uses fixtures from conftest.py in current & parent directories ''' import pytest diff --git a/test/embedding/test_pgvector_doc.py b/test/embedding/test_pgvector_doc.py index 21180b7..b262f96 100644 --- a/test/embedding/test_pgvector_doc.py +++ b/test/embedding/test_pgvector_doc.py @@ -1,23 +1,16 @@ # SPDX-FileCopyrightText: 2023-present Oori Data # SPDX-License-Identifier: Apache-2.0 -# test/embedding/test_pgvector.py +# test/embedding/test_pgvector_doc.py ''' -Set up a mock Postgres instance with the following commands -(make sure you don't have anything running on port 0.0.0.0:5432))): -docker pull ankane/pgvector -docker run --name mock-postgres -p 5432:5432 \ - -e POSTGRES_USER=mock_user -e POSTGRES_PASSWORD=mock_password -e POSTGRES_DB=mock_db \ - -d ankane/pgvector - -Then run the tests with: -pytest test +After setup as described in the README.md for this directory, run the tests with: -or +pytest test -pytest test/embedding/test_pgvector.py +or, for just this test module: -Uses fixtures from ../conftest.py +pytest test/embedding/test_pgvector_doc.py +Uses fixtures from conftest.py in current & parent directories ''' import pytest diff --git a/test/embedding/test_pgvector_message.py b/test/embedding/test_pgvector_message.py index 766f3be..dfc4f71 100644 --- a/test/embedding/test_pgvector_message.py +++ b/test/embedding/test_pgvector_message.py @@ -2,7 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 # test/embedding/test_pgvector_message.py ''' -See test/embedding/test_pgvector.py for important notes on running these tests +After setup as described in the README.md for this directory, run the tests with: + +pytest test + +or, for just this test module: + +pytest test/embedding/test_pgvector_message.py + +Uses fixtures from conftest.py in current & parent directories ''' from datetime import datetime @@ -53,7 +61,30 @@ async def test_insert_message_vector(DB, MESSAGES): # Even though the embedding is mocked, the stored text should be faithful row = next(result) assert row.content == content - assert row.metadata == {'1': 'a'} + assert row.metadata == metadata + + +@pytest.mark.asyncio +async def test_insert_message_vector_windowed(DB_WINDOWED2, MESSAGES): + assert await DB_WINDOWED2.count_items() == 0, Exception('Starting with incorrect number of messages') + # Insert data + for index, (row) in enumerate(MESSAGES): + await DB_WINDOWED2.insert(*row) + + # There should be 2 left from each history key + assert await DB_WINDOWED2.count_items() == 4, Exception('Incorrect number of messages after insertion') + + # In the windowed case, the oldest 4 messages should have been deleted + history_key, role, content, timestamp, metadata = MESSAGES[5] + + # search table with perfect match + result = await DB_WINDOWED2.search(text=content, history_key=history_key, limit=2) + # assert result is not None, Exception('No results returned from perfect search') + + # Even though the embedding is mocked, the stored text should be faithful + row = next(result) + assert row.content == content + assert row.metadata == metadata @pytest.mark.asyncio @@ -72,7 +103,29 @@ async def test_insertmany_message_vector(DB, MESSAGES): # Even though the embedding is mocked, the stored text should be faithful row = next(result) assert row.content == content - assert row.metadata == {'1': 'a'} + assert row.metadata == metadata + + +@pytest.mark.asyncio +async def test_insertmany_message_vector_windowed(DB_WINDOWED2, MESSAGES): + assert await DB_WINDOWED2.count_items() == 0, Exception('Starting with incorrect number of messages') + # Insert data using insert_many() + await DB_WINDOWED2.insert_many(MESSAGES) + + # There should be 2 left from each history key + assert await DB_WINDOWED2.count_items() == 4, Exception('Incorrect number of messages after insertion') + + # In the windowed case, the oldest 4 messages should have been deleted + history_key, role, content, timestamp, metadata = MESSAGES[5] + + # search table with perfect match + result = await DB_WINDOWED2.search(text=content, history_key=history_key, limit=3) + # assert result is not None, Exception('No results returned from perfect search') + + # Even though the embedding is mocked, the stored text should be faithful + row = next(result) + assert row.content == content + assert row.metadata == metadata @pytest.mark.asyncio