Skip to content

Commit

Permalink
Testing doc tweaks. Implement windowing size param for for messages_d…
Browse files Browse the repository at this point in the history
…b, to limit message storage per history key.
  • Loading branch information
uogbuji committed Mar 22, 2024
1 parent 836e6e4 commit e0bb5e6
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 21 deletions.
2 changes: 1 addition & 1 deletion demo/PGvector_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.6"
},
"orig_nbformat": 4
},
Expand Down
71 changes: 69 additions & 2 deletions pylib/embedding/pgvector_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,55 @@
cosine_similarity DESC
LIMIT $3;
'''
# ------ SQL queries ---------------------------------------------------------------------------------------------------

DELETE_OLDEST_MESSAGES = '''-- Delete oldest messages for given history key, such that only the newest N messages remain
DELETE FROM {table_name} t_outer
WHERE
t_outer.history_key = $1
AND
t_outer.ctid NOT IN (
SELECT t_inner.ctid
FROM {table_name} t_inner
WHERE
t_inner.history_key = $1
ORDER BY
t_inner.ts DESC
LIMIT $2
);
'''

# Delete after full comfort with windowed implementation
# TEMPQ = '''
# SELECT t_inner.ctid
# FROM {table_name} t_inner
# WHERE
# t_inner.history_key = $1
# ORDER BY
# t_inner.ts DESC
# LIMIT $2;
# '''

# ------ Class implementations ---------------------------------------------------------------------------------------

class MessageDB(PGVectorHelper):
def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool, window=0):
'''
Helper class for messages/chatlog storage and retrieval
Args:
embedding (SentenceTransformer): SentenceTransformer object of your choice
https://huggingface.co/sentence-transformers
window (int, optional): number of messages to maintain in the DB. Default is 0 (all messages)
'''
super().__init__(embedding_model, table_name, pool)
self.window = window

@classmethod
async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password, window=0) -> 'MessageDB': # noqa: E501
obj = await super().from_conn_params(embedding_model, table_name, host, port, db_name, user, password)
obj.window = window
return obj

''' Specialize PGvectorHelper for messages, e.g. chatlogs '''
async def create_table(self):
async with self.pool.acquire() as conn:
Expand Down Expand Up @@ -135,7 +181,17 @@ async def insert(
content_embedding.tolist(),
timestamp,
metadata
)
)
# print(f'{self.window=}, Pre-count: {await self.count_items()}')
async with self.pool.acquire() as conn:
async with conn.transaction():
if self.window:
await conn.execute(
DELETE_OLDEST_MESSAGES.format(table_name=self.table_name),
history_key,
self.window)
# async with self.pool.acquire() as conn:
# print(f'{self.window=}, Post-count: {await self.count_items()}, {list(await conn.fetch(TEMPQ.format(table_name=self.table_name), history_key, self.window))}') # noqa E501

async def insert_many(
self,
Expand All @@ -158,6 +214,17 @@ async def insert_many(
for hk, role, text, ts, metadata in content_list
)
)
# print(f'{self.window=}, Pre-count: {await self.count_items()}') # noqa E501
async with self.pool.acquire() as conn:
async with conn.transaction():
if self.window:
# Set uniquifies the history keys
for hk in {hk for hk, _, _, _, _ in content_list}:
await conn.execute(
DELETE_OLDEST_MESSAGES.format(table_name=self.table_name),
hk, self.window)
# async with self.pool.acquire() as conn:
# print(f'{self.window=}, {hk=}, Post-count: {await self.count_items()}, {list(await conn.fetch(TEMPQ.format(table_name=self.table_name), hk, self.window))}') # noqa E501

async def clear(
self,
Expand Down
9 changes: 9 additions & 0 deletions test/embedding/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
To run these tests, first set up a mock Postgres instance with the following commands
(make sure you don't have anything running on port 0.0.0.0:5432):

```sh
docker pull ankane/pgvector
docker run --name mock-postgres -p 5432:5432 \
-e POSTGRES_USER=mock_user -e POSTGRES_PASSWORD=mock_password -e POSTGRES_DB=mock_db \
-d ankane/pgvector
```
47 changes: 46 additions & 1 deletion test/embedding/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
# test/embedding/conftest.py
'''
Fixtures/setup/teardown for embedding tests
General note: After setup as described in the README.md for this directory, run the tests with:
pytest test
or, for just embeddings tests:
pytest test/embedding/
'''

import sys
Expand Down Expand Up @@ -103,6 +111,7 @@ def __init__(self, model_name_or_path):
'test/embedding/test_pgvector_doc.py': DocDB,
}

# print(HOST, DB_NAME, USER, PASSWORD, PORT)

@pytest_asyncio.fixture # Notice the async aware fixture declaration
async def DB(request):
Expand All @@ -123,8 +132,44 @@ async def DB(request):
password=PASSWORD)
except ConnectionRefusedError:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)
if vDB is None:
# Actually we want to propagate the error condition, in this case
# if vDB is None:
# pytest.skip("Unable to create a valid DB instance. Skipping.", allow_module_level=True)

# Create table
await vDB.drop_table()
assert not await vDB.table_exists(), Exception("Table exists before creation")
await vDB.create_table()
assert await vDB.table_exists(), Exception("Table does not exist after creation")
# The test will take control upon the yield
yield vDB
# Teardown: Drop table
await vDB.drop_table()


# FIXME: Lots of DRY violations
@pytest_asyncio.fixture # Notice the async aware fixture declaration
async def DB_WINDOWED2(request):
testname = request.node.name
table_name = testname.lower()
print(f'DB setup for test: {testname}. Table name {table_name}', file=sys.stderr)
dummy_model = SentenceTransformer('mock_transformer')
dummy_model.encode.return_value = np.array([1, 2, 3])
try:
vDB = await MessageDB.from_conn_params(
embedding_model=dummy_model,
table_name=table_name,
db_name=DB_NAME,
host=HOST,
port=int(PORT),
user=USER,
password=PASSWORD,
window=2)
except ConnectionRefusedError:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)
# Actually we want to propagate the error condition, in this case
# if vDB is None:
# pytest.skip("Unable to create a valid DB instance. Skipping.", allow_module_level=True)

# Create table
await vDB.drop_table()
Expand Down
10 changes: 9 additions & 1 deletion test/embedding/test_pgvector_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
# SPDX-License-Identifier: Apache-2.0
# test/embedding/test_pgvector_data.py
'''
See test/embedding/test_pgvector.py for important notes on running these tests
After setup as described in the README.md for this directory, run the tests with:
pytest test
or, for just this test module:
pytest test/embedding/test_pgvector_data.py
Uses fixtures from conftest.py in current & parent directories
'''

import pytest
Expand Down
19 changes: 6 additions & 13 deletions test/embedding/test_pgvector_doc.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
# SPDX-FileCopyrightText: 2023-present Oori Data <[email protected]>
# SPDX-License-Identifier: Apache-2.0
# test/embedding/test_pgvector.py
# test/embedding/test_pgvector_doc.py
'''
Set up a mock Postgres instance with the following commands
(make sure you don't have anything running on port 0.0.0.0:5432))):
docker pull ankane/pgvector
docker run --name mock-postgres -p 5432:5432 \
-e POSTGRES_USER=mock_user -e POSTGRES_PASSWORD=mock_password -e POSTGRES_DB=mock_db \
-d ankane/pgvector
Then run the tests with:
pytest test
After setup as described in the README.md for this directory, run the tests with:
or
pytest test
pytest test/embedding/test_pgvector.py
or, for just this test module:
Uses fixtures from ../conftest.py
pytest test/embedding/test_pgvector_doc.py
Uses fixtures from conftest.py in current & parent directories
'''

import pytest
Expand Down
59 changes: 56 additions & 3 deletions test/embedding/test_pgvector_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
# SPDX-License-Identifier: Apache-2.0
# test/embedding/test_pgvector_message.py
'''
See test/embedding/test_pgvector.py for important notes on running these tests
After setup as described in the README.md for this directory, run the tests with:
pytest test
or, for just this test module:
pytest test/embedding/test_pgvector_message.py
Uses fixtures from conftest.py in current & parent directories
'''

from datetime import datetime
Expand Down Expand Up @@ -53,7 +61,30 @@ async def test_insert_message_vector(DB, MESSAGES):
# Even though the embedding is mocked, the stored text should be faithful
row = next(result)
assert row.content == content
assert row.metadata == {'1': 'a'}
assert row.metadata == metadata


@pytest.mark.asyncio
async def test_insert_message_vector_windowed(DB_WINDOWED2, MESSAGES):
assert await DB_WINDOWED2.count_items() == 0, Exception('Starting with incorrect number of messages')
# Insert data
for index, (row) in enumerate(MESSAGES):
await DB_WINDOWED2.insert(*row)

# There should be 2 left from each history key
assert await DB_WINDOWED2.count_items() == 4, Exception('Incorrect number of messages after insertion')

# In the windowed case, the oldest 4 messages should have been deleted
history_key, role, content, timestamp, metadata = MESSAGES[5]

# search table with perfect match
result = await DB_WINDOWED2.search(text=content, history_key=history_key, limit=2)
# assert result is not None, Exception('No results returned from perfect search')

# Even though the embedding is mocked, the stored text should be faithful
row = next(result)
assert row.content == content
assert row.metadata == metadata


@pytest.mark.asyncio
Expand All @@ -72,7 +103,29 @@ async def test_insertmany_message_vector(DB, MESSAGES):
# Even though the embedding is mocked, the stored text should be faithful
row = next(result)
assert row.content == content
assert row.metadata == {'1': 'a'}
assert row.metadata == metadata


@pytest.mark.asyncio
async def test_insertmany_message_vector_windowed(DB_WINDOWED2, MESSAGES):
assert await DB_WINDOWED2.count_items() == 0, Exception('Starting with incorrect number of messages')
# Insert data using insert_many()
await DB_WINDOWED2.insert_many(MESSAGES)

# There should be 2 left from each history key
assert await DB_WINDOWED2.count_items() == 4, Exception('Incorrect number of messages after insertion')

# In the windowed case, the oldest 4 messages should have been deleted
history_key, role, content, timestamp, metadata = MESSAGES[5]

# search table with perfect match
result = await DB_WINDOWED2.search(text=content, history_key=history_key, limit=3)
# assert result is not None, Exception('No results returned from perfect search')

# Even though the embedding is mocked, the stored text should be faithful
row = next(result)
assert row.content == content
assert row.metadata == metadata


@pytest.mark.asyncio
Expand Down

0 comments on commit e0bb5e6

Please sign in to comment.