Skip to content

Commit

Permalink
Merge pull request #71 from OoriData/69-singleton-pg-connection-pool
Browse files Browse the repository at this point in the history
[#69] Cede back connection pool control to user for PGVector
  • Loading branch information
choccccy authored Feb 20, 2024
2 parents 21f891d + 19128df commit 3fd0f71
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 118 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ Notable changes to OgbujiPT. Format based on [Keep a Changelog](https://keepacha
## [Unreleased]
-->

## [0.7.1] - 20240122
## [0.7.1] - 20240205

### Added

- MessageDB.get_messages() options: `since` (for retrieving messages aftter a timestamp) and `limit` (for limiting the number of messages returned, selecting the most recent)

### Changed

- PGVector users now manage their own connection pool by default
- Better modularization of embeddings test cases; using `conftest.py` more
- `pgvector_message.py` PG table timstamp column no longer a primary key

Expand Down
146 changes: 47 additions & 99 deletions pylib/embedding/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
'''

import json
import asyncio
# import asyncio
# from typing import ClassVar

from ogbujipt.config import attr_dict

Expand All @@ -16,10 +17,12 @@
import asyncpg
from pgvector.asyncpg import register_vector
PREREQS_AVAILABLE = True
POOL_TYPE = asyncpg.pool.Pool
except ImportError:
PREREQS_AVAILABLE = False
asyncpg = None
register_vector = object() # Set up a dummy to satisfy the type hints
POOL_TYPE = object

# ------ SQL queries ---------------------------------------------------------------------------------------------------
# PG only supports proper query arguments (e.g. $1, $2, etc.) for values, not for table or column names
Expand All @@ -41,29 +44,44 @@


class PGVectorHelper:
# XXX: Should pool_params just be required? Can't really construct without going through *something*
# async such as from_conn_params anyway, which will handle ensuring we've been provided pool_params
def __init__(self, embedding_model, table_name: str, pool_params: dict = None):
'''
Create a PGvector helper from an asyncpg connection
'''
Helper class for PGVector operations
If you don't yet have a connection, but have all the parameters,
you can use the PGvectorHelper.from_conn_params() method instead
Construct using PGVectorHelper.from_conn_params() method
Connection and pool parameters:
* table_name: PostgresQL table name. Checked to restrict to alphanumeric characters & underscore
* host: Hostname or IP address of the PostgreSQL server. Defaults to UNIX socket if not provided.
* port: Port number at which the PostgreSQL server is listening. Defaults to 5432 if not provided.
* user: User name used to authenticate.
* password: Password used to authenticate.
* database: Database name to connect to.
* min_max_size: Tuple of minimum and maximum number of connections to maintain in the pool.
Defaults to (10, 20)
'''
def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool):
'''
If you don't already have a connection pool, construct using the PGvectorHelper.from_pool_params() method
Args:
embedding (SentenceTransformer): SentenceTransformer object of your choice
https://huggingface.co/sentence-transformers
table_name: PostgresQL table to store the vector embeddings. Will be checked to restrict to
alphanumeric characters and underscore
table_name: PostgresQL table. Checked to restrict to alphanumeric characters & underscore
apg_conn: asyncpg connection to the database
pool: asyncpg connection pool instance
'''
if not PREREQS_AVAILABLE:
raise RuntimeError('pgvector not installed, you can run `pip install pgvector asyncpg`')

if not table_name.replace('_', '').isalnum():
raise ValueError('table_name must be alphanumeric, with underscore also allowed')
msg = 'table_name must be alphanumeric, with underscore also allowed'
raise ValueError(msg)

self.table_name = table_name
self.embedding_model = embedding_model
self.pool = pool

# Check if the provided embedding model is a SentenceTransformer
if (embedding_model.__class__.__name__ == 'SentenceTransformer') and (not None):
Expand All @@ -76,104 +94,34 @@ def __init__(self, embedding_model, table_name: str, pool_params: dict = None):
raise ValueError('embedding_model must be a SentenceTransformer object or None')

self.table_name = table_name
self.pool_params = pool_params or {}
# asyncpg doesn't allow use of the same pool in different event loops
self.pool_per_loop = {}
self.pool = pool

@classmethod
async def from_conn_params(
cls,
embedding_model,
table_name,
user,
password,
db_name,
host,
port,
min_max_size=DEFAULT_MIN_MAX_CONNECTION_POOL_SIZE,
**conn_params
) -> 'PGVectorHelper':
'''
Create a PGvector helper from connection parameters
For details on accepted parameters, See the `pgvector_connection` docstring
(e.g. run `help(pgvector_connection)`)
async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password) -> 'PGVectorHelper': # noqa: E501
'''
min_size, max_size = min_max_size
# FIXME: Clean up this exception handling
# try:
# import logging
# logging.critical(f'Connecting to {host}:{port} as {user} to {db_name}')
# logging.critical(str(conn_params))
pool_params = dict(
host=host,
port=port,
user=user,
password=password,
database=db_name,
min_size=min_size,
max_size=max_size,
**conn_params
)
# except Exception as e:
# Don't blanket mask the exception. Handle exceptions types in whatever way makes sense
# raise e

obj = cls(embedding_model, table_name, pool_params)
pool = await obj.connection_pool()

# Set up DB extension & type handling
async with pool.acquire() as conn:
# Is this also required per pool? (duplicated from init_pool)
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector;')
# Actually ALSO have to register_vector per pool (duplicated from init_pool)
# https://github.com/pgvector/pgvector-python?tab=readme-ov-file#asyncpg
await register_vector(conn)

await conn.set_type_codec( # Register a codec for JSON
'JSON',
encoder=json.dumps,
decoder=json.loads,
schema='pg_catalog'
)
Create PGVectorHelper instance from connection/pool parameters
# print('PGvector extension created and loaded.')
return obj
Will create a connection pool for you, with JSON type handling initialized,
and set that as a pool attribute on the created object as a user convenience.
async def connection_pool(self):
For details on accepted parameters, See the class docstring
(e.g. run `help(PGVectorHelper)`)
'''
'''
# conn_pool = await asyncpg.create_pool(
# host=host,
# port=port,
# user=user,
# password=password,
# database=db_name,
# min_size=min_size,
# max_size=max_size,
# **conn_params
# )
loop = asyncio.get_event_loop()
if loop in self.pool_per_loop:
pool = self.pool_per_loop[loop]
else:
pool = await asyncpg.create_pool(init=PGVectorHelper.init_pool, **self.pool_params)
self.pool_per_loop[loop] = pool
return pool
pool = await asyncpg.create_pool(init=PGVectorHelper.init_pool, host=host, port=port, user=user,
password=password, database=db_name)

new_obj = cls(embedding_model, table_name, pool)
return new_obj

@staticmethod
async def init_pool(conn):
'''
Initialize the vector extension for a connection from a pool
Initialize vector extension for a connection from a pool
'''
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector;')
await register_vector(conn)
await conn.set_type_codec( # Register a codec for JSON
'JSON',
encoder=json.dumps,
decoder=json.loads,
schema='pg_catalog'
)
'JSON', encoder=json.dumps, decoder=json.loads, schema='pg_catalog')

# Hmm. Just called count in the qdrant version
async def count_items(self) -> int:
Expand All @@ -183,7 +131,7 @@ async def count_items(self) -> int:
Returns:
int: number of documents in the table
'''
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
# Count the number of documents in the table
count = await conn.fetchval(f'SELECT COUNT(*) FROM {self.table_name}')
return count
Expand All @@ -196,7 +144,7 @@ async def table_exists(self) -> bool:
bool: True if the table exists, False otherwise
'''
# Check if the table exists
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
exists = await conn.fetchval(
CHECK_TABLE_EXISTS,
self.table_name
Expand All @@ -210,7 +158,7 @@ async def drop_table(self) -> None:
Exercise caution!
'''
# Delete the table
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
await conn.execute(f'DROP TABLE IF EXISTS {self.table_name};')


Expand Down
24 changes: 12 additions & 12 deletions pylib/embedding/pgvector_data_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@

QUERY_DATA_TABLE = QUERY_TABLE_BASE.format(extra_fields='')

TITLE_WHERE_CLAUSE = 'title = {query_title} -- Equals operator \n'
TITLE_WHERE_CLAUSE = 'title = {query_title} -- Equals operator\n'

PAGE_NUMBERS_WHERE_CLAUSE = 'page_numbers && {query_page_numbers} -- Overlap operator \n'

TAGS_WHERE_CLAUSE_CONJ = 'tags @> {tags} -- Contains operator \n'
TAGS_WHERE_CLAUSE_DISJ = 'tags && {tags} -- Overlap operator \n'
TAGS_WHERE_CLAUSE_CONJ = 'tags @> {tags} -- Contains operator\n'
TAGS_WHERE_CLAUSE_DISJ = 'tags && {tags} -- Overlap operator\n'

THRESHOLD_WHERE_CLAUSE = '{query_threshold} >= cosine_similarity \n'
THRESHOLD_WHERE_CLAUSE = 'cosine_similarity >= {query_threshold}\n'
# ------ SQL queries ---------------------------------------------------------------------------------------------------


Expand All @@ -91,7 +91,7 @@ async def create_table(self) -> None:
'''
Create the table to hold embedded documents
'''
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
CREATE_DATA_TABLE.format(
Expand All @@ -115,7 +115,7 @@ async def insert(
# Get the embedding of the content as a PGvector compatible list
content_embedding = self._embedding_model.encode(content)

async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
INSERT_DATA.format(table_name=self.table_name),
Expand All @@ -136,7 +136,7 @@ async def insert_many(
Args:
content_list: List of tuples, each of the form: (content, tags)
'''
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.executemany(
INSERT_DATA.format(table_name=self.table_name),
Expand Down Expand Up @@ -224,7 +224,7 @@ async def search(
limit_clause = ''

# Execute the search via SQL
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
search_results = await conn.fetch(
QUERY_DATA_TABLE.format(
table_name=self.table_name,
Expand All @@ -242,7 +242,7 @@ async def create_table(self) -> None:
'''
Create the table to hold embedded documents
'''
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
CREATE_DOC_TABLE.format(
Expand Down Expand Up @@ -272,7 +272,7 @@ async def insert(
# Get the embedding of the content as a PGvector compatible list
content_embedding = self._embedding_model.encode(content)

async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
INSERT_DOCS.format(table_name=self.table_name),
Expand All @@ -295,7 +295,7 @@ async def insert_many(
Args:
content_list: List of tuples, each of the form: (content, tags, title, page_numbers)
'''
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.executemany(
INSERT_DOCS.format(table_name=self.table_name),
Expand Down Expand Up @@ -391,7 +391,7 @@ async def search(
limit_clause = ''

# Execute the search via SQL
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
search_results = await conn.fetch(
QUERY_DOC_TABLE.format(
table_name=self.table_name,
Expand Down
12 changes: 6 additions & 6 deletions pylib/embedding/pgvector_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
class MessageDB(PGVectorHelper):
''' Specialize PGvectorHelper for messages, e.g. chatlogs '''
async def create_table(self):
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
CREATE_MESSAGE_TABLE.format(
Expand Down Expand Up @@ -126,7 +126,7 @@ async def insert(
# Get the embedding of the content as a PGvector compatible list
content_embedding = self._embedding_model.encode(content)

async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
INSERT_MESSAGE.format(table_name=self.table_name),
Expand All @@ -150,7 +150,7 @@ async def insert_many(
Args:
content_list: List of tuples, each of the form: (history_key, role, text, timestamp, metadata)
'''
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.executemany(
INSERT_MESSAGE.format(table_name=self.table_name),
Expand All @@ -170,7 +170,7 @@ async def clear(
Args:
history_key (str): history key (unique identifier) to match
'''
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
CLEAR_MESSAGE.format(
Expand All @@ -197,7 +197,7 @@ async def get_messages(
generates asyncpg.Record instances of resulting messages
'''
qparams = [history_key]
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
if not isinstance(history_key, UUID):
history_key = UUID(history_key)
# msg = f'history_key must be a UUID, not {type(history_key)} ({history_key}))'
Expand Down Expand Up @@ -271,7 +271,7 @@ async def search(
query_embedding = list(self._embedding_model.encode(text))

# Search the table
async with (await self.connection_pool()).acquire() as conn:
async with self.pool.acquire() as conn:
records = await conn.fetch(
SEMANTIC_QUERY_MESSAGE_TABLE.format(
table_name=self.table_name
Expand Down

0 comments on commit 3fd0f71

Please sign in to comment.