Skip to content

Commit

Permalink
Add threshold param to MessageDB.search(). Assorted cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
uogbuji committed Apr 15, 2024
1 parent 80adc65 commit 0a21613
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 56 deletions.
26 changes: 8 additions & 18 deletions pylib/embedding/pgvector_data_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ogbujipt.embedding.pgvector import PGVectorHelper, asyncpg, process_search_response

__all__ = ['DocDB']
__all__ = ['DataDB', 'DocDB']

# ------ SQL queries ---------------------------------------------------------------------------------------------------
# PG only supports proper query arguments (e.g. $1, $2, etc.) for values, not for table or column names
Expand Down Expand Up @@ -177,11 +177,8 @@ async def search(
generator which yields the rows os the query results ass attributable dicts
'''
if threshold is not None:
if not isinstance(threshold, float):
raise TypeError('threshold must be a float')
if (threshold < 0) or (threshold > 1):
raise ValueError('threshold must be between 0 and 1')

if not isinstance(threshold, float) or (threshold < 0) or (threshold > 1):
raise TypeError('threshold must be a float between 0.0 and 1.0')
if not isinstance(limit, int):
raise TypeError('limit must be an integer') # Guard against injection

Expand Down Expand Up @@ -218,10 +215,8 @@ async def search(
# Execute the search via SQL
async with self.pool.acquire() as conn:
search_results = await conn.fetch(
QUERY_DATA_TABLE.format(
table_name=self.table_name,
where_clauses=where_clauses,
limit_clause=limit_clause,
QUERY_DATA_TABLE.format(table_name=self.table_name, where_clauses=where_clauses,
limit_clause=limit_clause,
),
*query_args
)
Expand All @@ -237,10 +232,7 @@ async def create_table(self) -> None:
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
CREATE_DOC_TABLE.format(
table_name=self.table_name,
embed_dimension=self._embed_dimension)
)
CREATE_DOC_TABLE.format(table_name=self.table_name, embed_dimension=self._embed_dimension))

async def insert(
self,
Expand Down Expand Up @@ -336,10 +328,8 @@ async def search(
warnings.warn('query_tags is deprecated. Use tags instead.', DeprecationWarning)
tags = query_tags
if threshold is not None:
if not isinstance(threshold, float):
raise TypeError('threshold must be a float')
if (threshold < 0) or (threshold > 1):
raise ValueError('threshold must be between 0 and 1')
if not isinstance(threshold, float) or (threshold < 0) or (threshold > 1):
raise TypeError('threshold must be a float between 0.0 and 1.0')

if not isinstance(limit, int):
raise TypeError('limit must be an integer') # Guard against injection
Expand Down
113 changes: 82 additions & 31 deletions pylib/embedding/pgvector_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,35 @@
role,
content,
metadata
FROM
{table_name}
WHERE
history_key = $2
FROM {table_name}
{where_clauses}
ORDER BY
cosine_similarity DESC
LIMIT $3;
{limit_clause};
'''

# The cosine_similarity alias is not available in the WHERE clause, so use a nested SELECT
SEMANTIC_QUERY_MESSAGE_TABLE = '''-- Find messages with closest semantic similarity
SELECT
cosine_similarity,
ts,
role,
content,
metadata
FROM
(SELECT
history_key,
1 - (embedding <=> $1) AS cosine_similarity,
ts,
role,
content,
metadata FROM {table_name}) AS main
{where_clauses}
{limit_clause};
'''

THRESHOLD_WHERE_CLAUSE = 'main.cosine_similarity >= {query_threshold}\n'

DELETE_OLDEST_MESSAGES = '''-- Delete oldest messages for given history key, such that only the newest N messages remain
DELETE FROM {table_name} t_outer
WHERE
Expand Down Expand Up @@ -262,26 +282,32 @@ async def get_messages(
Returns:
generates asyncpg.Record instances of resulting messages
'''
if not isinstance(history_key, UUID):
history_key = UUID(history_key)
# msg = f'history_key must be a UUID, not {type(history_key)} ({history_key}))'
# raise TypeError(msg)
if not isinstance(since, datetime) and since is not None:
msg = 'since must be a datetime or None'
raise TypeError(msg)
if not isinstance(limit, int):
raise TypeError('limit must be an integer') # Guard against injection

qparams = [history_key]
# Build query
if since:
# Don't really need the ${len(qparams) + N} thing here (first optional), but used for consistency
since_clause = f' AND ts > ${len(qparams) + 1}'
qparams.append(since)
else:
since_clause = ''
if limit:
limit_clause = f'LIMIT ${len(qparams) + 1}'
qparams.append(limit)
else:
limit_clause = ''

# Execute
async with self.pool.acquire() as conn:
if not isinstance(history_key, UUID):
history_key = UUID(history_key)
# msg = f'history_key must be a UUID, not {type(history_key)} ({history_key}))'
# raise TypeError(msg)
if not isinstance(since, datetime) and since is not None:
msg = 'since must be a datetime or None'
raise TypeError(msg)
if since:
# Really don't need the ${len(qparams) + 1} trick here, but used for consistency
since_clause = f' AND ts > ${len(qparams) + 1}'
qparams.append(since)
else:
since_clause = ''
if limit:
limit_clause = f'LIMIT ${len(qparams) + 1}'
qparams.append(limit)
else:
limit_clause = ''
message_records = await conn.fetch(
RETURN_MESSAGES_BY_HISTORY_KEY.format(
table_name=self.table_name,
Expand All @@ -303,6 +329,7 @@ async def search(
history_key: UUID,
text: str,
since: datetime | None = None,
threshold: float | None = None,
limit: int = 1
) -> list[asyncpg.Record]:
'''
Expand All @@ -317,21 +344,45 @@ async def search(
list[asyncpg.Record]: list of search results
(asyncpg.Record objects are similar to dicts, but allow for attribute-style access)
'''
# Type checks
if threshold is not None:
if not isinstance(threshold, float) or (threshold < 0) or (threshold > 1):
raise TypeError('threshold must be a float between 0.0 and 1.0')
if not isinstance(limit, int):
raise TypeError('limit must be an integer')
if not isinstance(history_key, UUID):
history_key = UUID(history_key)
if not isinstance(since, datetime) and since is not None:
msg = 'since must be a datetime or None'
raise TypeError(msg)

# Get the embedding of the query string as a PGvector compatible list
# Get embedding of the query string as a PGvector compatible list
query_embedding = list(self._embedding_model.encode(text))

# Search the table
# Build query
clauses = ['main.history_key = $2\n']
qparams = [query_embedding, history_key]
if since is not None:
# Don't really need the ${len(qparams) + N} thing here (first optional), but used for consistency
clauses.append(f'ts > ${len(qparams) + 1}')
qparams.append(since)
if threshold is not None:
clauses.append(THRESHOLD_WHERE_CLAUSE.format(query_threshold=f'${len(qparams) + 1}'))
qparams.append(threshold)
clauses = 'AND\n'.join(clauses) # TODO: move this into the fstring below after py3.12
where_clauses = f'WHERE\n{clauses}'
limit_clause = f'LIMIT ${len(qparams) + 1}'
qparams.append(limit)

# Execute
async with self.pool.acquire() as conn:
records = await conn.fetch(
message_records = await conn.fetch(
SEMANTIC_QUERY_MESSAGE_TABLE.format(
table_name=self.table_name
table_name=self.table_name,
where_clauses=where_clauses,
limit_clause=limit_clause
),
query_embedding,
history_key,
limit
*qparams
)

search_results = [
Expand All @@ -342,7 +393,7 @@ async def search(
'metadata': record['metadata'],
'cosine_similarity': record['cosine_similarity']
}
for record in records
for record in message_records
]

return process_search_response(search_results)
4 changes: 2 additions & 2 deletions pylib/llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,9 @@ class openai_chat_api(openai_api):
You need to set an OpenAI API key in your environment, or pass it in, for this next example
>>> from ogbujipt.llm_wrapper import openai_chat_api, prompt_to_chat
>>> import asyncio; from ogbujipt.llm_wrapper import openai_chat_api, prompt_to_chat
>>> llm_api = openai_chat_api(model='gpt-3.5-turbo')
>>> resp = llm_api(prompt_to_chat('Knock knock!'))
>>> resp = asyncio.run(llm_api(prompt_to_chat('Knock knock!')))
>>> resp.first_choice_text
'''
def call(self, messages, api_func=None, **kwargs):
Expand Down
18 changes: 18 additions & 0 deletions test/embedding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,21 @@ docker run --name mock-postgres -p 5432:5432 \
-e POSTGRES_USER=mock_user -e POSTGRES_PASSWORD=mock_password -e POSTGRES_DB=mock_db \
-d ankane/pgvector
```

You can also use another PGVector setup, but then you need the following environment variables:

* `PG_HOST`
* `PG_DATABASE`
* `PG_USER`
* `PG_PASSWORD`
* `PG_PORT`

e.g.:

```sh
PG_HOST="localhost"
PG_PORT="5432"
PG_USER="username"
PG_PASSWORD="passwd"
PG_DATABASE="PeeGeeVee"
```
4 changes: 0 additions & 4 deletions test/embedding/test_pgvector_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def __init__(self, model_name_or_path):

@pytest.mark.asyncio
async def test_PGv_embed_pacer(DB):
dummy_model = SentenceTransformer('mock_transformer')
dummy_model.encode.return_value = np.array([1, 2, 3])
# Insert data
for index, text in enumerate(pacer_copypasta): # For each line in the copypasta
await DB.insert( # Insert the line into the table
Expand All @@ -61,8 +59,6 @@ async def test_PGv_embed_pacer(DB):

@pytest.mark.asyncio
async def test_PGv_embed_many_pacer(DB):
dummy_model = SentenceTransformer('mock_transformer')
dummy_model.encode.return_value = np.array([1, 2, 3])
# Insert data using insert_many()
documents = (
(
Expand Down
27 changes: 26 additions & 1 deletion test/embedding/test_pgvector_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
from unittest.mock import MagicMock, DEFAULT # noqa: F401

# import numpy as np
import numpy as np

# XXX: Move to a fixture?
# Definitely don't want to even import SentenceTransformer class due to massive side-effects
Expand Down Expand Up @@ -164,5 +164,30 @@ async def test_get_messages_since(DB, MESSAGES):
assert len(results) == 1, Exception('Incorrect number of messages returned from chatlog')


@pytest.mark.asyncio
async def test_search_threshold(DB, MESSAGES):
dummy_model = SentenceTransformer('mock_transformer')
def encode_tweaker(*args, **kwargs):
# Note: cosine similarity of [1, 2, 3] & [100, 300, 500] appears to be ~ 0.9939
if args[0].startswith('Hi'):
return np.array([100, 300, 500])
else:
return np.array([1, 2, 3])

dummy_model.encode.side_effect = encode_tweaker
# Need to replace the default encoder set up by the fixture
DB._embedding_model = dummy_model

await DB.insert_many(MESSAGES)

history_key, role, content, timestamp, metadata = MESSAGES[0]

results = list(await DB.search(history_key, 'Hi!', threshold=0.999))
assert results is not None and len(results) == 0

results = list(await DB.search(history_key, 'Hi!', threshold=0.5))
assert results is not None and len(results) == 1


if __name__ == '__main__':
raise SystemExit("Attention! Run with pytest")

0 comments on commit 0a21613

Please sign in to comment.