From 2d47b3c65afac6fe6aeac13930cdc650311a999d Mon Sep 17 00:00:00 2001 From: Uche Ogbuji Date: Fri, 14 Jun 2024 15:16:01 -0600 Subject: [PATCH] [#79] Implement embedding.pgvector.match_oneof() and fix bug in embedding.pgvector.match_exact() --- demo/chat_web_selects.py | 3 +- pylib/embedding/pgvector.py | 26 ++++++++++- pylib/embedding/pgvector_data.py | 9 ++++ test/embedding/test_pgvector_data.py | 65 +++++++++++++++++++++++++--- 4 files changed, 93 insertions(+), 10 deletions(-) diff --git a/demo/chat_web_selects.py b/demo/chat_web_selects.py index 57dfb58..c72d265 100644 --- a/demo/chat_web_selects.py +++ b/demo/chat_web_selects.py @@ -139,8 +139,7 @@ async def async_main(oapi, sites, verbose, limit, chunk_size, chunk_overlap, que top_p=1, # AKA nucleus sampling; can increase generated text diversity frequency_penalty=0, # Favor more or less frequent tokens presence_penalty=1, # Prefer new, previously unused tokens - temperature=0.1 - ) + temperature=0.1) indicator_task = asyncio.create_task(indicate_progress()) llm_task = asyncio.Task(oapi(messages, **model_params)) diff --git a/pylib/embedding/pgvector.py b/pylib/embedding/pgvector.py index 668bba7..86563d0 100644 --- a/pylib/embedding/pgvector.py +++ b/pylib/embedding/pgvector.py @@ -200,9 +200,8 @@ def match_exact(key, val): Filter specifier to only return rows where the given top-level key exists in metadata, and matches the given value ''' assert key.isalnum() - cast = '' if isinstance(val, str): - assert val.isalnum() + cast = '' elif isinstance(val, bool): cast = '::boolean' elif isinstance(val, int): @@ -214,6 +213,29 @@ def apply(): return apply +def match_oneof(key, options: tuple[str]): + ''' + Filter specifier to only return rows where the given top-level key exists in metadata, + and matches one of the given values + ''' + options = tuple(options) + assert options + assert key.isalnum() + option1 = options[0] + if isinstance(option1, str): + cast = '' + if isinstance(option1, bool): + cast = '::boolean' + elif isinstance(option1, int): + cast = '::int' + elif isinstance(option1, float): + cast = '::float' + def apply(): + # return f'(metadata ->> \'{key}\'){cast} IN ${{}}', options + return f'(metadata ->> \'{key}\'){cast} = ANY(${{}})', options + return apply + + # Down here to avoid circular imports from ogbujipt.embedding.pgvector_data import DataDB # noqa: E402 F401 from ogbujipt.embedding.pgvector_message import MessageDB # noqa: E402 F401 diff --git a/pylib/embedding/pgvector_data.py b/pylib/embedding/pgvector_data.py index fac213a..9063d44 100644 --- a/pylib/embedding/pgvector_data.py +++ b/pylib/embedding/pgvector_data.py @@ -210,6 +210,15 @@ async def search( # Execute the search via SQL async with self.pool.acquire() as conn: + # Uncomment to debug + # from asyncpg import utils + # print(await utils._mogrify( + # conn, + # QUERY_DATA_TABLE.format(table_name=self.table_name, where_clauses=where_clauses_str, + # limit_clause=limit_clause, + # ), + # query_args + # )) search_results = await conn.fetch( QUERY_DATA_TABLE.format(table_name=self.table_name, where_clauses=where_clauses_str, limit_clause=limit_clause, diff --git a/test/embedding/test_pgvector_data.py b/test/embedding/test_pgvector_data.py index 8c07ea7..d5ab71c 100644 --- a/test/embedding/test_pgvector_data.py +++ b/test/embedding/test_pgvector_data.py @@ -18,15 +18,20 @@ import numpy as np -from ogbujipt.embedding.pgvector import match_exact +from ogbujipt.embedding.pgvector import match_exact, match_oneof KG_STATEMENTS = [ # Demo data - ("👤 Alikiba `releases_single` 💿 'Yalaiti'", {'url': 'https://njok.com/yalaiti-lyrics/', 'primary': True}), - ("👤 Sabah Salum `featured_in` 💿 'Yalaiti'", {'url': 'https://njok.com/yalaiti-lyrics/', 'primary': False}), - ('👤 Kukbeatz `collab_with` 👤 Ruger', {'url': 'https://njok.com/all-of-us-lyrics/', 'primary': True}), - ('💿 All of Us `song_by` 👤 Kukbeatz & Ruger', {'url': 'https://njok.com/all-of-us-lyrics/', 'primary': False}), - ('👤 Blaqbonez `collab_with` 👤 Fireboy DML', {'url': 'https://njok.com/fireboy-dml-collab/', 'primary': True}) + ("👤 Alikiba `releases` 💿 'Yalaiti'", + {'url': 'https://njok.com/yalaiti-lyrics/', 'primary': True, 'when': '2023-11-29'}), + ("👤 Sabah Salum `featured_in` 💿 'Yalaiti'", + {'url': 'https://njok.com/yalaiti-lyrics/', 'primary': False, 'when': '2023-11-29'}), + ('👤 Kukbeatz `collab_with` 👤 Ruger', + {'url': 'https://njok.com/all-of-us-lyrics/', 'primary': True, 'when': '2023-11-25'}), + ('💿 All of Us `song_by` 👤 Kukbeatz & Ruger', + {'url': 'https://njok.com/all-of-us-lyrics/', 'primary': False, 'when': '2023-11-25'}), + ('👤 Blaqbonez `collab_with` 👤 Fireboy DML', + {'url': 'https://njok.com/fireboy-dml-collab/', 'primary': True, 'when': '2023-11-19'}) ] @@ -112,5 +117,53 @@ async def test_search_with_filter(DB): assert len(result) == 1 +@pytest.mark.asyncio +async def test_search_with_date_filter(DB): + dummy_model = SentenceTransformer('mock_transformer') + dummy_model.encode.return_value = np.array([1, 2, 3]) + + # item1_text = KG_STATEMENTS[0][0] + + # Insert data using insert_many() + # dataset = ((text, metadata) for (text, metadata) in KG_STATEMENTS) + + await DB.insert_many(KG_STATEMENTS) + + # search table with perfect match, but only where primary is set to True + primary_filt = match_exact('when', '2023-11-29') + result = list(await DB.search(text='Kukbeatz and Ruger', meta_filter=primary_filt)) + assert len(result) == 2 + + primary_filt = match_exact('when', '2023-11-19') + result = list(await DB.search(text='Kukbeatz and Ruger', meta_filter=primary_filt)) + assert len(result) == 1 + + +@pytest.mark.asyncio +async def test_search_with_date_filter_match_oneof(DB): + dummy_model = SentenceTransformer('mock_transformer') + dummy_model.encode.return_value = np.array([1, 2, 3]) + + # item1_text = KG_STATEMENTS[0][0] + + # Insert data using insert_many() + # dataset = ((text, metadata) for (text, metadata) in KG_STATEMENTS) + + await DB.insert_many(KG_STATEMENTS) + + # search table with perfect match, but only where primary is set to True + primary_filt = match_oneof('when', ('2023-11-29',)) + result = list(await DB.search(text='Kukbeatz and Ruger', meta_filter=primary_filt)) + assert len(result) == 2 + + primary_filt = match_oneof('when', ('2023-11-29', '2023-11-19')) + result = list(await DB.search(text='Kukbeatz and Ruger', meta_filter=primary_filt)) + assert len(result) == 3 + + primary_filt = match_oneof('when', ('2023-11-29', '2023-11-25', '2023-11-19')) + result = list(await DB.search(text='Kukbeatz and Ruger', meta_filter=primary_filt)) + assert len(result) == 5 + + if __name__ == '__main__': raise SystemExit("Attention! Run with pytest")