Skip to content

Commit

Permalink
Clean up query-building logic
Browse files Browse the repository at this point in the history
  • Loading branch information
uogbuji committed Jun 20, 2024
1 parent 2d47b3c commit e5685b0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ pytest test

If you want to make contributions to the project, please [read these notes](https://github.com/OoriData/OgbujiPT/wiki/Notes-for-contributors).

# Resources

* [Against mixing environment setup with code](https://huggingface.co/blog/ucheog/separate-env-setup-from-code)

# License

Apache 2. For tha culture!
Expand Down Expand Up @@ -189,7 +193,7 @@ I mentioned the bias to software engineering, but what does this mean?

## Does this support GPU for locally-hosted models

Yes, but you have to make sure you set up your back end LLm server (llama.cpp or text-generation-webui) with GPU, and properly configure the model you load into it. If you can use the webui to query your model and get GPU usage, that will also apply here in OgbujiPT.
Yes, but you have to make sure you set up your back end LLM server (llama.cpp or text-generation-webui) with GPU, and properly configure the model you load into it.

Many install guides I've found for Mac, Linux and Windows touch on enabling GPU, but the ecosystem is still in its early days, and helpful resouces can feel scattered.

Expand Down
7 changes: 4 additions & 3 deletions pylib/embedding/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,6 @@ def __init__(self, embedding_model, table_name: str, pool):
else:
raise ValueError('embedding_model must be a SentenceTransformer object or None')

self.table_name = table_name
self.pool = pool

@classmethod
async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password) -> 'PGVectorHelper': # noqa: E501
'''
Expand All @@ -117,6 +114,10 @@ async def from_conn_params(cls, embedding_model, table_name, host, port, db_name
async def init_pool(conn):
'''
Initialize vector extension for a connection from a pool
Can be invoked from upstream if they're managing the connection pool themselves
If they choose to have us create a connection pool (e.g. from_conn_params), it will use this
'''
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector;')
await register_vector(conn)
Expand Down
22 changes: 11 additions & 11 deletions pylib/embedding/pgvector_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,22 +184,17 @@ async def search(
query_embedding = list(self._embedding_model.encode(text))

# Build where clauses
if threshold is None:
# No where clauses, so don't bother with the WHERE keyword
where_clauses = []
query_args = [query_embedding]
else: # construct where clauses
where_clauses = []
query_args = [query_embedding]
if threshold is not None:
query_args.append(threshold)
where_clauses.append(THRESHOLD_WHERE_CLAUSE.format(query_threshold=f'${len(query_args)+1}'))
query_args = [query_embedding]
where_clauses = []
if threshold is not None:
query_args.append(threshold)
where_clauses.append(THRESHOLD_WHERE_CLAUSE.format(query_threshold=f'${len(query_args)}'))

for mf in meta_filter:
assert callable(mf), 'All meta_filter items must be callable'
clause, pval = mf()
where_clauses.append(clause.format(len(query_args)+1))
query_args.append(pval)
where_clauses.append(clause.format(len(query_args)))

where_clauses_str = 'WHERE\n' + 'AND\n'.join(where_clauses) if where_clauses else ''

Expand All @@ -208,6 +203,11 @@ async def search(
else:
limit_clause = ''

# print(QUERY_DATA_TABLE.format(table_name=self.table_name, where_clauses=where_clauses_str,
# limit_clause=limit_clause,
# ))
# print(query_args)

# Execute the search via SQL
async with self.pool.acquire() as conn:
# Uncomment to debug
Expand Down

0 comments on commit e5685b0

Please sign in to comment.