From e5685b0f0abccab077fde69551c36aec5a943db1 Mon Sep 17 00:00:00 2001 From: Uche Ogbuji Date: Thu, 20 Jun 2024 14:16:39 -0600 Subject: [PATCH] Clean up query-building logic --- README.md | 6 +++++- pylib/embedding/pgvector.py | 7 ++++--- pylib/embedding/pgvector_data.py | 22 +++++++++++----------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index e21da6f..9e5d78a 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,10 @@ pytest test If you want to make contributions to the project, please [read these notes](https://github.com/OoriData/OgbujiPT/wiki/Notes-for-contributors). +# Resources + +* [Against mixing environment setup with code](https://huggingface.co/blog/ucheog/separate-env-setup-from-code) + # License Apache 2. For tha culture! @@ -189,7 +193,7 @@ I mentioned the bias to software engineering, but what does this mean? ## Does this support GPU for locally-hosted models -Yes, but you have to make sure you set up your back end LLm server (llama.cpp or text-generation-webui) with GPU, and properly configure the model you load into it. If you can use the webui to query your model and get GPU usage, that will also apply here in OgbujiPT. +Yes, but you have to make sure you set up your back end LLM server (llama.cpp or text-generation-webui) with GPU, and properly configure the model you load into it. Many install guides I've found for Mac, Linux and Windows touch on enabling GPU, but the ecosystem is still in its early days, and helpful resouces can feel scattered. diff --git a/pylib/embedding/pgvector.py b/pylib/embedding/pgvector.py index 86563d0..c09b7e6 100644 --- a/pylib/embedding/pgvector.py +++ b/pylib/embedding/pgvector.py @@ -93,9 +93,6 @@ def __init__(self, embedding_model, table_name: str, pool): else: raise ValueError('embedding_model must be a SentenceTransformer object or None') - self.table_name = table_name - self.pool = pool - @classmethod async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password) -> 'PGVectorHelper': # noqa: E501 ''' @@ -117,6 +114,10 @@ async def from_conn_params(cls, embedding_model, table_name, host, port, db_name async def init_pool(conn): ''' Initialize vector extension for a connection from a pool + + Can be invoked from upstream if they're managing the connection pool themselves + + If they choose to have us create a connection pool (e.g. from_conn_params), it will use this ''' await conn.execute('CREATE EXTENSION IF NOT EXISTS vector;') await register_vector(conn) diff --git a/pylib/embedding/pgvector_data.py b/pylib/embedding/pgvector_data.py index 9063d44..1cb6672 100644 --- a/pylib/embedding/pgvector_data.py +++ b/pylib/embedding/pgvector_data.py @@ -184,22 +184,17 @@ async def search( query_embedding = list(self._embedding_model.encode(text)) # Build where clauses - if threshold is None: - # No where clauses, so don't bother with the WHERE keyword - where_clauses = [] - query_args = [query_embedding] - else: # construct where clauses - where_clauses = [] - query_args = [query_embedding] - if threshold is not None: - query_args.append(threshold) - where_clauses.append(THRESHOLD_WHERE_CLAUSE.format(query_threshold=f'${len(query_args)+1}')) + query_args = [query_embedding] + where_clauses = [] + if threshold is not None: + query_args.append(threshold) + where_clauses.append(THRESHOLD_WHERE_CLAUSE.format(query_threshold=f'${len(query_args)}')) for mf in meta_filter: assert callable(mf), 'All meta_filter items must be callable' clause, pval = mf() - where_clauses.append(clause.format(len(query_args)+1)) query_args.append(pval) + where_clauses.append(clause.format(len(query_args))) where_clauses_str = 'WHERE\n' + 'AND\n'.join(where_clauses) if where_clauses else '' @@ -208,6 +203,11 @@ async def search( else: limit_clause = '' + # print(QUERY_DATA_TABLE.format(table_name=self.table_name, where_clauses=where_clauses_str, + # limit_clause=limit_clause, + # )) + # print(query_args) + # Execute the search via SQL async with self.pool.acquire() as conn: # Uncomment to debug