Skip to content

Commit

Permalink
Add support for DB schema use for PGVector, for example, constructors…
Browse files Browse the repository at this point in the history
… now take a schema kwarg.
  • Loading branch information
uogbuji committed Sep 24, 2024
1 parent 28fa336 commit 9061398
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 24 deletions.
54 changes: 38 additions & 16 deletions pylib/embedding/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
# import asyncio
# from typing import ClassVar
from functools import partial

from ogbujipt.config import attr_dict

Expand All @@ -31,6 +32,8 @@
register_vector = object() # Set up a dummy to satisfy the type hints
POOL_TYPE = object

DEFAULT_SCHEMA = 'public' # Think 'pg_catalog' is out of date

# ------ SQL queries ---------------------------------------------------------------------------------------------------
# PG only supports proper query arguments (e.g. $1, $2, etc.) for values, not for table or column names
# Table names are checked to be legit sequel table names, and embed_dimension is assured to be an integer
Expand Down Expand Up @@ -59,7 +62,7 @@ class PGVectorHelper:
Connection and pool parameters:
* table_name: PostgresQL table name. Checked to restrict to alphanumeric characters & underscore
* table_name: PostgresQL table name. Checked to restrict to alphanumeric characters, underscore & period (for schema qualification)
* host: Hostname or IP address of the PostgreSQL server. Defaults to UNIX socket if not provided.
* port: Port number at which the PostgreSQL server is listening. Defaults to 5432 if not provided.
* user: User name used to authenticate.
Expand All @@ -68,7 +71,7 @@ class PGVectorHelper:
* pool_min: minimum number of connections to maintain in the pool (used as min_size for create_pool).
* pool_max: maximum number of connections to maintain in the pool (used as max_size for create_pool).
'''
def __init__(self, embedding_model, table_name: str, pool):
def __init__(self, embedding_model, table_name: str, pool, schema=None):
'''
If you don't already have a connection pool, construct using the PGvectorHelper.from_pool_params() method
Expand All @@ -79,17 +82,22 @@ def __init__(self, embedding_model, table_name: str, pool):
table_name: PostgresQL table. Checked to restrict to alphanumeric characters & underscore
pool: asyncpg connection pool instance (asyncpg.pool.Pool)
schema: a schema to which the vector extension has been set. In more sophisticated DB setups
using multiple schemata, you can run into `ERROR: type "vector" does not exist`
unless a schema with the extension is in the search path (via `SET SCHEMA`)
'''
if not PREREQS_AVAILABLE:
raise RuntimeError('pgvector not installed, you can run `pip install pgvector asyncpg`')

if not table_name.replace('_', '').isalnum():
msg = 'table_name must be alphanumeric, with underscore also allowed'
if not table_name.replace('_', '').replace('.', '').isalnum():
msg = f'table_name must be alphanumeric, with optional underscore or periods. Got: {table_name}'
raise ValueError(msg)

self.table_name = table_name
self.embedding_model = embedding_model
self.pool = pool
self.schema = schema

# Check if the provided embedding model is a SentenceTransformer
if (embedding_model.__class__.__name__ == 'SentenceTransformer') and (not None):
Expand All @@ -103,7 +111,7 @@ def __init__(self, embedding_model, table_name: str, pool):

@classmethod
async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password,
pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'PGVectorHelper': # noqa: E501
schema=None, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'PGVectorHelper': # noqa: E501
'''
Create PGVectorHelper instance from connection/pool parameters
Expand All @@ -113,41 +121,55 @@ async def from_conn_params(cls, embedding_model, table_name, host, port, db_name
For details on accepted parameters, See the class docstring
(e.g. run `help(PGVectorHelper)`)
'''
pool = await asyncpg.create_pool(init=PGVectorHelper.init_pool, host=host, port=port, user=user,
password=password, database=db_name, min_size=pool_min, max_size=pool_max)
init_pool_ = partial(PGVectorHelper.init_pool, schema=schema)
pool = await asyncpg.create_pool(init=init_pool_, host=host, port=port, user=user,
password=password, database=db_name, min_size=pool_min, max_size=pool_max)

new_obj = cls(embedding_model, table_name, pool)
new_obj = cls(embedding_model, table_name, pool, schema=schema)
return new_obj

@classmethod
async def from_conn_string(cls, conn_string, embedding_model, table_name,
pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'PGVectorHelper': # noqa: E501
schema=None, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'PGVectorHelper': # noqa: E501
'''
Create PGVectorHelper instance from a connection string AKA DSN
Will create a connection pool, with JSON type handling initialized,
and set that as a pool attribute on the created object as a user convenience.
'''
# https://github.com/MagicStack/asyncpg/blob/0a322a2e4ca1c3c3cf6c2cf22b236a6da6c61680/asyncpg/pool.py#L339
pool = await asyncpg.create_pool(conn_string, init=PGVectorHelper.init_pool,
min_size=pool_min, max_size=pool_max)
init_pool_ = partial(PGVectorHelper.init_pool, schema=schema)
pool = await asyncpg.create_pool(
conn_string, init=init_pool_, min_size=pool_min, max_size=pool_max)

new_obj = cls(embedding_model, table_name, pool)
new_obj = cls(embedding_model, table_name, pool, schema=schema)
return new_obj

@staticmethod
async def init_pool(conn):
async def init_pool(conn, schema=None):
'''
Initialize vector extension for a connection from a pool
Can be invoked from upstream if they're managing the connection pool themselves
If they choose to have us create a connection pool (e.g. from_conn_params), it will use this
'''
schema = schema or DEFAULT_SCHEMA
# TODO: Clean all this up
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector;')
await register_vector(conn)
await conn.set_type_codec( # Register a codec for JSON
'JSON', encoder=json.dumps, decoder=json.loads, schema='pg_catalog')
try:
await register_vector(conn, schema=schema)
except ValueError:
raise RuntimeError('Unable to find the vector type in the database or schema. You might need to enable it.')
try:
await conn.set_type_codec( # Register a codec for JSON
'JSON', encoder=json.dumps, decoder=json.loads, schema=schema)
except ValueError:
try:
await conn.set_type_codec( # Register a codec for JSON
'JSON', encoder=json.dumps, decoder=json.loads)
except ValueError:
pass

# Hmm. Just called count in the qdrant version
async def count_items(self) -> int:
Expand Down
4 changes: 3 additions & 1 deletion pylib/embedding/pgvector_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Table names are checked to be legit sequel table names, and embed_dimension is assured to be an integer

CREATE_TABLE_BASE = '''-- Create a table to hold embedded documents or data
CREATE TABLE IF NOT EXISTS {{table_name}} (
{{set_schema}}CREATE TABLE IF NOT EXISTS {{table_name}} (
id BIGSERIAL PRIMARY KEY,
embedding VECTOR({{embed_dimension}}), -- embedding vectors (array dimension)
content TEXT NOT NULL, -- text content of the chunk
Expand Down Expand Up @@ -78,10 +78,12 @@ async def create_table(self) -> None:
'''
Create the table to hold embedded documents
'''
set_schema = f'SET SCHEMA \'{self.schema}\';\n' if self.schema else ''
async with self.pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
CREATE_DATA_TABLE.format(
set_schema=set_schema,
table_name=self.table_name,
embed_dimension=self._embed_dimension)
)
Expand Down
12 changes: 6 additions & 6 deletions pylib/embedding/pgvector_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
# ------ Class implementations ---------------------------------------------------------------------------------------

class MessageDB(PGVectorHelper):
def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool, window=0):
def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool, window=0, schema=None):
'''
Helper class for messages/chatlog storage and retrieval
Expand All @@ -140,22 +140,22 @@ def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool, wi
https://huggingface.co/sentence-transformers
window (int, optional): number of messages to maintain in the DB. Default is 0 (all messages)
'''
super().__init__(embedding_model, table_name, pool)
super().__init__(embedding_model, table_name, pool, schema=schema)
self.window = window

@classmethod
async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password, window=0,
pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'MessageDB': # noqa: E501
schema=None, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'MessageDB': # noqa: E501
obj = await super().from_conn_params(embedding_model, table_name, host, port, db_name, user, password,
pool_min, pool_max)
schema, pool_min, pool_max)
obj.window = window
return obj

@classmethod
async def from_conn_string(cls, conn_string, embedding_model, table_name, window=0,
pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'MessageDB': # noqa: E501
schema=None, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'MessageDB': # noqa: E501
obj = await super().from_conn_string(conn_string, embedding_model, table_name,
pool_min, pool_max)
schema, pool_min, pool_max)
obj.window = window
return obj

Expand Down
2 changes: 1 addition & 1 deletion pylib/text_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def text_split(text: str, chunk_size: int, separator: str='\n\n', joiner=None, l
# Split the text by the separator
if joiner is None:
separator = f'({separator})'
sep_pat = re.compile(separator)
sep_pat = re.compile(separator, flags=re.M)
raw_split = re.split(sep_pat, text)

# Rapid aid to understanding following logic:
Expand Down
46 changes: 46 additions & 0 deletions test/test_text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,51 @@ def test_zero_overlap(LOREM_IPSUM):
assert len(chunk) <= 100


# TODO: Markup split based on below:

'''
# Hello
Goodbye
## World
Neighborhood
### Spam
Spam spam spam!
# Eggs
Green, with ham
'''

# Check the differences with e.g.
# list(text_split(s, chunk_size=50, separator=r'^(#)'))
# Where chunk_size varies from 5 to 100 & sep is also e.g. r'^#', r'^(#+)', etc.

# Notice,

# Based on https://regex101.com/r/cVCCSg/1 This wacky example:

'''
import re
text= """# Heading 1
## heading 2 (some text in parentheses)
###Heading 3
Don't match the following:
[Some internal link](
#foo)
[Some internal link](
#foo)
[Some internal link](
#foo
)"""
print( re.sub(r'(\[[^][]*]\([^()]*\))|^(#+)(.*)', lambda x: x.group(1) if x.group(1) else "<h{1}>{0}</h{1}>".format(x.group(3), len(x.group(2))), text, flags=re.M) )
'''

if __name__ == '__main__':
raise SystemExit("Attention! Run with pytest")

0 comments on commit 9061398

Please sign in to comment.