Skip to content

Commit

Permalink
Linty things
Browse files Browse the repository at this point in the history
  • Loading branch information
uogbuji committed Oct 18, 2024
1 parent 59e96e7 commit 3199ddd
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 15 deletions.
16 changes: 9 additions & 7 deletions pylib/embedding/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
register_vector = object() # Set up a dummy to satisfy the type hints
POOL_TYPE = object

# See:
# See: https://github.com/OoriData/OgbujiPT/issues/87
DEFAULT_USER_SCHEMA = 'public'
DEFAULT_SYSTEM_SCHEMA = 'pg_catalog'

Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(self, embedding_model, table_name: str, pool, sys_schema=DEFAULT_SY
using multiple schemata, you can run into `ERROR: type "vector" does not exist`
unless a schema with the extension is in the search path (via `SET SCHEMA`)
half_precision (bool) - if True, use halfvec type to store half-precision vectors (pgvector 0.7.0 & up only).
half_precision (bool) - if True, use halfvec type to store half-precision vectors (pgvector 0.7.0+ only).
Default is False (full precision)
itypes (list <str>) - Index types (or empty for no indexing).
Expand Down Expand Up @@ -138,8 +138,9 @@ def __init__(self, embedding_model, table_name: str, pool, sys_schema=DEFAULT_SY

@classmethod
async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password,
sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE,
half_precision=False, itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'PGVectorHelper': # noqa: E501
sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE,
pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, half_precision=False, itypes=None, ifuncs=None,
i_max_conn=16, ef_construction=64) -> 'PGVectorHelper': # noqa: E501
'''
Create PGVectorHelper instance from connection/pool parameters
Expand All @@ -165,8 +166,9 @@ async def from_conn_params(cls, embedding_model, table_name, host, port, db_name

@classmethod
async def from_conn_string(cls, conn_string, embedding_model, table_name,
sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE,
half_precision=False, itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'PGVectorHelper': # noqa: E501
sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE,
pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, half_precision=False,
itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'PGVectorHelper': # noqa: E501
'''
Create PGVectorHelper instance from a connection string AKA DSN
Expand Down Expand Up @@ -201,7 +203,7 @@ async def init_pool(conn, schema=DEFAULT_SYSTEM_SCHEMA):
try:
await register_vector(conn, schema=schema)
except ValueError as e:
raise RuntimeError(f'Unable to find the vector type in the database or schema. You might need to enable it: {e}')
raise RuntimeError(f'Unable to find vector type in the DB/schema. You might need to enable it: {e}')
try:
await conn.set_type_codec( # Register a codec for JSON
'jsonb', encoder=json.dumps, decoder=json.loads, schema=schema)
Expand Down
3 changes: 1 addition & 2 deletions pylib/embedding/pgvector_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

from typing import Iterable, Callable, List, Sequence

from ogbujipt.embedding.pgvector import (PGVectorHelper, asyncpg, process_search_response,
DEFAULT_MIN_CONNECTION_POOL_SIZE, DEFAULT_MAX_CONNECTION_POOL_SIZE, DEFAULT_SYSTEM_SCHEMA, DEFAULT_USER_SCHEMA)
from ogbujipt.embedding.pgvector import (PGVectorHelper, asyncpg, process_search_response)

__all__ = ['DataDB']

Expand Down
12 changes: 7 additions & 5 deletions pylib/embedding/pgvector_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ogbujipt.config import attr_dict
from ogbujipt.embedding.pgvector import (PGVectorHelper, asyncpg, process_search_response,
DEFAULT_MIN_CONNECTION_POOL_SIZE, DEFAULT_MAX_CONNECTION_POOL_SIZE, DEFAULT_SYSTEM_SCHEMA, DEFAULT_USER_SCHEMA)
DEFAULT_MIN_CONNECTION_POOL_SIZE, DEFAULT_MAX_CONNECTION_POOL_SIZE, DEFAULT_SYSTEM_SCHEMA)

__all__ = ['MessageDB']

Expand Down Expand Up @@ -146,17 +146,19 @@ def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool, wi

@classmethod
async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password, window=0,
sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE,
half_precision=False, itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'MessageDB': # noqa: E501
sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE,
pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, half_precision=False,
itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'MessageDB': # noqa: E501
obj = await super().from_conn_params(embedding_model, table_name, host, port, db_name, user, password,
sys_schema, pool_min, pool_max)
obj.window = window
return obj

@classmethod
async def from_conn_string(cls, conn_string, embedding_model, table_name, window=0,
sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE, pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE,
half_precision=False, itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'MessageDB': # noqa: E501
sys_schema=DEFAULT_SYSTEM_SCHEMA, pool_min=DEFAULT_MIN_CONNECTION_POOL_SIZE,
pool_max=DEFAULT_MAX_CONNECTION_POOL_SIZE, half_precision=False,
itypes=None, ifuncs=None, i_max_conn=16, ef_construction=64) -> 'MessageDB': # noqa: E501
obj = await super().from_conn_string(conn_string, embedding_model, table_name,
sys_schema, pool_min, pool_max)
obj.window = window
Expand Down
3 changes: 2 additions & 1 deletion test/embedding/test_pgvector_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ async def test_data_vector_half_index_half(DB_HALF_INDEX_HALF):
metadata=meta, # Tag metadata
)

assert await DB_HALF_INDEX_HALF.count_items() == len(KG_STATEMENTS), Exception('Incorrect number of documents after insertion')
assert await DB_HALF_INDEX_HALF.count_items() == len(KG_STATEMENTS), \
Exception('Incorrect number of documents after insertion')

# search table with perfect match
result = await DB_HALF_INDEX_HALF.search(text=item1_text, limit=3)
Expand Down

0 comments on commit 3199ddd

Please sign in to comment.