Skip to content

Commit

Permalink
Merge pull request #67 from OoriData/messages-since
Browse files Browse the repository at this point in the history
More MessageDB improvements, including timestamp constraints in getting chatlog (messages since timestamp)
  • Loading branch information
uogbuji authored Jan 21, 2024
2 parents ac984b4 + 58549b4 commit c25f6d6
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 266 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ Notable changes to OgbujiPT. Format based on [Keep a Changelog](https://keepacha
## [Unreleased]
-->

## [0.7.1] - 20240122

### Added

- MessageDB.get_messages() options: `since` (for retrieving messages aftter a timestamp) and `limit` (for limiting the number of messages returned, selecting the most recent)

### Changed

- Better modularization of embeddings test cases; using `conftest.py` more
- `pgvector_message.py` PG table timstamp column no longer a primary key

## [0.7.0] - 20240110

### Added
Expand Down
6 changes: 3 additions & 3 deletions pylib/embedding/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def count_items(self) -> int:
# Count the number of documents in the table
count = await conn.fetchval(f'SELECT COUNT(*) FROM {self.table_name}')
return count

async def table_exists(self) -> bool:
'''
Check if the table exists
Expand All @@ -197,11 +197,11 @@ async def table_exists(self) -> bool:
'''
# Check if the table exists
async with (await self.connection_pool()).acquire() as conn:
table_exists = await conn.fetchval(
exists = await conn.fetchval(
CHECK_TABLE_EXISTS,
self.table_name
)
return table_exists
return exists

async def drop_table(self) -> None:
'''
Expand Down
87 changes: 60 additions & 27 deletions pylib/embedding/pgvector_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from uuid import UUID
from datetime import datetime, timezone
from typing import Iterable
import warnings

from ogbujipt.config import attr_dict
from ogbujipt.embedding.pgvector import PGVectorHelper, asyncpg, process_search_response
Expand All @@ -21,7 +22,7 @@

CREATE_MESSAGE_TABLE = '''-- Create a table to hold individual messages (e.g. from a chatlog) and their metadata
CREATE TABLE IF NOT EXISTS {table_name} (
ts TIMESTAMP WITH TIME ZONE PRIMARY KEY, -- timestamp of the message
ts TIMESTAMP WITH TIME ZONE, -- timestamp of the message
history_key UUID, -- uunique identifier for contextual message history
role TEXT, -- role of the message (meta ID such as 'system' or user,
-- or an ID associated with the sender)
Expand All @@ -39,11 +40,7 @@
embedding,
ts,
metadata
) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (ts) DO UPDATE SET -- Update the content, embedding, and metadata of the message if it already exists
content = EXCLUDED.content,
embedding = EXCLUDED.embedding,
metadata = EXCLUDED.metadata;
) VALUES ($1, $2, $3, $4, $5, $6);
'''

CLEAR_MESSAGE = '''-- Deletes matching messages
Expand All @@ -52,7 +49,7 @@
history_key = $1
'''

RETURN_MESSAGE_BY_HISTORY_KEY = '''-- Get entire chatlog by history key
RETURN_MESSAGES_BY_HISTORY_KEY = '''-- Get entire chatlog by history key
SELECT
ts,
role,
Expand All @@ -62,8 +59,10 @@
{table_name}
WHERE
history_key = $1
{since_clause}
ORDER BY
ts;
ts DESC
{limit_clause};
'''

SEMANTIC_QUERY_MESSAGE_TABLE = '''-- Find messages with closest semantic similarity
Expand Down Expand Up @@ -179,54 +178,88 @@ async def clear(
),
history_key
)

# XXX: Change to a generator
async def get_table(
async def get_messages(
self,
history_key: UUID
) -> list[asyncpg.Record]:
history_key: UUID | str,
since: datetime | None = None,
limit: int = 0
): # -> list[asyncpg.Record]:
'''
Retrieve all entries in a message history
Retrieve entries in a message history
Args:
history_key (str): history key (unique identifier) to match
history_key (str): history key (unique identifier) to match; string or object
since (datetime, optional): only return messages after this timestamp
limit (int, optional): maximum number of messages to return. Default is all messages
Returns:
list[asyncpg.Record]: list of message entries
(asyncpg.Record objects are similar to dicts, but allow for attribute-style access)
generates asyncpg.Record instances of resulting messages
'''
qparams = [history_key]
async with (await self.connection_pool()).acquire() as conn:
if not isinstance(history_key, UUID):
history_key = UUID(history_key)
# msg = f'history_key must be a UUID, not {type(history_key)} ({history_key}))'
# raise TypeError(msg)
if not isinstance(since, datetime) and since is not None:
msg = 'since must be a datetime or None'
raise TypeError(msg)
if since:
# Really don't need the ${len(qparams) + 1} trick here, but used for consistency
since_clause = f' AND ts > ${len(qparams) + 1}'
qparams.append(since)
else:
since_clause = ''
if limit:
limit_clause = f'LIMIT ${len(qparams) + 1}'
qparams.append(limit)
else:
limit_clause = ''
message_records = await conn.fetch(
RETURN_MESSAGE_BY_HISTORY_KEY.format(
RETURN_MESSAGES_BY_HISTORY_KEY.format(
table_name=self.table_name,
since_clause=since_clause,
limit_clause=limit_clause
),
history_key
*qparams
)

messages = [
attr_dict({
return (attr_dict({
'ts': record['ts'],
'role': record['role'],
'content': record['content'],
'metadata': record['metadata']
})
for record in message_records
]

return messages
}) for record in message_records)

# async for record in message_records:
# yield attr_dict({
# 'ts': record['ts'],
# 'role': record['role'],
# 'content': record['content'],
# 'metadata': record['metadata']
# })

async def get_table(self, history_key):
# Deprecated
warnings.warn('pgvector_message.get_table() is deprecated. Use get_messages().', DeprecationWarning)
return list(await self.get_messages(history_key))

async def search(
self,
history_key: UUID,
text: str,
since: datetime | None = None,
limit: int = 1
) -> list[asyncpg.Record]:
'''
Similarity search documents using a query string
Args:
history_key (str): history key for the conversation to query
text (str): string to compare against items in the table
k (int, optional): maximum number of results to return (useful for top-k query)
since (datetime, optional): only return results after this timestamp
limit (int, optional): maximum number of messages to return; for top-k type query. Default is 1
Returns:
list[asyncpg.Record]: list of search results
(asyncpg.Record objects are similar to dicts, but allow for attribute-style access)
Expand Down
68 changes: 68 additions & 0 deletions test/embedding/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# SPDX-FileCopyrightText: 2023-present Oori Data <[email protected]>
# SPDX-License-Identifier: Apache-2.0
# test/embedding/conftest.py
'''
Fixtures/setup/teardown for embedding tests
'''

import sys
import os
import pytest
import pytest_asyncio
from unittest.mock import MagicMock, DEFAULT # noqa: F401

import numpy as np
from ogbujipt.embedding.pgvector import MessageDB, DataDB, DocDB


@pytest.fixture
def CORRECT_STRING():
Expand Down Expand Up @@ -67,3 +81,57 @@ def HITHERE_all_MiniLM_L12_v2():
-0.071223356, 0.0019683593, 0.032683503, -0.08899012, 0.10160039, 0.04948275, 0.048017487, -0.046223965,
0.032460734, -0.043729845, 0.030224336, -0.019220904, 0.08223829, 0.03851222, -0.016376046, 0.041965306,
0.0445879, -0.03780432, -0.024826797, 0.014669102, 0.057102628, -0.031820614, 0.0027352672, 0.052658144])

# XXX: This stanza to go away once mocking is complete - Kai
HOST = os.environ.get('PG_HOST', 'localhost')
DB_NAME = os.environ.get('PG_DATABASE', 'mock_db')
USER = os.environ.get('PG_USER', 'mock_user')
PASSWORD = os.environ.get('PG_PASSWORD', 'mock_password')
PORT = os.environ.get('PG_PORT', 5432)


# XXX: Move to a fixture?
# Definitely don't want to even import SentenceTransformer class due to massive side-effects
class SentenceTransformer(object):
def __init__(self, model_name_or_path):
self.encode = MagicMock()


DB_CLASS = {
'test/embedding/test_pgvector_message.py': MessageDB,
'test/embedding/test_pgvector_data.py': DataDB,
'test/embedding/test_pgvector_doc.py': DocDB,
}


@pytest_asyncio.fixture # Notice the async aware fixture declaration
async def DB(request):
testname = request.node.name
testfile = request.node.location[0]
table_name = testname.lower()
print(f'DB setup for test: {testname}. Table name {table_name}', file=sys.stderr)
dummy_model = SentenceTransformer('mock_transformer')
dummy_model.encode.return_value = np.array([1, 2, 3])
try:
vDB = await DB_CLASS[testfile].from_conn_params(
embedding_model=dummy_model,
table_name=table_name,
db_name=DB_NAME,
host=HOST,
port=int(PORT),
user=USER,
password=PASSWORD)
except ConnectionRefusedError:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)
if vDB is None:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)

# Create table
await vDB.drop_table()
assert not await vDB.table_exists(), Exception("Table exists before creation")
await vDB.create_table()
assert await vDB.table_exists(), Exception("Table does not exist after creation")
# The test will take control upon the yield
yield vDB
# Teardown: Drop table
await vDB.drop_table()
Loading

0 comments on commit c25f6d6

Please sign in to comment.