-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #67 from OoriData/messages-since
More MessageDB improvements, including timestamp constraints in getting chatlog (messages since timestamp)
- Loading branch information
Showing
7 changed files
with
233 additions
and
266 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,19 @@ | ||
# SPDX-FileCopyrightText: 2023-present Oori Data <[email protected]> | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# test/embedding/conftest.py | ||
''' | ||
Fixtures/setup/teardown for embedding tests | ||
''' | ||
|
||
import sys | ||
import os | ||
import pytest | ||
import pytest_asyncio | ||
from unittest.mock import MagicMock, DEFAULT # noqa: F401 | ||
|
||
import numpy as np | ||
from ogbujipt.embedding.pgvector import MessageDB, DataDB, DocDB | ||
|
||
|
||
@pytest.fixture | ||
def CORRECT_STRING(): | ||
|
@@ -67,3 +81,57 @@ def HITHERE_all_MiniLM_L12_v2(): | |
-0.071223356, 0.0019683593, 0.032683503, -0.08899012, 0.10160039, 0.04948275, 0.048017487, -0.046223965, | ||
0.032460734, -0.043729845, 0.030224336, -0.019220904, 0.08223829, 0.03851222, -0.016376046, 0.041965306, | ||
0.0445879, -0.03780432, -0.024826797, 0.014669102, 0.057102628, -0.031820614, 0.0027352672, 0.052658144]) | ||
|
||
# XXX: This stanza to go away once mocking is complete - Kai | ||
HOST = os.environ.get('PG_HOST', 'localhost') | ||
DB_NAME = os.environ.get('PG_DATABASE', 'mock_db') | ||
USER = os.environ.get('PG_USER', 'mock_user') | ||
PASSWORD = os.environ.get('PG_PASSWORD', 'mock_password') | ||
PORT = os.environ.get('PG_PORT', 5432) | ||
|
||
|
||
# XXX: Move to a fixture? | ||
# Definitely don't want to even import SentenceTransformer class due to massive side-effects | ||
class SentenceTransformer(object): | ||
def __init__(self, model_name_or_path): | ||
self.encode = MagicMock() | ||
|
||
|
||
DB_CLASS = { | ||
'test/embedding/test_pgvector_message.py': MessageDB, | ||
'test/embedding/test_pgvector_data.py': DataDB, | ||
'test/embedding/test_pgvector_doc.py': DocDB, | ||
} | ||
|
||
|
||
@pytest_asyncio.fixture # Notice the async aware fixture declaration | ||
async def DB(request): | ||
testname = request.node.name | ||
testfile = request.node.location[0] | ||
table_name = testname.lower() | ||
print(f'DB setup for test: {testname}. Table name {table_name}', file=sys.stderr) | ||
dummy_model = SentenceTransformer('mock_transformer') | ||
dummy_model.encode.return_value = np.array([1, 2, 3]) | ||
try: | ||
vDB = await DB_CLASS[testfile].from_conn_params( | ||
embedding_model=dummy_model, | ||
table_name=table_name, | ||
db_name=DB_NAME, | ||
host=HOST, | ||
port=int(PORT), | ||
user=USER, | ||
password=PASSWORD) | ||
except ConnectionRefusedError: | ||
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) | ||
if vDB is None: | ||
pytest.skip("No Postgres instance made available for test. Skipping.", allow_module_level=True) | ||
|
||
# Create table | ||
await vDB.drop_table() | ||
assert not await vDB.table_exists(), Exception("Table exists before creation") | ||
await vDB.create_table() | ||
assert await vDB.table_exists(), Exception("Table does not exist after creation") | ||
# The test will take control upon the yield | ||
yield vDB | ||
# Teardown: Drop table | ||
await vDB.drop_table() |
Oops, something went wrong.