Skip to content

Commit

Permalink
[#79] Implement embedding.pgvector.match_oneof() and fix bug in embed…
Browse files Browse the repository at this point in the history
…ding.pgvector.match_exact()
  • Loading branch information
uogbuji committed Jun 14, 2024
1 parent 561f060 commit 2d47b3c
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 10 deletions.
3 changes: 1 addition & 2 deletions demo/chat_web_selects.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ async def async_main(oapi, sites, verbose, limit, chunk_size, chunk_overlap, que
top_p=1, # AKA nucleus sampling; can increase generated text diversity
frequency_penalty=0, # Favor more or less frequent tokens
presence_penalty=1, # Prefer new, previously unused tokens
temperature=0.1
)
temperature=0.1)

indicator_task = asyncio.create_task(indicate_progress())
llm_task = asyncio.Task(oapi(messages, **model_params))
Expand Down
26 changes: 24 additions & 2 deletions pylib/embedding/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ def match_exact(key, val):
Filter specifier to only return rows where the given top-level key exists in metadata, and matches the given value
'''
assert key.isalnum()
cast = ''
if isinstance(val, str):
assert val.isalnum()
cast = ''
elif isinstance(val, bool):
cast = '::boolean'
elif isinstance(val, int):
Expand All @@ -214,6 +213,29 @@ def apply():
return apply


def match_oneof(key, options: tuple[str]):
'''
Filter specifier to only return rows where the given top-level key exists in metadata,
and matches one of the given values
'''
options = tuple(options)
assert options
assert key.isalnum()
option1 = options[0]
if isinstance(option1, str):
cast = ''
if isinstance(option1, bool):
cast = '::boolean'
elif isinstance(option1, int):
cast = '::int'
elif isinstance(option1, float):
cast = '::float'
def apply():
# return f'(metadata ->> \'{key}\'){cast} IN ${{}}', options
return f'(metadata ->> \'{key}\'){cast} = ANY(${{}})', options
return apply


# Down here to avoid circular imports
from ogbujipt.embedding.pgvector_data import DataDB # noqa: E402 F401
from ogbujipt.embedding.pgvector_message import MessageDB # noqa: E402 F401
9 changes: 9 additions & 0 deletions pylib/embedding/pgvector_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ async def search(

# Execute the search via SQL
async with self.pool.acquire() as conn:
# Uncomment to debug
# from asyncpg import utils
# print(await utils._mogrify(
# conn,
# QUERY_DATA_TABLE.format(table_name=self.table_name, where_clauses=where_clauses_str,
# limit_clause=limit_clause,
# ),
# query_args
# ))
search_results = await conn.fetch(
QUERY_DATA_TABLE.format(table_name=self.table_name, where_clauses=where_clauses_str,
limit_clause=limit_clause,
Expand Down
65 changes: 59 additions & 6 deletions test/embedding/test_pgvector_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@

import numpy as np

from ogbujipt.embedding.pgvector import match_exact
from ogbujipt.embedding.pgvector import match_exact, match_oneof


KG_STATEMENTS = [ # Demo data
("👤 Alikiba `releases_single` 💿 'Yalaiti'", {'url': 'https://njok.com/yalaiti-lyrics/', 'primary': True}),
("👤 Sabah Salum `featured_in` 💿 'Yalaiti'", {'url': 'https://njok.com/yalaiti-lyrics/', 'primary': False}),
('👤 Kukbeatz `collab_with` 👤 Ruger', {'url': 'https://njok.com/all-of-us-lyrics/', 'primary': True}),
('💿 All of Us `song_by` 👤 Kukbeatz & Ruger', {'url': 'https://njok.com/all-of-us-lyrics/', 'primary': False}),
('👤 Blaqbonez `collab_with` 👤 Fireboy DML', {'url': 'https://njok.com/fireboy-dml-collab/', 'primary': True})
("👤 Alikiba `releases` 💿 'Yalaiti'",
{'url': 'https://njok.com/yalaiti-lyrics/', 'primary': True, 'when': '2023-11-29'}),
("👤 Sabah Salum `featured_in` 💿 'Yalaiti'",
{'url': 'https://njok.com/yalaiti-lyrics/', 'primary': False, 'when': '2023-11-29'}),
('👤 Kukbeatz `collab_with` 👤 Ruger',
{'url': 'https://njok.com/all-of-us-lyrics/', 'primary': True, 'when': '2023-11-25'}),
('💿 All of Us `song_by` 👤 Kukbeatz & Ruger',
{'url': 'https://njok.com/all-of-us-lyrics/', 'primary': False, 'when': '2023-11-25'}),
('👤 Blaqbonez `collab_with` 👤 Fireboy DML',
{'url': 'https://njok.com/fireboy-dml-collab/', 'primary': True, 'when': '2023-11-19'})
]


Expand Down Expand Up @@ -112,5 +117,53 @@ async def test_search_with_filter(DB):
assert len(result) == 1


@pytest.mark.asyncio
async def test_search_with_date_filter(DB):
dummy_model = SentenceTransformer('mock_transformer')
dummy_model.encode.return_value = np.array([1, 2, 3])

# item1_text = KG_STATEMENTS[0][0]

# Insert data using insert_many()
# dataset = ((text, metadata) for (text, metadata) in KG_STATEMENTS)

await DB.insert_many(KG_STATEMENTS)

# search table with perfect match, but only where primary is set to True
primary_filt = match_exact('when', '2023-11-29')
result = list(await DB.search(text='Kukbeatz and Ruger', meta_filter=primary_filt))
assert len(result) == 2

primary_filt = match_exact('when', '2023-11-19')
result = list(await DB.search(text='Kukbeatz and Ruger', meta_filter=primary_filt))
assert len(result) == 1


@pytest.mark.asyncio
async def test_search_with_date_filter_match_oneof(DB):
dummy_model = SentenceTransformer('mock_transformer')
dummy_model.encode.return_value = np.array([1, 2, 3])

# item1_text = KG_STATEMENTS[0][0]

# Insert data using insert_many()
# dataset = ((text, metadata) for (text, metadata) in KG_STATEMENTS)

await DB.insert_many(KG_STATEMENTS)

# search table with perfect match, but only where primary is set to True
primary_filt = match_oneof('when', ('2023-11-29',))
result = list(await DB.search(text='Kukbeatz and Ruger', meta_filter=primary_filt))
assert len(result) == 2

primary_filt = match_oneof('when', ('2023-11-29', '2023-11-19'))
result = list(await DB.search(text='Kukbeatz and Ruger', meta_filter=primary_filt))
assert len(result) == 3

primary_filt = match_oneof('when', ('2023-11-29', '2023-11-25', '2023-11-19'))
result = list(await DB.search(text='Kukbeatz and Ruger', meta_filter=primary_filt))
assert len(result) == 5


if __name__ == '__main__':
raise SystemExit("Attention! Run with pytest")

0 comments on commit 2d47b3c

Please sign in to comment.