diff --git a/pylib/embedding/pgvector.py b/pylib/embedding/pgvector.py index 3b1625e..d8fdfd0 100644 --- a/pylib/embedding/pgvector.py +++ b/pylib/embedding/pgvector.py @@ -14,6 +14,7 @@ import json # import asyncio # from typing import ClassVar +from functools import partial from ogbujipt.config import attr_dict @@ -31,6 +32,8 @@ register_vector = object() # Set up a dummy to satisfy the type hints POOL_TYPE = object +DEFAULT_SCHEMA = 'public' # Think 'pg_catalog' is out of date + # ------ SQL queries --------------------------------------------------------------------------------------------------- # PG only supports proper query arguments (e.g. $1, $2, etc.) for values, not for table or column names # Table names are checked to be legit sequel table names, and embed_dimension is assured to be an integer @@ -59,7 +62,7 @@ class PGVectorHelper: Connection and pool parameters: - * table_name: PostgresQL table name. Checked to restrict to alphanumeric characters & underscore + * table_name: PostgresQL table name. Checked to restrict to alphanumeric characters, underscore & period (for schema qualification) * host: Hostname or IP address of the PostgreSQL server. Defaults to UNIX socket if not provided. * port: Port number at which the PostgreSQL server is listening. Defaults to 5432 if not provided. * user: User name used to authenticate. @@ -68,7 +71,7 @@ class PGVectorHelper: * pool_min: minimum number of connections to maintain in the pool (used as min_size for create_pool). * pool_max: maximum number of connections to maintain in the pool (used as max_size for create_pool). ''' - def __init__(self, embedding_model, table_name: str, pool): + def __init__(self, embedding_model, table_name: str, pool, schema=None): ''' If you don't already have a connection pool, construct using the PGvectorHelper.from_pool_params() method @@ -79,17 +82,22 @@ def __init__(self, embedding_model, table_name: str, pool): table_name: PostgresQL table. Checked to restrict to alphanumeric characters & underscore pool: asyncpg connection pool instance (asyncpg.pool.Pool) + + schema: a schema to which the vector extension has been set. In more sophisticated DB setups + using multiple schemata, you can run into `ERROR: type "vector" does not exist` + unless a schema with the extension is in the search path (via `SET SCHEMA`) ''' if not PREREQS_AVAILABLE: raise RuntimeError('pgvector not installed, you can run `pip install pgvector asyncpg`') - if not table_name.replace('_', '').isalnum(): - msg = 'table_name must be alphanumeric, with underscore also allowed' + if not table_name.replace('_', '').replace('.', '').isalnum(): + msg = f'table_name must be alphanumeric, with optional underscore or periods. Got: {table_name}' raise ValueError(msg) self.table_name = table_name self.embedding_model = embedding_model self.pool = pool + self.schema = schema # Check if the provided embedding model is a SentenceTransformer if (embedding_model.__class__.__name__ == 'SentenceTransformer') and (not None): @@ -103,7 +111,7 @@ def __init__(self, embedding_model, table_name: str, pool): @classmethod async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password, - pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'PGVectorHelper': # noqa: E501 + schema=None, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'PGVectorHelper': # noqa: E501 ''' Create PGVectorHelper instance from connection/pool parameters @@ -113,15 +121,16 @@ async def from_conn_params(cls, embedding_model, table_name, host, port, db_name For details on accepted parameters, See the class docstring (e.g. run `help(PGVectorHelper)`) ''' - pool = await asyncpg.create_pool(init=PGVectorHelper.init_pool, host=host, port=port, user=user, - password=password, database=db_name, min_size=pool_min, max_size=pool_max) + init_pool_ = partial(PGVectorHelper.init_pool, schema=schema) + pool = await asyncpg.create_pool(init=init_pool_, host=host, port=port, user=user, + password=password, database=db_name, min_size=pool_min, max_size=pool_max) - new_obj = cls(embedding_model, table_name, pool) + new_obj = cls(embedding_model, table_name, pool, schema=schema) return new_obj @classmethod async def from_conn_string(cls, conn_string, embedding_model, table_name, - pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'PGVectorHelper': # noqa: E501 + schema=None, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'PGVectorHelper': # noqa: E501 ''' Create PGVectorHelper instance from a connection string AKA DSN @@ -129,14 +138,15 @@ async def from_conn_string(cls, conn_string, embedding_model, table_name, and set that as a pool attribute on the created object as a user convenience. ''' # https://github.com/MagicStack/asyncpg/blob/0a322a2e4ca1c3c3cf6c2cf22b236a6da6c61680/asyncpg/pool.py#L339 - pool = await asyncpg.create_pool(conn_string, init=PGVectorHelper.init_pool, - min_size=pool_min, max_size=pool_max) + init_pool_ = partial(PGVectorHelper.init_pool, schema=schema) + pool = await asyncpg.create_pool( + conn_string, init=init_pool_, min_size=pool_min, max_size=pool_max) - new_obj = cls(embedding_model, table_name, pool) + new_obj = cls(embedding_model, table_name, pool, schema=schema) return new_obj @staticmethod - async def init_pool(conn): + async def init_pool(conn, schema=None): ''' Initialize vector extension for a connection from a pool @@ -144,10 +154,22 @@ async def init_pool(conn): If they choose to have us create a connection pool (e.g. from_conn_params), it will use this ''' + schema = schema or DEFAULT_SCHEMA + # TODO: Clean all this up await conn.execute('CREATE EXTENSION IF NOT EXISTS vector;') - await register_vector(conn) - await conn.set_type_codec( # Register a codec for JSON - 'JSON', encoder=json.dumps, decoder=json.loads, schema='pg_catalog') + try: + await register_vector(conn, schema=schema) + except ValueError: + raise RuntimeError('Unable to find the vector type in the database or schema. You might need to enable it.') + try: + await conn.set_type_codec( # Register a codec for JSON + 'JSON', encoder=json.dumps, decoder=json.loads, schema=schema) + except ValueError: + try: + await conn.set_type_codec( # Register a codec for JSON + 'JSON', encoder=json.dumps, decoder=json.loads) + except ValueError: + pass # Hmm. Just called count in the qdrant version async def count_items(self) -> int: diff --git a/pylib/embedding/pgvector_data.py b/pylib/embedding/pgvector_data.py index ea72023..1a80f85 100644 --- a/pylib/embedding/pgvector_data.py +++ b/pylib/embedding/pgvector_data.py @@ -17,7 +17,7 @@ # Table names are checked to be legit sequel table names, and embed_dimension is assured to be an integer CREATE_TABLE_BASE = '''-- Create a table to hold embedded documents or data -CREATE TABLE IF NOT EXISTS {{table_name}} ( +{{set_schema}}CREATE TABLE IF NOT EXISTS {{table_name}} ( id BIGSERIAL PRIMARY KEY, embedding VECTOR({{embed_dimension}}), -- embedding vectors (array dimension) content TEXT NOT NULL, -- text content of the chunk @@ -78,10 +78,12 @@ async def create_table(self) -> None: ''' Create the table to hold embedded documents ''' + set_schema = f'SET SCHEMA \'{self.schema}\';\n' if self.schema else '' async with self.pool.acquire() as conn: async with conn.transaction(): await conn.execute( CREATE_DATA_TABLE.format( + set_schema=set_schema, table_name=self.table_name, embed_dimension=self._embed_dimension) ) diff --git a/pylib/embedding/pgvector_message.py b/pylib/embedding/pgvector_message.py index dd8f7b7..08e7a1f 100644 --- a/pylib/embedding/pgvector_message.py +++ b/pylib/embedding/pgvector_message.py @@ -131,7 +131,7 @@ # ------ Class implementations --------------------------------------------------------------------------------------- class MessageDB(PGVectorHelper): - def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool, window=0): + def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool, window=0, schema=None): ''' Helper class for messages/chatlog storage and retrieval @@ -140,22 +140,22 @@ def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool, wi https://huggingface.co/sentence-transformers window (int, optional): number of messages to maintain in the DB. Default is 0 (all messages) ''' - super().__init__(embedding_model, table_name, pool) + super().__init__(embedding_model, table_name, pool, schema=schema) self.window = window @classmethod async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password, window=0, - pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'MessageDB': # noqa: E501 + schema=None, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'MessageDB': # noqa: E501 obj = await super().from_conn_params(embedding_model, table_name, host, port, db_name, user, password, - pool_min, pool_max) + schema, pool_min, pool_max) obj.window = window return obj @classmethod async def from_conn_string(cls, conn_string, embedding_model, table_name, window=0, - pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'MessageDB': # noqa: E501 + schema=None, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE) -> 'MessageDB': # noqa: E501 obj = await super().from_conn_string(conn_string, embedding_model, table_name, - pool_min, pool_max) + schema, pool_min, pool_max) obj.window = window return obj diff --git a/pylib/text_helper.py b/pylib/text_helper.py index 5c3c52f..22797d5 100644 --- a/pylib/text_helper.py +++ b/pylib/text_helper.py @@ -57,7 +57,7 @@ def text_split(text: str, chunk_size: int, separator: str='\n\n', joiner=None, l # Split the text by the separator if joiner is None: separator = f'({separator})' - sep_pat = re.compile(separator) + sep_pat = re.compile(separator, flags=re.M) raw_split = re.split(sep_pat, text) # Rapid aid to understanding following logic: diff --git a/test/test_text_splitter.py b/test/test_text_splitter.py index 0cd2071..89b412c 100644 --- a/test/test_text_splitter.py +++ b/test/test_text_splitter.py @@ -70,5 +70,51 @@ def test_zero_overlap(LOREM_IPSUM): assert len(chunk) <= 100 +# TODO: Markup split based on below: + +''' +# Hello + +Goodbye + +## World + +Neighborhood + +### Spam + +Spam spam spam! + +# Eggs + +Green, with ham +''' + +# Check the differences with e.g. +# list(text_split(s, chunk_size=50, separator=r'^(#)')) +# Where chunk_size varies from 5 to 100 & sep is also e.g. r'^#', r'^(#+)', etc. + +# Notice, + +# Based on https://regex101.com/r/cVCCSg/1 This wacky example: + +''' +import re +text= """# Heading 1 +## heading 2 (some text in parentheses) +###Heading 3 + +Don't match the following: + +[Some internal link]( +#foo) +[Some internal link]( +#foo) +[Some internal link]( +#foo +)""" +print( re.sub(r'(\[[^][]*]\([^()]*\))|^(#+)(.*)', lambda x: x.group(1) if x.group(1) else "{0}".format(x.group(3), len(x.group(2))), text, flags=re.M) ) +''' + if __name__ == '__main__': raise SystemExit("Attention! Run with pytest")