Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sliding MessageDB window #76

Merged
merged 2 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
# Changelog

Notable changes to OgbujiPT. Format based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). Project follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
Notable changes to Format based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). Project follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

<!--
## [Unreleased]
-->

## [0.8.0] - 20240312
## [0.8.0] - 20240325

### Added

- Implemented `ogbujipt.llm_wrapper.llama_cpp_http_chat` & `ogbujipt.llm_wrapper.llama_cpp_http`; llama.cpp low-level HTTP API support
- Implemented flexible `ogbujipt.llm_wrapper.llama_response` class
- `llm_wrapper.llama_cpp_http_chat` & `llm_wrapper.llama_cpp_http`; llama.cpp low-level HTTP API support
- `llm_wrapper.llama_response` class with flexible handling across API specs
- `window` init param for for `embedding.pgvector.MessageDB`, to limit message storage per history key

### Changed

- Deprecated `first_choice_text` & `first_choice_message` methods in favor of `first_choice_text` attributes on response objects
- Clarify set quite setup docs

## [0.7.1] - 20240229

Expand All @@ -31,7 +33,7 @@ Notable changes to OgbujiPT. Format based on [Keep a Changelog](https://keepacha

### Fixed

- Backward threshold check for ogbujipt.embedding.pgvector_data_doc.DataDB
- Backward threshold check for embedding.pgvector_data_doc.DataDB

## [0.7.0] - 20240110

Expand All @@ -40,7 +42,7 @@ Notable changes to OgbujiPT. Format based on [Keep a Changelog](https://keepacha
- Command line options for `demo/chat_web_selects.py`
- Helper for folks installing on Apple Silicon: `constraints-apple-silicon.txt`
- Function calling demo
- `ogbujipt.embedding.pgvector_message.insert_many()`
- `embedding.pgvector_message.insert_many()`

### Changed

Expand All @@ -52,8 +54,8 @@ Notable changes to OgbujiPT. Format based on [Keep a Changelog](https://keepacha
- Separated data-style PGVector DBs from doc-style. tags is no longer the final param for PGVector docs helper methods & some params renamed.
- PGVector helper method results now as `attr_dict`
- PGVector helper now uses connection pooling & is more multiprocess safe
- `ogbujipt.embedding.pgvector_chat` renamed to `ogbujipt.embedding.pgvector_message`
- DB MIGRATION REQUIRED - `ogbujipt.embedding.pgvector_message` table schema
- `embedding.pgvector_chat` renamed to `embedding.pgvector_message`
- DB MIGRATION REQUIRED - `embedding.pgvector_message` table schema

### Fixed

Expand Down Expand Up @@ -142,9 +144,9 @@ Notable changes to OgbujiPT. Format based on [Keep a Changelog](https://keepacha

### Added

- `ogbujipt.__version__`
- `__version__`
- chat_web_selects.py demo
- `ogbujipt.async_helper.save_openai_api_params()`
- `async_helper.save_openai_api_params()`

### Fixed

Expand All @@ -155,17 +157,17 @@ Notable changes to OgbujiPT. Format based on [Keep a Changelog](https://keepacha

- Renamed demo alpaca_simple_fix_xml.py → simple_fix_xml.py
- Renamed demo alpaca_multitask_fix_xml.py → multiprocess.py
- Renamed `ogbujipt.oapi_choice1_text()``ogbujipt.oapi_first_choice_text()`
- Renamed `ogbujipt.async_helper.schedule_llm_call()``ogbujipt.async_helper.schedule_callable()`
- Renamed `oapi_choice1_text()``oapi_first_choice_text()`
- Renamed `async_helper.schedule_llm_call()``async_helper.schedule_callable()`

## [0.1.1] - 20230711

### Added

- GitHub CI workflow
- Orca model style
- Convenience function ogbujipt.oapi_choice1_text()
- Additional conveniences in ogbujipt.prompting.model_style
- Convenience function oapi_choice1_text()
- Additional conveniences in prompting.model_style

### Fixed

Expand All @@ -174,7 +176,7 @@ Notable changes to OgbujiPT. Format based on [Keep a Changelog](https://keepacha
### Changed

- Qdrant embeddings interface
- Renamed ogbujipt.prompting.context_build() → ogbujipt.prompting.format()
- Renamed prompting.context_build() → prompting.format()

## [0.1.0]

Expand Down
2 changes: 1 addition & 1 deletion demo/PGvector_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.6"
},
"orig_nbformat": 4
},
Expand Down
71 changes: 69 additions & 2 deletions pylib/embedding/pgvector_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,55 @@
cosine_similarity DESC
LIMIT $3;
'''
# ------ SQL queries ---------------------------------------------------------------------------------------------------

DELETE_OLDEST_MESSAGES = '''-- Delete oldest messages for given history key, such that only the newest N messages remain
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should prob take a pass at moving all sql out of the .pys at some point

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Maybe make it a goal for 0.9.0.

DELETE FROM {table_name} t_outer
WHERE
t_outer.history_key = $1
AND
t_outer.ctid NOT IN (
SELECT t_inner.ctid
FROM {table_name} t_inner
WHERE
t_inner.history_key = $1
ORDER BY
t_inner.ts DESC
LIMIT $2
);
'''

# Delete after full comfort with windowed implementation
# TEMPQ = '''
# SELECT t_inner.ctid
# FROM {table_name} t_inner
# WHERE
# t_inner.history_key = $1
# ORDER BY
# t_inner.ts DESC
# LIMIT $2;
# '''

# ------ Class implementations ---------------------------------------------------------------------------------------

class MessageDB(PGVectorHelper):
def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool, window=0):
'''
Helper class for messages/chatlog storage and retrieval

Args:
embedding (SentenceTransformer): SentenceTransformer object of your choice
https://huggingface.co/sentence-transformers
window (int, optional): number of messages to maintain in the DB. Default is 0 (all messages)
'''
super().__init__(embedding_model, table_name, pool)
self.window = window

@classmethod
async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password, window=0) -> 'MessageDB': # noqa: E501
obj = await super().from_conn_params(embedding_model, table_name, host, port, db_name, user, password)
obj.window = window
return obj

''' Specialize PGvectorHelper for messages, e.g. chatlogs '''
async def create_table(self):
async with self.pool.acquire() as conn:
Expand Down Expand Up @@ -135,7 +181,17 @@ async def insert(
content_embedding.tolist(),
timestamp,
metadata
)
)
# print(f'{self.window=}, Pre-count: {await self.count_items()}')
async with self.pool.acquire() as conn:
async with conn.transaction():
if self.window:
await conn.execute(
DELETE_OLDEST_MESSAGES.format(table_name=self.table_name),
history_key,
self.window)
# async with self.pool.acquire() as conn:
# print(f'{self.window=}, Post-count: {await self.count_items()}, {list(await conn.fetch(TEMPQ.format(table_name=self.table_name), history_key, self.window))}') # noqa E501

async def insert_many(
self,
Expand All @@ -158,6 +214,17 @@ async def insert_many(
for hk, role, text, ts, metadata in content_list
)
)
# print(f'{self.window=}, Pre-count: {await self.count_items()}') # noqa E501
async with self.pool.acquire() as conn:
async with conn.transaction():
if self.window:
# Set uniquifies the history keys
for hk in {hk for hk, _, _, _, _ in content_list}:
await conn.execute(
DELETE_OLDEST_MESSAGES.format(table_name=self.table_name),
hk, self.window)
# async with self.pool.acquire() as conn:
# print(f'{self.window=}, {hk=}, Post-count: {await self.count_items()}, {list(await conn.fetch(TEMPQ.format(table_name=self.table_name), hk, self.window))}') # noqa E501

async def clear(
self,
Expand Down
9 changes: 9 additions & 0 deletions test/embedding/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
To run these tests, first set up a mock Postgres instance with the following commands
(make sure you don't have anything running on port 0.0.0.0:5432):

```sh
docker pull ankane/pgvector
docker run --name mock-postgres -p 5432:5432 \
-e POSTGRES_USER=mock_user -e POSTGRES_PASSWORD=mock_password -e POSTGRES_DB=mock_db \
-d ankane/pgvector
```
47 changes: 46 additions & 1 deletion test/embedding/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
# test/embedding/conftest.py
'''
Fixtures/setup/teardown for embedding tests

General note: After setup as described in the README.md for this directory, run the tests with:

pytest test

or, for just embeddings tests:

pytest test/embedding/
'''

import sys
Expand Down Expand Up @@ -103,6 +111,7 @@ def __init__(self, model_name_or_path):
'test/embedding/test_pgvector_doc.py': DocDB,
}

# print(HOST, DB_NAME, USER, PASSWORD, PORT)

@pytest_asyncio.fixture # Notice the async aware fixture declaration
async def DB(request):
Expand All @@ -123,8 +132,44 @@ async def DB(request):
password=PASSWORD)
except ConnectionRefusedError:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)
if vDB is None:
# Actually we want to propagate the error condition, in this case
# if vDB is None:
# pytest.skip("Unable to create a valid DB instance. Skipping.", allow_module_level=True)

# Create table
await vDB.drop_table()
assert not await vDB.table_exists(), Exception("Table exists before creation")
await vDB.create_table()
assert await vDB.table_exists(), Exception("Table does not exist after creation")
# The test will take control upon the yield
yield vDB
# Teardown: Drop table
await vDB.drop_table()


# FIXME: Lots of DRY violations
@pytest_asyncio.fixture # Notice the async aware fixture declaration
async def DB_WINDOWED2(request):
testname = request.node.name
table_name = testname.lower()
print(f'DB setup for test: {testname}. Table name {table_name}', file=sys.stderr)
dummy_model = SentenceTransformer('mock_transformer')
dummy_model.encode.return_value = np.array([1, 2, 3])
try:
vDB = await MessageDB.from_conn_params(
embedding_model=dummy_model,
table_name=table_name,
db_name=DB_NAME,
host=HOST,
port=int(PORT),
user=USER,
password=PASSWORD,
window=2)
except ConnectionRefusedError:
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True)
# Actually we want to propagate the error condition, in this case
# if vDB is None:
# pytest.skip("Unable to create a valid DB instance. Skipping.", allow_module_level=True)

# Create table
await vDB.drop_table()
Expand Down
10 changes: 9 additions & 1 deletion test/embedding/test_pgvector_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
# SPDX-License-Identifier: Apache-2.0
# test/embedding/test_pgvector_data.py
'''
See test/embedding/test_pgvector.py for important notes on running these tests
After setup as described in the README.md for this directory, run the tests with:

pytest test

or, for just this test module:

pytest test/embedding/test_pgvector_data.py

Uses fixtures from conftest.py in current & parent directories
'''

import pytest
Expand Down
44 changes: 21 additions & 23 deletions test/embedding/test_pgvector_doc.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
# SPDX-FileCopyrightText: 2023-present Oori Data <[email protected]>
# SPDX-License-Identifier: Apache-2.0
# test/embedding/test_pgvector.py
# test/embedding/test_pgvector_doc.py
'''
Set up a mock Postgres instance with the following commands
(make sure you don't have anything running on port 0.0.0.0:5432))):
docker pull ankane/pgvector
docker run --name mock-postgres -p 5432:5432 \
-e POSTGRES_USER=mock_user -e POSTGRES_PASSWORD=mock_password -e POSTGRES_DB=mock_db \
-d ankane/pgvector

Then run the tests with:
pytest test
After setup as described in the README.md for this directory, run the tests with:

or
pytest test

pytest test/embedding/test_pgvector.py
or, for just this test module:

Uses fixtures from ../conftest.py
pytest test/embedding/test_pgvector_doc.py

Uses fixtures from conftest.py in current & parent directories
'''

import pytest
Expand Down Expand Up @@ -133,12 +126,14 @@ def encode_tweaker(*args, **kwargs):

# Using limit default
sim_search = await DB.search(text='Text', tags=['tag1', 'tag3'], conjunctive=False)
assert sim_search is not None, Exception("No results returned from filtered search")
assert len(list(sim_search)) == 3, Exception(f"There should be 3 results, received {sim_search}")
assert sim_search is not None, Exception("Null return from filtered search")
sim_search = list(sim_search)
assert len(sim_search) == 3, Exception(f"There should be 3 results, received {sim_search}")

sim_search = await DB.search(text='Text', tags=['tag1', 'tag3'], conjunctive=False, limit=1000)
assert sim_search is not None, Exception("No results returned from filtered search")
assert len(list(sim_search)) == 3, Exception(f"There should be 3 results, received {sim_search}")
assert sim_search is not None, Exception("Null return from filtered search")
sim_search = list(sim_search)
assert len(sim_search) == 3

texts = ['Hello world', 'Hello Dolly', 'Good-Bye to All That']
authors = ['Brian Kernighan', 'Louis Armstrong', 'Robert Graves']
Expand All @@ -148,12 +143,15 @@ def encode_tweaker(*args, **kwargs):
await DB.insert_many(records)

sim_search = await DB.search(text='Hi there!', threshold=0.999, limit=0)
assert sim_search is not None, Exception("No results returned from filtered search")
assert len(list(sim_search)) == 3, Exception(f"There should be 3 results, received {sim_search}")

sim_search = await DB.search(text='Hi there!', threshold=0.999, limit=2)
assert sim_search is not None, Exception("No results returned from filtered search")
assert len(list(sim_search)) == 2, Exception(f"There should be 2 results, received {sim_search}")
assert sim_search is not None, Exception("Null return from filtered search")
sim_search = list(sim_search)
# FIXME: Double-check this expected result
assert len(sim_search) == 10

sim_search = await DB.search(text='Hi there!', threshold=0.5, limit=2)
assert sim_search is not None, Exception("Null return from filtered search")
sim_search = list(sim_search)
assert len(list(sim_search)) == 2

await DB.drop_table()

Expand Down
Loading
Loading