Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More MessageDB improvements, including timestamp constraints in getting chatlog (messages since timestamp) #67

Merged
merged 6 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ Notable changes to OgbujiPT. Format based on [Keep a Changelog](https://keepacha
## [Unreleased]
-->

## [0.7.1] - 20240122

### Added

- MessageDB.get_messages() options: `since` (for retrieving messages aftter a timestamp) and `limit` (for limiting the number of messages returned, selecting the most recent)

### Changed

- Better modularization of embeddings test cases; using `conftest.py` more
- `pgvector_message.py` PG table timstamp column no longer a primary key

## [0.7.0] - 20240110

### Added
Expand Down
6 changes: 3 additions & 3 deletions pylib/embedding/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def count_items(self) -> int:
# Count the number of documents in the table
count = await conn.fetchval(f'SELECT COUNT(*) FROM {self.table_name}')
return count

async def table_exists(self) -> bool:
'''
Check if the table exists
Expand All @@ -197,11 +197,11 @@ async def table_exists(self) -> bool:
'''
# Check if the table exists
async with (await self.connection_pool()).acquire() as conn:
table_exists = await conn.fetchval(
exists = await conn.fetchval(
CHECK_TABLE_EXISTS,
self.table_name
)
return table_exists
return exists

async def drop_table(self) -> None:
'''
Expand Down
87 changes: 60 additions & 27 deletions pylib/embedding/pgvector_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from uuid import UUID
from datetime import datetime, timezone
from typing import Iterable
import warnings

from ogbujipt.config import attr_dict
from ogbujipt.embedding.pgvector import PGVectorHelper, asyncpg, process_search_response
Expand All @@ -21,7 +22,7 @@

CREATE_MESSAGE_TABLE = '''-- Create a table to hold individual messages (e.g. from a chatlog) and their metadata
CREATE TABLE IF NOT EXISTS {table_name} (
ts TIMESTAMP WITH TIME ZONE PRIMARY KEY, -- timestamp of the message
ts TIMESTAMP WITH TIME ZONE, -- timestamp of the message
history_key UUID, -- uunique identifier for contextual message history
role TEXT, -- role of the message (meta ID such as 'system' or user,
-- or an ID associated with the sender)
Expand All @@ -39,11 +40,7 @@
embedding,
ts,
metadata
) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (ts) DO UPDATE SET -- Update the content, embedding, and metadata of the message if it already exists
content = EXCLUDED.content,
embedding = EXCLUDED.embedding,
metadata = EXCLUDED.metadata;
) VALUES ($1, $2, $3, $4, $5, $6);
'''

CLEAR_MESSAGE = '''-- Deletes matching messages
Expand All @@ -52,7 +49,7 @@
history_key = $1
'''

RETURN_MESSAGE_BY_HISTORY_KEY = '''-- Get entire chatlog by history key
RETURN_MESSAGES_BY_HISTORY_KEY = '''-- Get entire chatlog by history key
SELECT
ts,
role,
Expand All @@ -62,8 +59,10 @@
{table_name}
WHERE
history_key = $1
{since_clause}
ORDER BY
ts;
ts DESC
{limit_clause};
'''

SEMANTIC_QUERY_MESSAGE_TABLE = '''-- Find messages with closest semantic similarity
Expand Down Expand Up @@ -179,54 +178,88 @@ async def clear(
),
history_key
)

# XXX: Change to a generator
async def get_table(
async def get_messages(
self,
history_key: UUID
) -> list[asyncpg.Record]:
history_key: UUID | str,
since: datetime | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, the since arg is cool

limit: int = 0
): # -> list[asyncpg.Record]:
'''
Retrieve all entries in a message history
Retrieve entries in a message history

Args:
history_key (str): history key (unique identifier) to match
history_key (str): history key (unique identifier) to match; string or object
since (datetime, optional): only return messages after this timestamp
limit (int, optional): maximum number of messages to return. Default is all messages
Returns:
list[asyncpg.Record]: list of message entries
(asyncpg.Record objects are similar to dicts, but allow for attribute-style access)
generates asyncpg.Record instances of resulting messages
'''
qparams = [history_key]
async with (await self.connection_pool()).acquire() as conn:
if not isinstance(history_key, UUID):
history_key = UUID(history_key)
# msg = f'history_key must be a UUID, not {type(history_key)} ({history_key}))'
# raise TypeError(msg)
if not isinstance(since, datetime) and since is not None:
msg = 'since must be a datetime or None'
raise TypeError(msg)
if since:
# Really don't need the ${len(qparams) + 1} trick here, but used for consistency
since_clause = f' AND ts > ${len(qparams) + 1}'
qparams.append(since)
else:
since_clause = ''
if limit:
limit_clause = f'LIMIT ${len(qparams) + 1}'
qparams.append(limit)
else:
limit_clause = ''
message_records = await conn.fetch(
RETURN_MESSAGE_BY_HISTORY_KEY.format(
RETURN_MESSAGES_BY_HISTORY_KEY.format(
table_name=self.table_name,
since_clause=since_clause,
limit_clause=limit_clause
),
history_key
*qparams
)

messages = [
attr_dict({
return (attr_dict({
'ts': record['ts'],
'role': record['role'],
'content': record['content'],
'metadata': record['metadata']
})
for record in message_records
]

return messages
}) for record in message_records)

# async for record in message_records:
# yield attr_dict({
# 'ts': record['ts'],
# 'role': record['role'],
# 'content': record['content'],
# 'metadata': record['metadata']
# })

async def get_table(self, history_key):
# Deprecated
warnings.warn('pgvector_message.get_table() is deprecated. Use get_messages().', DeprecationWarning)
return list(await self.get_messages(history_key))

async def search(
self,
history_key: UUID,
text: str,
since: datetime | None = None,
limit: int = 1
) -> list[asyncpg.Record]:
'''
Similarity search documents using a query string

Args:
history_key (str): history key for the conversation to query
text (str): string to compare against items in the table

k (int, optional): maximum number of results to return (useful for top-k query)
since (datetime, optional): only return results after this timestamp
limit (int, optional): maximum number of messages to return; for top-k type query. Default is 1
Returns:
list[asyncpg.Record]: list of search results
(asyncpg.Record objects are similar to dicts, but allow for attribute-style access)
Expand Down
68 changes: 68 additions & 0 deletions test/embedding/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# SPDX-FileCopyrightText: 2023-present Oori Data <[email protected]>
# SPDX-License-Identifier: Apache-2.0
# test/embedding/conftest.py
'''
Fixtures/setup/teardown for embedding tests
'''

import sys
import os
import pytest
import pytest_asyncio
from unittest.mock import MagicMock, DEFAULT # noqa: F401

import numpy as np
from ogbujipt.embedding.pgvector import MessageDB, DataDB, DocDB


@pytest.fixture
def CORRECT_STRING():
Expand Down Expand Up @@ -67,3 +81,57 @@ def HITHERE_all_MiniLM_L12_v2():
-0.071223356, 0.0019683593, 0.032683503, -0.08899012, 0.10160039, 0.04948275, 0.048017487, -0.046223965,
0.032460734, -0.043729845, 0.030224336, -0.019220904, 0.08223829, 0.03851222, -0.016376046, 0.041965306,
0.0445879, -0.03780432, -0.024826797, 0.014669102, 0.057102628, -0.031820614, 0.0027352672, 0.052658144])

# XXX: This stanza to go away once mocking is complete - Kai
HOST = os.environ.get('PG_HOST', 'localhost')
DB_NAME = os.environ.get('PG_DATABASE', 'mock_db')
USER = os.environ.get('PG_USER', 'mock_user')
PASSWORD = os.environ.get('PG_PASSWORD', 'mock_password')
PORT = os.environ.get('PG_PORT', 5432)


# XXX: Move to a fixture?
# Definitely don't want to even import SentenceTransformer class due to massive side-effects
class SentenceTransformer(object):
def __init__(self, model_name_or_path):
self.encode = MagicMock()


DB_CLASS = {
'test/embedding/test_pgvector_message.py': MessageDB,
'test/embedding/test_pgvector_data.py': DataDB,
'test/embedding/test_pgvector_doc.py': DocDB,
}


@pytest_asyncio.fixture # Notice the async aware fixture declaration
async def DB(request):
testname = request.node.name
testfile = request.node.location[0]
table_name = testname.lower()
print(f'DB setup for test: {testname}. Table name {table_name}', file=sys.stderr)
dummy_model = SentenceTransformer('mock_transformer')
dummy_model.encode.return_value = np.array([1, 2, 3])
try:
vDB = await DB_CLASS[testfile].from_conn_params(
embedding_model=dummy_model,
table_name=table_name,
db_name=DB_NAME,
host=HOST,
port=int(PORT),
user=USER,
password=PASSWORD)
except ConnectionRefusedError:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)
if vDB is None:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)

# Create table
await vDB.drop_table()
assert not await vDB.table_exists(), Exception("Table exists before creation")
await vDB.create_table()
assert await vDB.table_exists(), Exception("Table does not exist after creation")
# The test will take control upon the yield
yield vDB
# Teardown: Drop table
await vDB.drop_table()
Loading