diff --git a/pylib/embedding/pgvector.py b/pylib/embedding/pgvector.py index dba1e09..ee9ef63 100644 --- a/pylib/embedding/pgvector.py +++ b/pylib/embedding/pgvector.py @@ -26,7 +26,7 @@ register_vector = object() # Set up a dummy to satisfy the type hints POOL_TYPE = object -# See: +# See: https://github.com/OoriData/OgbujiPT/issues/87 DEFAULT_USER_SCHEMA = 'public' DEFAULT_SYSTEM_SCHEMA = 'pg_catalog' @@ -84,7 +84,7 @@ def __init__(self, embedding_model, table_name: str, pool, sys_schema=DEFAULT_SY using multiple schemata, you can run into `ERROR: type "vector" does not exist` unless a schema with the extension is in the search path (via `SET SCHEMA`) - half_precision (bool) - if True, use halfvec type to store half-precision vectors (pgvector 0.7.0 & up only). + half_precision (bool) - if True, use halfvec type to store half-precision vectors (pgvector 0.7.0+ only). Default is False (full precision) itypes (list ) - Index types (or empty for no indexing). @@ -138,8 +138,9 @@ def __init__(self, embedding_model, table_name: str, pool, sys_schema=DEFAULT_SY @classmethod async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password, - sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, - half_precision=False, itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'PGVectorHelper': # noqa: E501 + sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, + pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, half_precision=False, itypes=None, ifuncs=None, + i_max_conn=16, ef_construction=64) -> 'PGVectorHelper': # noqa: E501 ''' Create PGVectorHelper instance from connection/pool parameters @@ -165,8 +166,9 @@ async def from_conn_params(cls, embedding_model, table_name, host, port, db_name @classmethod async def from_conn_string(cls, conn_string, embedding_model, table_name, - sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, - half_precision=False, itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'PGVectorHelper': # noqa: E501 + sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, + pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, half_precision=False, + itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'PGVectorHelper': # noqa: E501 ''' Create PGVectorHelper instance from a connection string AKA DSN @@ -201,7 +203,7 @@ async def init_pool(conn, schema=DEFAULT_SYSTEM_SCHEMA): try: await register_vector(conn, schema=schema) except ValueError as e: - raise RuntimeError(f'Unable to find the vector type in the database or schema. You might need to enable it: {e}') + raise RuntimeError(f'Unable to find vector type in the DB/schema. You might need to enable it: {e}') try: await conn.set_type_codec( # Register a codec for JSON 'jsonb', encoder=json.dumps, decoder=json.loads, schema=schema) diff --git a/pylib/embedding/pgvector_data.py b/pylib/embedding/pgvector_data.py index 62d588b..edc8c68 100644 --- a/pylib/embedding/pgvector_data.py +++ b/pylib/embedding/pgvector_data.py @@ -7,8 +7,7 @@ from typing import Iterable, Callable, List, Sequence -from ogbujipt.embedding.pgvector import (PGVectorHelper, asyncpg, process_search_response, - DEFAULT_MIN_CONNECTION_POOL_SIZE, DEFAULT_MAX_CONNECTION_POOL_SIZE, DEFAULT_SYSTEM_SCHEMA, DEFAULT_USER_SCHEMA) +from ogbujipt.embedding.pgvector import (PGVectorHelper, asyncpg, process_search_response) __all__ = ['DataDB'] diff --git a/pylib/embedding/pgvector_message.py b/pylib/embedding/pgvector_message.py index 1893e47..12e9d85 100644 --- a/pylib/embedding/pgvector_message.py +++ b/pylib/embedding/pgvector_message.py @@ -12,7 +12,7 @@ from ogbujipt.config import attr_dict from ogbujipt.embedding.pgvector import (PGVectorHelper, asyncpg, process_search_response, - DEFAULT_MIN_CONNECTION_POOL_SIZE, DEFAULT_MAX_CONNECTION_POOL_SIZE, DEFAULT_SYSTEM_SCHEMA, DEFAULT_USER_SCHEMA) + DEFAULT_MIN_CONNECTION_POOL_SIZE, DEFAULT_MAX_CONNECTION_POOL_SIZE, DEFAULT_SYSTEM_SCHEMA) __all__ = ['MessageDB'] @@ -146,8 +146,9 @@ def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool, wi @classmethod async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password, window=0, - sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, - half_precision=False, itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'MessageDB': # noqa: E501 + sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, + pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, half_precision=False, + itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'MessageDB': # noqa: E501 obj = await super().from_conn_params(embedding_model, table_name, host, port, db_name, user, password, sys_schema, pool_min, pool_max) obj.window = window @@ -155,8 +156,9 @@ async def from_conn_params(cls, embedding_model, table_name, host, port, db_name @classmethod async def from_conn_string(cls, conn_string, embedding_model, table_name, window=0, - sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, - half_precision=False, itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'MessageDB': # noqa: E501 + sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, + pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, half_precision=False, + itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'MessageDB': # noqa: E501 obj = await super().from_conn_string(conn_string, embedding_model, table_name, sys_schema, pool_min, pool_max) obj.window = window diff --git a/test/embedding/test_pgvector_data.py b/test/embedding/test_pgvector_data.py index 9d2e331..ffb3ae9 100644 --- a/test/embedding/test_pgvector_data.py +++ b/test/embedding/test_pgvector_data.py @@ -207,7 +207,8 @@ async def test_data_vector_half_index_half(DB_HALF_INDEX_HALF): metadata=meta, # Tag metadata ) - assert await DB_HALF_INDEX_HALF.count_items() == len(KG_STATEMENTS), Exception('Incorrect number of documents after insertion') + assert await DB_HALF_INDEX_HALF.count_items() == len(KG_STATEMENTS), \ + Exception('Incorrect number of documents after insertion') # search table with perfect match result = await DB_HALF_INDEX_HALF.search(text=item1_text, limit=3)