Skip to content

Commit

Permalink
Implement pgvector_message.get_chatlog(), with support for date const…
Browse files Browse the repository at this point in the history
…rint & limit, while depecating pgvector_message.get_table(). Tidy up embedding tests with more use of fixtures & teardown
  • Loading branch information
uogbuji committed Jan 15, 2024
1 parent ac984b4 commit 4eb7975
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 264 deletions.
6 changes: 3 additions & 3 deletions pylib/embedding/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def count_items(self) -> int:
# Count the number of documents in the table
count = await conn.fetchval(f'SELECT COUNT(*) FROM {self.table_name}')
return count

async def table_exists(self) -> bool:
'''
Check if the table exists
Expand All @@ -197,11 +197,11 @@ async def table_exists(self) -> bool:
'''
# Check if the table exists
async with (await self.connection_pool()).acquire() as conn:
table_exists = await conn.fetchval(
exists = await conn.fetchval(
CHECK_TABLE_EXISTS,
self.table_name
)
return table_exists
return exists

async def drop_table(self) -> None:
'''
Expand Down
87 changes: 60 additions & 27 deletions pylib/embedding/pgvector_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from uuid import UUID
from datetime import datetime, timezone
from typing import Iterable
import warnings

from ogbujipt.config import attr_dict
from ogbujipt.embedding.pgvector import PGVectorHelper, asyncpg, process_search_response
Expand All @@ -21,7 +22,7 @@

CREATE_MESSAGE_TABLE = '''-- Create a table to hold individual messages (e.g. from a chatlog) and their metadata
CREATE TABLE IF NOT EXISTS {table_name} (
ts TIMESTAMP WITH TIME ZONE PRIMARY KEY, -- timestamp of the message
ts TIMESTAMP WITH TIME ZONE, -- timestamp of the message
history_key UUID, -- uunique identifier for contextual message history
role TEXT, -- role of the message (meta ID such as 'system' or user,
-- or an ID associated with the sender)
Expand All @@ -39,11 +40,7 @@
embedding,
ts,
metadata
) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (ts) DO UPDATE SET -- Update the content, embedding, and metadata of the message if it already exists
content = EXCLUDED.content,
embedding = EXCLUDED.embedding,
metadata = EXCLUDED.metadata;
) VALUES ($1, $2, $3, $4, $5, $6);
'''

CLEAR_MESSAGE = '''-- Deletes matching messages
Expand All @@ -52,7 +49,7 @@
history_key = $1
'''

RETURN_MESSAGE_BY_HISTORY_KEY = '''-- Get entire chatlog by history key
RETURN_MESSAGES_BY_HISTORY_KEY = '''-- Get entire chatlog by history key
SELECT
ts,
role,
Expand All @@ -62,8 +59,10 @@
{table_name}
WHERE
history_key = $1
{since_clause}
ORDER BY
ts;
ts
{limit_clause};
'''

SEMANTIC_QUERY_MESSAGE_TABLE = '''-- Find messages with closest semantic similarity
Expand Down Expand Up @@ -179,54 +178,88 @@ async def clear(
),
history_key
)

# XXX: Change to a generator
async def get_table(
async def get_chatlog(
self,
history_key: UUID
) -> list[asyncpg.Record]:
history_key: UUID | str,
since: datetime | None = None,
limit: int = 0
): # -> list[asyncpg.Record]:
'''
Retrieve all entries in a message history
Retrieve entries in a message history
Args:
history_key (str): history key (unique identifier) to match
history_key (str): history key (unique identifier) to match; string or object
since (datetime, optional): only return messages after this timestamp
limit (int, optional): maximum number of messages to return. Default is all messages
Returns:
list[asyncpg.Record]: list of message entries
(asyncpg.Record objects are similar to dicts, but allow for attribute-style access)
generates asyncpg.Record instances of resulting messages
'''
qparams = [history_key]
async with (await self.connection_pool()).acquire() as conn:
if not isinstance(history_key, UUID):
history_key = UUID(history_key)
# msg = f'history_key must be a UUID, not {type(history_key)} ({history_key}))'
# raise TypeError(msg)
if not isinstance(since, datetime) and since is not None:
msg = 'since must be a datetime or None'
raise TypeError(msg)
if since:
# Really don't need the ${len(qparams) + 1} trick here, but used for consistency
since_clause = f' AND ts > ${len(qparams) + 1}'
qparams.append(since)
else:
since_clause = ''
if limit:
limit_clause = f'LIMIT ${len(qparams) + 1}'
qparams.append(limit)
else:
limit_clause = ''
message_records = await conn.fetch(
RETURN_MESSAGE_BY_HISTORY_KEY.format(
RETURN_MESSAGES_BY_HISTORY_KEY.format(
table_name=self.table_name,
since_clause=since_clause,
limit_clause=limit_clause
),
history_key
*qparams
)

messages = [
attr_dict({
return (attr_dict({
'ts': record['ts'],
'role': record['role'],
'content': record['content'],
'metadata': record['metadata']
})
for record in message_records
]

return messages
}) for record in message_records)

# async for record in message_records:
# yield attr_dict({
# 'ts': record['ts'],
# 'role': record['role'],
# 'content': record['content'],
# 'metadata': record['metadata']
# })

async def get_table(self, history_key):
# Deprecated
warnings.warn('pgvector_message.get_table() is deprecated. Use get_chatlog().', DeprecationWarning)
return list(await self.get_chatlog(history_key))

async def search(
self,
history_key: UUID,
text: str,
since: datetime | None = None,
limit: int = 1
) -> list[asyncpg.Record]:
'''
Similarity search documents using a query string
Args:
history_key (str): history key for the conversation to query
text (str): string to compare against items in the table
k (int, optional): maximum number of results to return (useful for top-k query)
since (datetime, optional): only return results after this timestamp
limit (int, optional): maximum number of messages to return; for top-k type query. Default is 1
Returns:
list[asyncpg.Record]: list of search results
(asyncpg.Record objects are similar to dicts, but allow for attribute-style access)
Expand Down
68 changes: 68 additions & 0 deletions test/embedding/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# SPDX-FileCopyrightText: 2023-present Oori Data <[email protected]>
# SPDX-License-Identifier: Apache-2.0
# test/embedding/conftest.py
'''
Fixtures/setup/teardown for embedding tests
'''

import sys
import os
import pytest
import pytest_asyncio
from unittest.mock import MagicMock, DEFAULT # noqa: F401

import numpy as np
from ogbujipt.embedding.pgvector import MessageDB, DataDB, DocDB


@pytest.fixture
def CORRECT_STRING():
Expand Down Expand Up @@ -67,3 +81,57 @@ def HITHERE_all_MiniLM_L12_v2():
-0.071223356, 0.0019683593, 0.032683503, -0.08899012, 0.10160039, 0.04948275, 0.048017487, -0.046223965,
0.032460734, -0.043729845, 0.030224336, -0.019220904, 0.08223829, 0.03851222, -0.016376046, 0.041965306,
0.0445879, -0.03780432, -0.024826797, 0.014669102, 0.057102628, -0.031820614, 0.0027352672, 0.052658144])

# XXX: This stanza to go away once mocking is complete - Kai
HOST = os.environ.get('PG_HOST', 'localhost')
DB_NAME = os.environ.get('PG_DATABASE', 'mock_db')
USER = os.environ.get('PG_USER', 'mock_user')
PASSWORD = os.environ.get('PG_PASSWORD', 'mock_password')
PORT = os.environ.get('PG_PORT', 5432)


# XXX: Move to a fixture?
# Definitely don't want to even import SentenceTransformer class due to massive side-effects
class SentenceTransformer(object):
def __init__(self, model_name_or_path):
self.encode = MagicMock()


DB_CLASS = {
'test/embedding/test_pgvector_message.py': MessageDB,
'test/embedding/test_pgvector_data.py': DataDB,
'test/embedding/test_pgvector_doc.py': DocDB,
}


@pytest_asyncio.fixture # Notice the async aware fixture declaration
async def DB(request):
testname = request.node.name
testfile = request.node.location[0]
table_name = testname.lower()
print(f'DB setup for test: {testname}. Table name {table_name}', file=sys.stderr)
dummy_model = SentenceTransformer('mock_transformer')
dummy_model.encode.return_value = np.array([1, 2, 3])
try:
vDB = await DB_CLASS[testfile].from_conn_params(
embedding_model=dummy_model,
table_name=table_name,
db_name=DB_NAME,
host=HOST,
port=int(PORT),
user=USER,
password=PASSWORD)
except ConnectionRefusedError:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)
if vDB is None:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)

# Create table
await vDB.drop_table()
assert not await vDB.table_exists(), Exception("Table exists before creation")
await vDB.create_table()
assert await vDB.table_exists(), Exception("Table does not exist after creation")
# The test will take control upon the yield
yield vDB
# Teardown: Drop table
await vDB.drop_table()
69 changes: 10 additions & 59 deletions test/embedding/test_pgvector_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,8 @@
import pytest
from unittest.mock import MagicMock, DEFAULT # noqa: F401

import os
from ogbujipt.embedding.pgvector import DataDB
import numpy as np

# XXX: This stanza to go away once mocking is complete - Kai
HOST = os.environ.get('PG_HOST', 'localhost')
DB_NAME = os.environ.get('PG_DATABASE', 'mock_db')
USER = os.environ.get('PG_USER', 'mock_user')
PASSWORD = os.environ.get('PG_PASSWORD', 'mock_password')
PORT = os.environ.get('PG_PORT', 5432)

KG_STATEMENTS = [ # Demo data
("👤 Alikiba `releases_single` 💿 'Yalaiti'", {'url': 'https://notjustok.com/lyrics/yalaiti-lyrics-by-alikiba-ft-sabah-salum/'}),
Expand All @@ -36,100 +28,59 @@ def __init__(self, model_name_or_path):


@pytest.mark.asyncio
async def test_insert_data_vector():
async def test_insert_data_vector(DB):
dummy_model = SentenceTransformer('mock_transformer')
dummy_model.encode.return_value = np.array([1, 2, 3])
TABLE_NAME = 'embedding_data_test'
try:
vDB = await DataDB.from_conn_params(
embedding_model=dummy_model,
table_name=TABLE_NAME,
db_name=DB_NAME,
host=HOST,
port=int(PORT),
user=USER,
password=PASSWORD)
except ConnectionRefusedError:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)
if vDB is None:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)

# Create tables
await vDB.drop_table()
assert await vDB.table_exists() is False, Exception("Table exists before creation")
await vDB.create_table()
assert await vDB.table_exists() is True, Exception("Table does not exist after creation")

item1_text = KG_STATEMENTS[0][0]
item1_meta = KG_STATEMENTS[0][1]

# Insert data
for index, (text, meta) in enumerate(KG_STATEMENTS):
await vDB.insert( # Insert the row into the table
await DB.insert( # Insert the row into the table
content=text, # text to be embedded
tags=[f'{k}={v}' for (k, v) in meta.items()], # Tag metadata
)

assert await vDB.count_items() == len(KG_STATEMENTS), Exception('Incorrect number of documents after insertion')
assert await DB.count_items() == len(KG_STATEMENTS), Exception('Incorrect number of documents after insertion')

# search table with perfect match
result = await vDB.search(text=item1_text, limit=3)
result = await DB.search(text=item1_text, limit=3)
# assert result is not None, Exception('No results returned from perfect search')

# Even though the embedding is mocked, the stored text should be faithful
row = next(result)
assert row.content == item1_text
assert row.tags == [f'{k}={v}' for (k, v) in item1_meta.items()]

await vDB.drop_table()
await DB.drop_table()


@pytest.mark.asyncio
async def test_insertmany_data_vector():
async def test_insertmany_data_vector(DB):
dummy_model = SentenceTransformer('mock_transformer')
dummy_model.encode.return_value = np.array([1, 2, 3])
# print(f'EMODEL: {dummy_model}')
TABLE_NAME = 'embedding_test'
try:
vDB = await DataDB.from_conn_params(
embedding_model=dummy_model,
table_name=TABLE_NAME,
db_name=DB_NAME,
host=HOST,
port=int(PORT),
user=USER,
password=PASSWORD)
except ConnectionRefusedError:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)
if vDB is None:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)

item1_text = KG_STATEMENTS[0][0]
item1_meta = KG_STATEMENTS[0][1]

# Create tables
await vDB.drop_table()
assert await vDB.table_exists() is False, Exception("Table exists before creation")
await vDB.create_table()
assert await vDB.table_exists() is True, Exception("Table does not exist after creation")

# Insert data using insert_many()
dataset = ((text, [f'{k}={v}' for (k, v) in tags.items()]) for (text, tags) in KG_STATEMENTS)

await vDB.insert_many(dataset)
await DB.insert_many(dataset)

assert await vDB.count_items() == len(KG_STATEMENTS), Exception('Incorrect number of documents after insertion')
assert await DB.count_items() == len(KG_STATEMENTS), Exception('Incorrect number of documents after insertion')

# search table with perfect match
result = await vDB.search(text=item1_text, limit=3)
result = await DB.search(text=item1_text, limit=3)
# assert result is not None, Exception('No results returned from perfect search')

# Even though the embedding is mocked, the stored text should be faithful
row = next(result)
assert row.content == item1_text
assert row.tags == [f'{k}={v}' for (k, v) in item1_meta.items()]

await vDB.drop_table()
await DB.drop_table()


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 4eb7975

Please sign in to comment.