diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0cb55d2..c72183c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,7 @@ jobs: # ensure postgres version this stays in sync with prod database # and with postgres version used in docker compose # Testing with postgres that has the pg vector extension - image: ankane/pgvector + image: pgvector/pgvector:pg16 env: # optional (defaults to `postgres`) POSTGRES_DB: langchain_test @@ -89,7 +89,7 @@ jobs: shell: bash run: | echo "Running tests, installing dependencies with poetry..." - poetry install --with test,lint,typing,docs + poetry install --with test,lint,typing - name: Run tests run: make test env: diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 0000000..87803d9 --- /dev/null +++ b/.tool-versions @@ -0,0 +1 @@ +python 3.9.21 diff --git a/README.md b/README.md index 06c0a0f..2905fd1 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,186 @@ pip install -U langchain-postgres ## Usage +### HNSW index + +```python +from langchain_postgres import PGVector, EmbeddingIndexType + +PGVector( + collection_name="test_collection", + embeddings=FakeEmbedding(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", +) +``` + +- Embedding length is required for HNSW index. +- Allowed values for `embedding_index_ops` are described in the [pgvector HNSW](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw). + +Can set `ef_construction` and `m` parameters for HNSW index. +Refer to the [pgvector HNSW Index Options](https://github.com/pgvector/pgvector?tab=readme-ov-file#index-options) to better understand these parameters. + +```python +from langchain_postgres import PGVector, EmbeddingIndexType + +PGVector( + collection_name="test_collection", + embeddings=FakeEmbedding(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + ef_construction=200, + m=48, +) +``` + +### IVFFlat index + +```python +from langchain_postgres import PGVector, EmbeddingIndexType + +PGVector( + collection_name="test_collection", + embeddings=FakeEmbedding(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.ivfflat, + embedding_index_ops="vector_cosine_ops", +) +``` + +- Embedding length is required for HNSW index. +- Allowed values for `embedding_index_ops` are described in the [pgvector IVFFlat](https://github.com/pgvector/pgvector?tab=readme-ov-file#ivfflat). + +### Binary Quantization + +```python +from langchain_postgres import PGVector, EmbeddingIndexType + +PGVector( + collection_name="test_collection", + embeddings=FakeEmbedding(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="bit_hamming_ops", + binary_quantization=True, + binary_limit=200, +) +``` + +- Works only with HNSW index with `bit_hamming_ops`. +- `binary_limit` increases the limit in the inner binary search. A higher value will increase the recall at the cost of speed. + +Refer to the [pgvector Binary Quantization](https://github.com/pgvector/pgvector?tab=readme-ov-file#binary-quantization) to better understand. + +### Partitioning + +```python +from langchain_postgres import PGVector + +PGVector( + collection_name="test_collection", + embeddings=FakeEmbedding(), + connection=CONNECTION_STRING, + enable_partitioning=True, +) +``` + +- Create partitions of `langchain_pg_embedding` table by `collection_id`. Useful with a large number of embeddings with different collection. + +Refer to the [pgvector Partitioning](https://github.com/pgvector/pgvector?tab=readme-ov-file#filtering) + +### Iterative Scan + +```python +from langchain_postgres import PGVector, EmbeddingIndexType, IterativeScan + +PGVector( + collection_name="test_collection", + embeddings=FakeEmbedding(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + iterative_scan=IterativeScan.relaxed_order +) +``` + +- `iterative_scan` can be set to `IterativeScan.relaxed_order` or `IterativeScan.strict_order` or disabled with `IterativeScan.off`. +- Requires an HNSW or IVFFlat index. + +Refer to the [pgvector Iterative Scan](https://github.com/pgvector/pgvector?tab=readme-ov-file#iterative-index-scans) to better understand. + +### Iterative Scan Options for HNSW index + +```python +from langchain_postgres import PGVector, EmbeddingIndexType, IterativeScan + +PGVector( + collection_name="test_collection", + embeddings=FakeEmbedding(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + iterative_scan=IterativeScan.relaxed_order, + max_scan_tuples=40000, + scan_mem_multiplier=2 +) +``` + +- `max_scan_tuples` control when the scan ends when `iterative_scan` is enabled. +- `scan_mem_multiplier` specify the max amount of memory to use for the scan. + +Refer to the [pgvector Iterative Scan Options](https://github.com/pgvector/pgvector?tab=readme-ov-file#iterative-scan-options) to better understand. + +### Full Text Search + +Can be used by specifying `full_text_search` parameter. + +```python +from langchain_postgres import PGVector + +vectorstore = PGVector( + collection_name="test_collection", + embeddings=FakeEmbedding(), + connection=CONNECTION_STRING, +) + +vectorstore.similarity_search( + "hello world", + full_text_search=["foo", "bar & baz"] +) +``` + +This adds the following statement to the `WHERE` clause: +```sql +AND document_vector @@ to_tsquery('foo | bar & baz') +``` + +Can be used with retrievers like this: +```python +from langchain_postgres import PGVector + +vectorstore = PGVector( + collection_name="test_collection", + embeddings=FakeEmbedding(), + connection=CONNECTION_STRING, +) + +retriever = vectorstore.as_retriever( + search_kwargs={ + "full_text_search": ["foo", "bar & baz"] + } +) +``` + +Refer to Postgres [Full Text Search](https://www.postgresql.org/docs/current/textsearch.html) for more information. + ### ChatMessageHistory The chat message history abstraction helps to persist chat message history diff --git a/docker-compose.yml b/docker-compose.yml index 6eabf2a..b5c6281 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -25,7 +25,7 @@ services: - postgres_data:/var/lib/postgresql/data pgvector: # postgres with the pgvector extension - image: ankane/pgvector + image: pgvector/pgvector:pg16 environment: POSTGRES_DB: langchain POSTGRES_USER: langchain diff --git a/langchain_postgres/__init__.py b/langchain_postgres/__init__.py index 241c8b8..fa2dc14 100644 --- a/langchain_postgres/__init__.py +++ b/langchain_postgres/__init__.py @@ -2,7 +2,7 @@ from langchain_postgres.chat_message_histories import PostgresChatMessageHistory from langchain_postgres.translator import PGVectorTranslator -from langchain_postgres.vectorstores import PGVector +from langchain_postgres.vectorstores import EmbeddingIndexType, IterativeScan, PGVector try: __version__ = metadata.version(__package__) @@ -15,4 +15,6 @@ "PostgresChatMessageHistory", "PGVector", "PGVectorTranslator", + "EmbeddingIndexType", + "IterativeScan", ] diff --git a/langchain_postgres/_utils.py b/langchain_postgres/_utils.py index 14eae3b..933e3ee 100644 --- a/langchain_postgres/_utils.py +++ b/langchain_postgres/_utils.py @@ -78,3 +78,8 @@ def maximal_marginal_relevance( idxs.append(idx_to_add) selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) return idxs + + +def chunkerize(lst: list, n: int): + for i in range(0, len(lst), n): + yield lst[i : i + n] \ No newline at end of file diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index e1630a1..5afe6f9 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -29,9 +29,19 @@ from langchain_core.embeddings import Embeddings from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore -from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select -from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert +from pgvector.sqlalchemy import BIT, VECTOR # type: ignore +from sqlalchemy import ( + SQLColumnExpression, + cast, + create_engine, + delete, + func, + select, + text, +) +from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, TSVECTOR, UUID, insert from sqlalchemy.engine import Connection, Engine +from sqlalchemy.exc import ProgrammingError from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -39,14 +49,17 @@ create_async_engine, ) from sqlalchemy.orm import ( + Bundle, Session, + aliased, declarative_base, relationship, scoped_session, sessionmaker, ) +from sqlalchemy.sql.ddl import CreateIndex, CreateTable -from langchain_postgres._utils import maximal_marginal_relevance +from langchain_postgres._utils import maximal_marginal_relevance, chunkerize class DistanceStrategy(str, enum.Enum): @@ -55,6 +68,8 @@ class DistanceStrategy(str, enum.Enum): EUCLIDEAN = "l2" COSINE = "cosine" MAX_INNER_PRODUCT = "inner" + HAMMING = "hamming" + JACCARD = "jaccard" DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE @@ -98,7 +113,16 @@ class DistanceStrategy(str, enum.Enum): ) -def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any: +def _get_embedding_collection_store( + vector_dimension: Optional[int] = None, + embedding_index: Optional[EmbeddingIndexType] = None, + embedding_index_ops: Optional[str] = None, + ef_construction: Optional[int] = None, + m: Optional[int] = None, + binary_quantization: Optional[bool] = None, + embedding_length: Optional[int] = None, + partition: Optional[bool] = None, +) -> Any: global _classes if _classes is not None: return _classes @@ -154,6 +178,7 @@ def get_or_create( session: Session, name: str, cmetadata: Optional[dict] = None, + partition: Optional[bool] = None, ) -> Tuple["CollectionStore", bool]: """Get or create a collection. Returns: @@ -161,12 +186,24 @@ def get_or_create( """ # noqa: E501 created = False collection = cls.get_by_name(session, name) + if collection: + if partition: + cls._ensure_partition_exists(session, collection) + return collection, created collection = cls(name=name, cmetadata=cmetadata) session.add(collection) + session.flush() + session.refresh(collection) + + if partition: + ddl = cls._create_partition_ddl(str(collection.uuid)) + session.execute(text(ddl)) + session.commit() + created = True return collection, created @@ -176,6 +213,7 @@ async def aget_or_create( session: AsyncSession, name: str, cmetadata: Optional[dict] = None, + partition: Optional[bool] = None, ) -> Tuple["CollectionStore", bool]: """ Get or create a collection. @@ -183,23 +221,100 @@ async def aget_or_create( """ # noqa: E501 created = False collection = await cls.aget_by_name(session, name) + if collection: + if partition: + await cls._aensure_partition_exists(session, collection) + return collection, created collection = cls(name=name, cmetadata=cmetadata) session.add(collection) + await session.flush() + await session.refresh(collection) + + if partition: + ddl = cls._create_partition_ddl(str(collection.uuid)) + await session.execute(text(ddl)) + await session.commit() + created = True return collection, created + @classmethod + def _ensure_partition_exists(cls, session: Session, collection: CollectionStore) -> None: + try: + ddl = cls._create_partition_ddl(str(collection.uuid)) + session.execute(text(ddl)) + session.commit() + except ProgrammingError as e: + if "already exists" not in str(e): + raise e + + @classmethod + async def _aensure_partition_exists(cls, session: AsyncSession, collection: CollectionStore) -> None: + try: + ddl = cls._create_partition_ddl(str(collection.uuid)) + await session.execute(text(ddl)) + await session.commit() + except ProgrammingError as e: + if "already exists" not in str(e): + raise e + + @classmethod + def _create_partition_ddl(cls, uuid: str) -> str: + return f""" + CREATE TABLE {EmbeddingStore.__tablename__}_{uuid.replace('-', '_')} + PARTITION OF {EmbeddingStore.__tablename__} FOR VALUES IN ('{uuid}') + """ + + def _create_optional_index_params() -> dict: + if not (m or ef_construction): + return {} + + return { + "postgresql_with": { + k: v + for k, v in { + "m": m, + "ef_construction": ef_construction, + }.items() + if v is not None + } + } + + def _create_index(embedding: sqlalchemy.Column) -> Optional[sqlalchemy.Index]: + if embedding_index is None: + return None + + optional_index_params = _create_optional_index_params() + + if binary_quantization: + return sqlalchemy.Index( + f"ix_embedding_{embedding_index.value}", + cast(func.binary_quantize(embedding), BIT(embedding_length)).label( + "embedding" + ), + postgresql_using=embedding_index.value, + postgresql_ops={"embedding": embedding_index_ops}, + **optional_index_params, + ) + + return sqlalchemy.Index( + f"ix_embedding_{embedding_index.value}", + embedding, + postgresql_using=embedding_index.value, + postgresql_ops={"embedding": embedding_index_ops}, + **optional_index_params, + ) + class EmbeddingStore(Base): """Embedding store.""" __tablename__ = "langchain_pg_embedding" - id = sqlalchemy.Column( - sqlalchemy.String, nullable=True, primary_key=True, index=True, unique=True - ) + id = sqlalchemy.Column(sqlalchemy.String, nullable=True, primary_key=True) collection_id = sqlalchemy.Column( UUID(as_uuid=True), @@ -207,12 +322,20 @@ class EmbeddingStore(Base): f"{CollectionStore.__tablename__}.uuid", ondelete="CASCADE", ), + primary_key=True if partition else False, ) collection = relationship(CollectionStore, back_populates="embeddings") embedding: Vector = sqlalchemy.Column(Vector(vector_dimension)) document = sqlalchemy.Column(sqlalchemy.String, nullable=True) cmetadata = sqlalchemy.Column(JSONB, nullable=True) + document_vector = sqlalchemy.Column( + TSVECTOR, + sqlalchemy.Computed( + "to_tsvector('english', document)", + persisted=True, + ), + ) __table_args__ = ( sqlalchemy.Index( @@ -221,6 +344,12 @@ class EmbeddingStore(Base): postgresql_using="gin", postgresql_ops={"cmetadata": "jsonb_path_ops"}, ), + sqlalchemy.Index( + "ix_document_vector_gin", + "document_vector", + postgresql_using="gin", + ), + _create_index(embedding), ) _classes = (EmbeddingStore, CollectionStore) @@ -245,6 +374,17 @@ def _create_vector_extension(conn: Connection) -> None: DBConnection = Union[sqlalchemy.engine.Engine, str] +class EmbeddingIndexType(enum.Enum): + hnsw = "hnsw" + ivfflat = "ivfflat" + + +class IterativeScan(enum.Enum): + off = "off" + strict_order = "strict_order" + relaxed_order = "relaxed_order" + + class PGVector(VectorStore): """Postgres vector store integration. @@ -377,6 +517,8 @@ def __init__( *, connection: Union[None, DBConnection, Engine, AsyncEngine, str] = None, embedding_length: Optional[int] = None, + embedding_index: Optional[EmbeddingIndexType] = None, + embedding_index_ops: Optional[str] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, @@ -387,6 +529,14 @@ def __init__( use_jsonb: bool = True, create_extension: bool = True, async_mode: bool = False, + iterative_scan: Optional[IterativeScan] = None, + max_scan_tuples: Optional[int] = None, + scan_mem_multiplier: Optional[int] = None, + ef_construction: Optional[int] = None, + m: Optional[int] = None, + binary_quantization: Optional[bool] = None, + binary_limit: Optional[int] = None, + enable_partitioning: Optional[bool] = None, ) -> None: """Initialize the PGVector store. For an async version, use `PGVector.acreate()` instead. @@ -399,6 +549,10 @@ def __init__( NOTE: This is not mandatory. Defining it will prevent vectors of any other size to be added to the embeddings table but, without it, the embeddings can't be indexed. + embedding_index: The type of index to use for the embeddings. + (default: None) + embedding_index_ops: The index operator class to use for the index. + (default: None) collection_name: The name of the collection to use. (default: langchain) NOTE: This is not the name of the table, but the name of the collection. The tables will be created when initializing the store (if not exists) @@ -415,10 +569,23 @@ def __init__( create_extension: If True, will create the vector extension if it doesn't exist. disabling creation is useful when using ReadOnly Databases. + iterative_scan: Enables iterative scan. Required index + max_scan_tuples: Maximum number of tuples to scan. Required HNSW index + scan_mem_multiplier: Max memory to use during scan. Required HNSW index + ef_construction: HNSW index parameter. + m: HNSW index parameter. + binary_quantization: Enable binary quantization. Required HNSW index with + bit_hamming_ops + binary_limit: Inner limit for binary quantization. Is mandatory when using + binary_quantization enable_partitioning: Enables partitioning of + langchain_pg_embedding by collection_id. + """ self.async_mode = async_mode self.embedding_function = embeddings self._embedding_length = embedding_length + self._embedding_index = embedding_index + self._embedding_index_ops = embedding_index_ops self.collection_name = collection_name self.collection_metadata = collection_metadata self._distance_strategy = distance_strategy @@ -428,6 +595,24 @@ def __init__( self._engine: Optional[Engine] = None self._async_engine: Optional[AsyncEngine] = None self._async_init = False + self._iterative_scan = iterative_scan + self._max_scan_tuples = max_scan_tuples + self._scan_mem_multiplier = scan_mem_multiplier + self._ef_construction = ef_construction + self._m = m + self._binary_quantization = binary_quantization + self._binary_limit = binary_limit + self._enable_partitioning = enable_partitioning + + if self._embedding_length is None and self._embedding_index is not None: + raise ValueError( + "embedding_length must be provided when using embedding_index" + ) + + if self._embedding_index is not None and self._embedding_index_ops is None: + raise ValueError( + "embedding_index_ops must be provided when using embedding_index" + ) if isinstance(connection, str): if async_mode: @@ -459,6 +644,57 @@ def __init__( if not use_jsonb: # Replace with a deprecation warning. raise NotImplementedError("use_jsonb=False is no longer supported.") + + if self._embedding_index is None and self._iterative_scan is not None: + raise ValueError("iterative_scan is not supported without embedding_index") + + if ( + self._max_scan_tuples is not None + and self._embedding_index != EmbeddingIndexType.hnsw + ): + raise ValueError( + "max_scan_tuples is not supported without embedding_index=hnsw" + ) + + if ( + self._scan_mem_multiplier is not None + and self._embedding_index != EmbeddingIndexType.hnsw + ): + raise ValueError( + "scan_mem_multiplier is not supported without embedding_index=hnsw" + ) + + if ( + self._embedding_index != EmbeddingIndexType.hnsw + and self._ef_construction is not None + ): + raise ValueError( + "ef_construction is not supported without embedding_index=hnsw" + ) + + if self._embedding_index != EmbeddingIndexType.hnsw and self._m is not None: + raise ValueError("m is not supported without embedding_index=hnsw") + + if ( + self._binary_quantization is True + and self._embedding_index != EmbeddingIndexType.hnsw + ): + raise ValueError( + "binary_quantization is not supported without embedding_index=hnsw" + ) + + if self._binary_quantization is True and self._embedding_index_ops not in [ + "bit_hamming_ops" + ]: + raise ValueError( + "binary_quantization is only supported with bit_hamming_ops" + ) + + if self._binary_quantization is True and self._binary_limit is None: + raise ValueError( + "binary_limit must be provided when using binary_quantization" + ) + if not self.async_mode: self.__post_init__() @@ -470,8 +706,16 @@ def __post_init__( self.create_vector_extension() EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length + vector_dimension=self._embedding_length, + embedding_index=self._embedding_index, + embedding_index_ops=self._embedding_index_ops, + ef_construction=self._ef_construction, + m=self._m, + binary_quantization=self._binary_quantization, + embedding_length=self._embedding_length, + partition=self._enable_partitioning, ) + self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore self.create_tables_if_not_exists() @@ -486,8 +730,16 @@ async def __apost_init__( self._async_init = True EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length + vector_dimension=self._embedding_length, + embedding_index=self._embedding_index, + embedding_index_ops=self._embedding_index_ops, + ef_construction=self._ef_construction, + m=self._m, + binary_quantization=self._binary_quantization, + embedding_length=self._embedding_length, + partition=self._enable_partitioning, ) + self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore if self.create_extension: @@ -516,14 +768,112 @@ async def acreate_vector_extension(self) -> None: def create_tables_if_not_exists(self) -> None: with self._make_sync_session() as session: + if self._enable_partitioning: + self._create_tables_with_partition_if_not_exists(session) + return + Base.metadata.create_all(session.get_bind()) session.commit() async def acreate_tables_if_not_exists(self) -> None: assert self._async_engine, "This method must be called with async_mode" + + if self._enable_partitioning: + async with self._make_async_session() as session: + await self._acreate_tables_with_partition_if_not_exists(session) + return + async with self._async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) + def _create_tables_with_partition_if_not_exists(self, session: Session) -> None: + if self._check_if_table_exists(session, self.CollectionStore): + return + + collection_table_ddl = self._compile_table_ddl(self.CollectionStore) + collection_index_ddls = self._compile_index_ddls(self.CollectionStore) + + session.execute(text(collection_table_ddl)) + for ddl in collection_index_ddls: + session.execute(text(ddl)) + + compiled_ddl = self._compile_table_ddl(self.EmbeddingStore).strip() + + embedding_table_ddl = f""" + {compiled_ddl} PARTITION BY LIST (collection_id) + """ + embedding_index_ddls = self._compile_index_ddls(self.EmbeddingStore) + + session.execute(text(embedding_table_ddl)) + for ddl in embedding_index_ddls: + session.execute(text(ddl)) + + session.commit() + + async def _acreate_tables_with_partition_if_not_exists( + self, session: AsyncSession + ) -> None: + if await self._acheck_if_table_exists(session, self.CollectionStore): + return + + collection_table_ddl = self._compile_table_ddl(self.CollectionStore) + collection_index_ddls = self._compile_index_ddls(self.CollectionStore) + + await session.execute(text(collection_table_ddl)) # type: ignore + for ddl in collection_index_ddls: + await session.execute(text(ddl)) # type: ignore + + compiled_ddl = self._compile_table_ddl(self.EmbeddingStore).strip() + + embedding_table_ddl = f""" + {compiled_ddl} PARTITION BY LIST (collection_id) + """ + embedding_index_ddls = self._compile_index_ddls(self.EmbeddingStore) + + await session.execute(text(embedding_table_ddl)) # type: ignore + for ddl in embedding_index_ddls: + await session.execute(text(ddl)) # type: ignore + + await session.commit() # type: ignore + + def _check_if_table_exists(self, session: Session, table: Any) -> bool: + inspector = sqlalchemy.inspect(session.get_bind()) + return inspector.has_table(table.__tablename__) + + async def _acheck_if_table_exists(self, session: AsyncSession, table: Any) -> bool: + return await session.run_sync( + lambda sync_session: sqlalchemy.inspect(sync_session.bind).has_table( # type: ignore + table.__tablename__ + ) + ) + + def _compile_table_ddl(self, table: Any) -> str: + return str( + CreateTable(table.__table__).compile( + dialect=sqlalchemy.dialects.postgresql.dialect() + ) + ) + + def _compile_index_ddls(self, table: Any) -> List[str]: + ddls: List[str] = [] + + table_args = getattr(table, "__table_args__", []) + + indexes = [arg for arg in list(table_args) if isinstance(arg, sqlalchemy.Index)] + + if not indexes: + return ddls + + for index in indexes: + ddl = str( + CreateIndex(index).compile( + dialect=sqlalchemy.dialects.postgresql.dialect() + ) + ) + ddls.append(ddl) + + return ddls + def drop_tables(self) -> None: with self._make_sync_session() as session: Base.metadata.drop_all(session.get_bind()) @@ -540,7 +890,10 @@ def create_collection(self) -> None: self.delete_collection() with self._make_sync_session() as session: self.CollectionStore.get_or_create( - session, self.collection_name, cmetadata=self.collection_metadata + session, + self.collection_name, + cmetadata=self.collection_metadata, + partition=self._enable_partitioning, ) session.commit() @@ -550,7 +903,10 @@ async def acreate_collection(self) -> None: if self.pre_delete_collection: await self._adelete_collection(session) await self.CollectionStore.aget_or_create( - session, self.collection_name, cmetadata=self.collection_metadata + session, + self.collection_name, + cmetadata=self.collection_metadata, + partition=self._enable_partitioning, ) await session.commit() @@ -561,6 +917,9 @@ def _delete_collection(self, session: Session) -> None: return session.delete(collection) + if self._enable_partitioning: + self._delete_partition(session, str(collection.uuid)) + async def _adelete_collection(self, session: AsyncSession) -> None: collection = await self.aget_collection(session) if not collection: @@ -568,6 +927,9 @@ async def _adelete_collection(self, session: AsyncSession) -> None: return await session.delete(collection) + if self._enable_partitioning: + await self._adelete_partition(session, str(collection.uuid)) + def delete_collection(self) -> None: with self._make_sync_session() as session: collection = self.get_collection(session) @@ -575,6 +937,10 @@ def delete_collection(self) -> None: self.logger.warning("Collection not found") return session.delete(collection) + + if self._enable_partitioning: + self._delete_partition(session, str(collection.uuid)) + session.commit() async def adelete_collection(self) -> None: @@ -585,8 +951,23 @@ async def adelete_collection(self) -> None: self.logger.warning("Collection not found") return await session.delete(collection) + + if self._enable_partitioning: + await self._adelete_partition(session, str(collection.uuid)) + await session.commit() + def _delete_partition(self, session: Session, uuid: str) -> None: + session.execute(text(f"DROP TABLE {self._compute_partition_table_name(uuid)}")) + + async def _adelete_partition(self, session: AsyncSession, uuid: str) -> None: + await session.execute( + text(f"DROP TABLE {self._compute_partition_table_name(uuid)}") + ) + + def _compute_partition_table_name(self, uuid: str) -> str: + return f"{self.EmbeddingStore.__tablename__}_{uuid.replace('-', '_')}" + def delete( self, ids: Optional[List[str]] = None, @@ -745,12 +1126,12 @@ async def __afrom( return store def add_embeddings( - self, - texts: Sequence[str], - embeddings: List[List[float]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - **kwargs: Any, + self, + texts: Sequence[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, ) -> List[str]: """Add embeddings to the vectorstore. @@ -773,23 +1154,32 @@ def add_embeddings( with self._make_sync_session() as session: # type: ignore[arg-type] collection = self.get_collection(session) + if not collection: raise ValueError("Collection not found") - data = [ - { - "id": id, - "collection_id": collection.uuid, - "embedding": embedding, - "document": text, - "cmetadata": metadata or {}, - } - for text, metadata, embedding, id in zip( - texts, metadatas, embeddings, ids_ - ) - ] - stmt = insert(self.EmbeddingStore).values(data) + + data = [ + { + "id": id, + "collection_id": collection.uuid, + "embedding": embedding, + "document": text, + "cmetadata": metadata or {}, + } + for text, metadata, embedding, id in zip( + texts, metadatas, embeddings, ids_ + ) + ] + + for chunk in chunkerize(data, 500): + stmt = insert(self.EmbeddingStore).values(chunk) + + index_elements = ["id"] + if self._enable_partitioning: + index_elements.append("collection_id") + on_conflict_stmt = stmt.on_conflict_do_update( - index_elements=["id"], + index_elements=index_elements, # Conflict detection based on these columns set_={ "embedding": stmt.excluded.embedding, @@ -797,8 +1187,10 @@ def add_embeddings( "cmetadata": stmt.excluded.cmetadata, }, ) - session.execute(on_conflict_stmt) - session.commit() + + with self._make_sync_session() as session: + session.execute(on_conflict_stmt) + session.commit() return ids_ @@ -847,8 +1239,13 @@ async def aadd_embeddings( ) ] stmt = insert(self.EmbeddingStore).values(data) + + index_elements = ["id"] + if self._enable_partitioning: + index_elements.append("collection_id") + on_conflict_stmt = stmt.on_conflict_do_update( - index_elements=["id"], + index_elements=index_elements, # Conflict detection based on these columns set_={ "embedding": stmt.excluded.embedding, @@ -926,6 +1323,7 @@ def similarity_search( query: str, k: int = 4, filter: Optional[dict] = None, + full_text_search: Optional[List[str]] = None, **kwargs: Any, ) -> List[Document]: """Run similarity search with PGVector with distance. @@ -944,6 +1342,7 @@ def similarity_search( embedding=embedding, k=k, filter=filter, + full_text_search=full_text_search, ) async def asimilarity_search( @@ -951,6 +1350,7 @@ async def asimilarity_search( query: str, k: int = 4, filter: Optional[dict] = None, + full_text_search: Optional[List[str]] = None, **kwargs: Any, ) -> List[Document]: """Run similarity search with PGVector with distance. @@ -969,6 +1369,7 @@ async def asimilarity_search( embedding=embedding, k=k, filter=filter, + full_text_search=full_text_search, ) def similarity_search_with_score( @@ -976,6 +1377,7 @@ def similarity_search_with_score( query: str, k: int = 4, filter: Optional[dict] = None, + full_text_search: Optional[List[str]] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -990,7 +1392,7 @@ def similarity_search_with_score( assert not self._async_engine, "This method must be called without async_mode" embedding = self.embeddings.embed_query(query) docs = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter + embedding=embedding, k=k, filter=filter, full_text_search=full_text_search ) return docs @@ -999,6 +1401,7 @@ async def asimilarity_search_with_score( query: str, k: int = 4, filter: Optional[dict] = None, + full_text_search: Optional[List[str]] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to query. @@ -1013,18 +1416,25 @@ async def asimilarity_search_with_score( await self.__apost_init__() # Lazy async init embedding = await self.embeddings.aembed_query(query) docs = await self.asimilarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter + embedding=embedding, k=k, filter=filter, full_text_search=full_text_search ) return docs @property def distance_strategy(self) -> Any: + return self._build_distance_strategy() + + def _build_distance_strategy( + self, column: Optional[sqlalchemy.Column] = None + ) -> Any: + _column = column if column is not None else self.EmbeddingStore.embedding + if self._distance_strategy == DistanceStrategy.EUCLIDEAN: - return self.EmbeddingStore.embedding.l2_distance + return _column.l2_distance elif self._distance_strategy == DistanceStrategy.COSINE: - return self.EmbeddingStore.embedding.cosine_distance + return _column.cosine_distance elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: - return self.EmbeddingStore.embedding.max_inner_product + return _column.max_inner_product else: raise ValueError( f"Got unexpected value for distance: {self._distance_strategy}. " @@ -1036,9 +1446,12 @@ def similarity_search_with_score_by_vector( embedding: List[float], k: int = 4, filter: Optional[dict] = None, + full_text_search: Optional[List[str]] = None, ) -> List[Tuple[Document, float]]: assert not self._async_engine, "This method must be called without async_mode" - results = self.__query_collection(embedding=embedding, k=k, filter=filter) + results = self.__query_collection( + embedding=embedding, k=k, filter=filter, full_text_search=full_text_search + ) return self._results_to_docs_and_scores(results) @@ -1047,11 +1460,16 @@ async def asimilarity_search_with_score_by_vector( embedding: List[float], k: int = 4, filter: Optional[dict] = None, + full_text_search: Optional[List[str]] = None, ) -> List[Tuple[Document, float]]: await self.__apost_init__() # Lazy async init async with self._make_async_session() as session: # type: ignore[arg-type] results = await self.__aquery_collection( - session=session, embedding=embedding, k=k, filter=filter + session=session, + embedding=embedding, + k=k, + filter=filter, + full_text_search=full_text_search, ) return self._results_to_docs_and_scores(results) @@ -1396,42 +1814,27 @@ def __query_collection( embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, + full_text_search: Optional[List[str]] = None, ) -> Sequence[Any]: """Query the collection.""" with self._make_sync_session() as session: # type: ignore[arg-type] - collection = self.get_collection(session) - if not collection: - raise ValueError("Collection not found") - - filter_by = [self.EmbeddingStore.collection_id == collection.uuid] - if filter: - if self.use_jsonb: - filter_clauses = self._create_filter_clause(filter) - if filter_clauses is not None: - filter_by.append(filter_clauses) - else: - # Old way of doing things - filter_clauses = self._create_filter_clause_json_deprecated(filter) - filter_by.extend(filter_clauses) + self._set_iterative_scan(session) + self._set_max_scan_tuples(session) + self._set_scan_mem_multiplier(session) - _type = self.EmbeddingStore + collection = self.get_collection(session) - results: List[Any] = ( - session.query( - self.EmbeddingStore, - self.distance_strategy(embedding).label("distance"), - ) - .filter(*filter_by) - .order_by(sqlalchemy.asc("distance")) - .join( - self.CollectionStore, - self.EmbeddingStore.collection_id == self.CollectionStore.uuid, - ) - .limit(k) - .all() + stmt = self._build_query_collection( + collection=collection, + embedding=embedding, + k=k, + filter=filter, + full_text_search=full_text_search, ) - return results + results: Sequence[Any] = session.execute(stmt).all() + + return results async def __aquery_collection( self, @@ -1439,49 +1842,257 @@ async def __aquery_collection( embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, + full_text_search: Optional[List[str]] = None, ) -> Sequence[Any]: """Query the collection.""" async with self._make_async_session() as session: # type: ignore[arg-type] + await self._aset_iterative_scan(session) + await self._aset_max_scan_tuples(session) + await self._aset_scan_mem_multiplier(session) + collection = await self.aget_collection(session) - if not collection: - raise ValueError("Collection not found") - filter_by = [self.EmbeddingStore.collection_id == collection.uuid] - if filter: - if self.use_jsonb: - filter_clauses = self._create_filter_clause(filter) - if filter_clauses is not None: - filter_by.append(filter_clauses) - else: - # Old way of doing things - filter_clauses = self._create_filter_clause_json_deprecated(filter) - filter_by.extend(filter_clauses) + stmt = self._build_query_collection( + collection=collection, + embedding=embedding, + k=k, + filter=filter, + full_text_search=full_text_search, + ) - _type = self.EmbeddingStore + results: Sequence[Any] = (await session.execute(stmt)).all() - stmt = ( - select( - self.EmbeddingStore, - self.distance_strategy(embedding).label("distance"), - ) - .filter(*filter_by) - .order_by(sqlalchemy.asc("distance")) - .join( - self.CollectionStore, - self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + return results + + def _build_query_collection( + self, + collection: Any, + embedding: List[float], + k: int, + filter: Optional[Dict[str, str]] = None, + full_text_search: Optional[List[str]] = None, + ) -> Any: + if not collection: + raise ValueError("Collection not found") + + if self._binary_quantization: + return self._build_binary_quantization_query( + collection=collection, + embedding=embedding, + k=k, + filter=filter, + full_text_search=full_text_search, + ) + + stmt = self._build_base_query( + collection=collection, + embedding=embedding, + k=k, + filter=filter, + full_text_search=full_text_search, + ) + + if self._iterative_scan == IterativeScan.relaxed_order: + stmt = self._build_iterative_scan_query(stmt) + + return stmt + + def _build_binary_quantization_query( + self, + collection: Any, + embedding: List[float], + k: int, + filter: Optional[Dict[str, str]] = None, + full_text_search: Optional[List[str]] = None, + ) -> Any: + filter_by = self._build_filter(collection, filter, full_text_search) + + distance = cast( + func.binary_quantize(self.EmbeddingStore.embedding), + BIT(self._embedding_length), + ).hamming_distance( + func.binary_quantize( + cast( + embedding, + VECTOR(self._embedding_length), ) - .limit(k) ) + ) - results: Sequence[Any] = (await session.execute(stmt)).all() + sub = ( + select(self.EmbeddingStore) + .filter(*filter_by) + .order_by(distance) + .limit(self._binary_limit) + .subquery(name="binary_result") + ) - return results + EmbeddingStoreAlias = aliased(self.EmbeddingStore, sub) + embedding_store_bundle: Bundle = Bundle( + "EmbeddingStore", + *[ + getattr(EmbeddingStoreAlias, c.key) + for c in self.EmbeddingStore.__table__.columns + ], + ) + + return ( + select( + embedding_store_bundle, + self._build_distance_strategy(sub.c["embedding"])(embedding).label( # type: ignore + "distance" + ), + ) + .order_by(sqlalchemy.asc("distance")) + .limit(k) + ) + + def _build_base_query( + self, + collection: Any, + embedding: List[float], + k: int, + filter: Optional[Dict[str, str]] = None, + full_text_search: Optional[List[str]] = None, + ) -> Any: + filter_by = self._build_filter(collection, filter, full_text_search) + + return ( + select( + self.EmbeddingStore, + self.distance_strategy(embedding).label("distance"), + ) + .filter(*filter_by) + .order_by(sqlalchemy.asc("distance")) + .limit(k) + ) + + def _build_iterative_scan_query(self, stmt: Any) -> Any: + cte = stmt.cte("relaxed_results").prefix_with("MATERIALIZED") + + EmbeddingStoreAlias = aliased(self.EmbeddingStore, cte) + embedding_store_bundle: Bundle = Bundle( + "EmbeddingStore", + *[ + getattr(EmbeddingStoreAlias, c.key) + for c in self.EmbeddingStore.__table__.columns + ], + ) + + return select(embedding_store_bundle, cte.c.distance).order_by(text("distance")) + + def _build_filter( + self, + collection: Any, + filter: Optional[Dict[str, str]] = None, + full_text_search: Optional[List[str]] = None, + ) -> List[SQLColumnExpression]: + filter_by = [self.EmbeddingStore.collection_id == collection.uuid] + + if filter: + if self.use_jsonb: + filter_clauses = self._create_filter_clause(filter) + if filter_clauses is not None: + filter_by.append(filter_clauses) + else: + # Old way of doing things + filter_clauses = self._create_filter_clause_json_deprecated(filter) + filter_by.extend(filter_clauses) + + if full_text_search: + filter_by.append(self._build_full_text_search_filter(full_text_search)) + + return filter_by + + def _build_full_text_search_filter( + self, keywords: List[str] + ) -> SQLColumnExpression: + return self.EmbeddingStore.document_vector.op("@@")( + func.to_tsquery(" | ".join(list(map(lambda x: x.strip(), keywords)))) + ) + + def _set_iterative_scan(self, session: Session) -> None: + assert not self._async_engine, "This method must be called without async_mode" + + if self._iterative_scan is None or self._embedding_index is None: + return + + index = self._embedding_index.value + iterative_scan = self._iterative_scan.value + + stmt = f"SET {index}.iterative_scan = {iterative_scan}" + + session.execute(text(stmt)) + + async def _aset_iterative_scan(self, session: AsyncSession) -> None: + assert self._async_engine, "This method must be called with async_mode" + + if self._iterative_scan is None or self._embedding_index is None: + return + + index = self._embedding_index.value + iterative_scan = self._iterative_scan.value + + stmt = f"SET {index}.iterative_scan = {iterative_scan}" + + await session.execute(text(stmt)) + + def _set_max_scan_tuples(self, session: Session) -> None: + assert not self._async_engine, "This method must be called without async_mode" + + if ( + self._max_scan_tuples is None + or self._embedding_index != EmbeddingIndexType.hnsw + ): + return + + session.execute(text(f"SET hnsw.max_scan_tuples = {self._max_scan_tuples}")) + + async def _aset_max_scan_tuples(self, session: AsyncSession) -> None: + assert self._async_engine, "This method must be called with async_mode" + + if ( + self._max_scan_tuples is None + or self._embedding_index != EmbeddingIndexType.hnsw + ): + return + + await session.execute( + text(f"SET hnsw.max_scan_tuples = {self._max_scan_tuples}") + ) + + def _set_scan_mem_multiplier(self, session: Session) -> None: + assert not self._async_engine, "This method must be called without async_mode" + + if ( + self._scan_mem_multiplier is None + or self._embedding_index != EmbeddingIndexType.hnsw + ): + return + + session.execute( + text(f"SET hnsw.scan_mem_multiplier = {self._scan_mem_multiplier}") + ) + + async def _aset_scan_mem_multiplier(self, session: AsyncSession) -> None: + assert self._async_engine, "This method must be called with async_mode" + + if ( + self._scan_mem_multiplier is None + or self._embedding_index != EmbeddingIndexType.hnsw + ): + return + + await session.execute( + text(f"SET hnsw.scan_mem_multiplier = {self._scan_mem_multiplier}") + ) def similarity_search_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[dict] = None, + full_text_search: Optional[List[str]] = None, **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -1490,13 +2101,16 @@ def similarity_search_by_vector( embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + full_text_search: filter by full text search only if one or more words + are present in the document. If passed (string1 & string2) then the + document should contain both string1 and string2. Returns: List of Documents most similar to the query vector. """ assert not self._async_engine, "This method must be called without async_mode" docs_and_scores = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter + embedding=embedding, k=k, filter=filter, full_text_search=full_text_search ) return _results_to_docs(docs_and_scores) @@ -1505,6 +2119,7 @@ async def asimilarity_search_by_vector( embedding: List[float], k: int = 4, filter: Optional[dict] = None, + full_text_search: Optional[List[str]] = None, **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -1513,6 +2128,9 @@ async def asimilarity_search_by_vector( embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + full_text_search: filter by full text search only if one or more words + are present in the document. If passed (string1 & string2) then the + document should contain both string1 and string2. Returns: List of Documents most similar to the query vector. @@ -1520,7 +2138,7 @@ async def asimilarity_search_by_vector( assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init docs_and_scores = await self.asimilarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter + embedding=embedding, k=k, filter=filter, full_text_search=full_text_search ) return _results_to_docs(docs_and_scores) @@ -1876,6 +2494,7 @@ def max_marginal_relevance_search_with_score_by_vector( fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, str]] = None, + full_text_search: Optional[List[str]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score @@ -1900,7 +2519,12 @@ def max_marginal_relevance_search_with_score_by_vector( relevance to the query and score for each. """ assert not self._async_engine, "This method must be called without async_mode" - results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) + results = self.__query_collection( + embedding=embedding, + k=fetch_k, + filter=filter, + full_text_search=full_text_search, + ) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -1922,6 +2546,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, str]] = None, + full_text_search: Optional[List[str]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score @@ -1948,7 +2573,11 @@ async def amax_marginal_relevance_search_with_score_by_vector( await self.__apost_init__() # Lazy async init async with self._make_async_session() as session: results = await self.__aquery_collection( - session=session, embedding=embedding, k=fetch_k, filter=filter + session=session, + embedding=embedding, + k=fetch_k, + filter=filter, + full_text_search=full_text_search, ) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -2048,6 +2677,7 @@ def max_marginal_relevance_search_with_score( fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[dict] = None, + full_text_search: Optional[List[str]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score. @@ -2077,6 +2707,7 @@ def max_marginal_relevance_search_with_score( fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter, + full_text_search=full_text_search, **kwargs, ) return docs @@ -2088,6 +2719,7 @@ async def amax_marginal_relevance_search_with_score( fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[dict] = None, + full_text_search: Optional[List[str]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score. @@ -2118,6 +2750,7 @@ async def amax_marginal_relevance_search_with_score( fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter, + full_text_search=full_text_search, **kwargs, ) return docs @@ -2129,6 +2762,7 @@ def max_marginal_relevance_search_by_vector( fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, str]] = None, + full_text_search: Optional[List[str]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance @@ -2157,6 +2791,7 @@ def max_marginal_relevance_search_by_vector( fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter, + full_text_search=full_text_search, **kwargs, ) @@ -2169,6 +2804,7 @@ async def amax_marginal_relevance_search_by_vector( fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, str]] = None, + full_text_search: Optional[List[str]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance @@ -2199,6 +2835,7 @@ async def amax_marginal_relevance_search_by_vector( fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter, + full_text_search=full_text_search, **kwargs, ) ) diff --git a/poetry.lock b/poetry.lock index b71bde3..9dbfbf6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -624,6 +624,16 @@ files = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] +[[package]] +name = "docopt" +version = "0.6.2" +description = "Pythonic argument parser, that will make you smile" +optional = false +python-versions = "*" +files = [ + {file = "docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491"}, +] + [[package]] name = "entrypoints" version = "0.4" @@ -690,77 +700,84 @@ files = [ [[package]] name = "greenlet" -version = "3.1.0" +version = "3.1.1" description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.7" files = [ - {file = "greenlet-3.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a814dc3100e8a046ff48faeaa909e80cdb358411a3d6dd5293158425c684eda8"}, - {file = "greenlet-3.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a771dc64fa44ebe58d65768d869fcfb9060169d203446c1d446e844b62bdfdca"}, - {file = "greenlet-3.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0e49a65d25d7350cca2da15aac31b6f67a43d867448babf997fe83c7505f57bc"}, - {file = "greenlet-3.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2cd8518eade968bc52262d8c46727cfc0826ff4d552cf0430b8d65aaf50bb91d"}, - {file = "greenlet-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76dc19e660baea5c38e949455c1181bc018893f25372d10ffe24b3ed7341fb25"}, - {file = "greenlet-3.1.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c0a5b1c22c82831f56f2f7ad9bbe4948879762fe0d59833a4a71f16e5fa0f682"}, - {file = "greenlet-3.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:2651dfb006f391bcb240635079a68a261b227a10a08af6349cba834a2141efa1"}, - {file = "greenlet-3.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3e7e6ef1737a819819b1163116ad4b48d06cfdd40352d813bb14436024fcda99"}, - {file = "greenlet-3.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:ffb08f2a1e59d38c7b8b9ac8083c9c8b9875f0955b1e9b9b9a965607a51f8e54"}, - {file = "greenlet-3.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9730929375021ec90f6447bff4f7f5508faef1c02f399a1953870cdb78e0c345"}, - {file = "greenlet-3.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:713d450cf8e61854de9420fb7eea8ad228df4e27e7d4ed465de98c955d2b3fa6"}, - {file = "greenlet-3.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c3446937be153718250fe421da548f973124189f18fe4575a0510b5c928f0cc"}, - {file = "greenlet-3.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ddc7bcedeb47187be74208bc652d63d6b20cb24f4e596bd356092d8000da6d6"}, - {file = "greenlet-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44151d7b81b9391ed759a2f2865bbe623ef00d648fed59363be2bbbd5154656f"}, - {file = "greenlet-3.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cea1cca3be76c9483282dc7760ea1cc08a6ecec1f0b6ca0a94ea0d17432da19"}, - {file = "greenlet-3.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:619935a44f414274a2c08c9e74611965650b730eb4efe4b2270f91df5e4adf9a"}, - {file = "greenlet-3.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:221169d31cada333a0c7fd087b957c8f431c1dba202c3a58cf5a3583ed973e9b"}, - {file = "greenlet-3.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:01059afb9b178606b4b6e92c3e710ea1635597c3537e44da69f4531e111dd5e9"}, - {file = "greenlet-3.1.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:24fc216ec7c8be9becba8b64a98a78f9cd057fd2dc75ae952ca94ed8a893bf27"}, - {file = "greenlet-3.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d07c28b85b350564bdff9f51c1c5007dfb2f389385d1bc23288de51134ca303"}, - {file = "greenlet-3.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:243a223c96a4246f8a30ea470c440fe9db1f5e444941ee3c3cd79df119b8eebf"}, - {file = "greenlet-3.1.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:26811df4dc81271033a7836bc20d12cd30938e6bd2e9437f56fa03da81b0f8fc"}, - {file = "greenlet-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d86401550b09a55410f32ceb5fe7efcd998bd2dad9e82521713cb148a4a15f"}, - {file = "greenlet-3.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:26d9c1c4f1748ccac0bae1dbb465fb1a795a75aba8af8ca871503019f4285e2a"}, - {file = "greenlet-3.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cd468ec62257bb4544989402b19d795d2305eccb06cde5da0eb739b63dc04665"}, - {file = "greenlet-3.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a53dfe8f82b715319e9953330fa5c8708b610d48b5c59f1316337302af5c0811"}, - {file = "greenlet-3.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:28fe80a3eb673b2d5cc3b12eea468a5e5f4603c26aa34d88bf61bba82ceb2f9b"}, - {file = "greenlet-3.1.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:76b3e3976d2a452cba7aa9e453498ac72240d43030fdc6d538a72b87eaff52fd"}, - {file = "greenlet-3.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:655b21ffd37a96b1e78cc48bf254f5ea4b5b85efaf9e9e2a526b3c9309d660ca"}, - {file = "greenlet-3.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c6f4c2027689093775fd58ca2388d58789009116844432d920e9147f91acbe64"}, - {file = "greenlet-3.1.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:76e5064fd8e94c3f74d9fd69b02d99e3cdb8fc286ed49a1f10b256e59d0d3a0b"}, - {file = "greenlet-3.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a4bf607f690f7987ab3291406e012cd8591a4f77aa54f29b890f9c331e84989"}, - {file = "greenlet-3.1.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:037d9ac99540ace9424cb9ea89f0accfaff4316f149520b4ae293eebc5bded17"}, - {file = "greenlet-3.1.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:90b5bbf05fe3d3ef697103850c2ce3374558f6fe40fd57c9fac1bf14903f50a5"}, - {file = "greenlet-3.1.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:726377bd60081172685c0ff46afbc600d064f01053190e4450857483c4d44484"}, - {file = "greenlet-3.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:d46d5069e2eeda111d6f71970e341f4bd9aeeee92074e649ae263b834286ecc0"}, - {file = "greenlet-3.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81eeec4403a7d7684b5812a8aaa626fa23b7d0848edb3a28d2eb3220daddcbd0"}, - {file = "greenlet-3.1.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a3dae7492d16e85ea6045fd11cb8e782b63eac8c8d520c3a92c02ac4573b0a6"}, - {file = "greenlet-3.1.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b5ea3664eed571779403858d7cd0a9b0ebf50d57d2cdeafc7748e09ef8cd81a"}, - {file = "greenlet-3.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a22f4e26400f7f48faef2d69c20dc055a1f3043d330923f9abe08ea0aecc44df"}, - {file = "greenlet-3.1.0-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:13ff8c8e54a10472ce3b2a2da007f915175192f18e6495bad50486e87c7f6637"}, - {file = "greenlet-3.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f9671e7282d8c6fcabc32c0fb8d7c0ea8894ae85cee89c9aadc2d7129e1a9954"}, - {file = "greenlet-3.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:184258372ae9e1e9bddce6f187967f2e08ecd16906557c4320e3ba88a93438c3"}, - {file = "greenlet-3.1.0-cp37-cp37m-win32.whl", hash = "sha256:a0409bc18a9f85321399c29baf93545152d74a49d92f2f55302f122007cfda00"}, - {file = "greenlet-3.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:9eb4a1d7399b9f3c7ac68ae6baa6be5f9195d1d08c9ddc45ad559aa6b556bce6"}, - {file = "greenlet-3.1.0-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:a8870983af660798dc1b529e1fd6f1cefd94e45135a32e58bd70edd694540f33"}, - {file = "greenlet-3.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfcfb73aed40f550a57ea904629bdaf2e562c68fa1164fa4588e752af6efdc3f"}, - {file = "greenlet-3.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9482c2ed414781c0af0b35d9d575226da6b728bd1a720668fa05837184965b7"}, - {file = "greenlet-3.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d58ec349e0c2c0bc6669bf2cd4982d2f93bf067860d23a0ea1fe677b0f0b1e09"}, - {file = "greenlet-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd65695a8df1233309b701dec2539cc4b11e97d4fcc0f4185b4a12ce54db0491"}, - {file = "greenlet-3.1.0-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:665b21e95bc0fce5cab03b2e1d90ba9c66c510f1bb5fdc864f3a377d0f553f6b"}, - {file = "greenlet-3.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d3c59a06c2c28a81a026ff11fbf012081ea34fb9b7052f2ed0366e14896f0a1d"}, - {file = "greenlet-3.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415b9494ff6240b09af06b91a375731febe0090218e2898d2b85f9b92abcda0"}, - {file = "greenlet-3.1.0-cp38-cp38-win32.whl", hash = "sha256:1544b8dd090b494c55e60c4ff46e238be44fdc472d2589e943c241e0169bcea2"}, - {file = "greenlet-3.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:7f346d24d74c00b6730440f5eb8ec3fe5774ca8d1c9574e8e57c8671bb51b910"}, - {file = "greenlet-3.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:db1b3ccb93488328c74e97ff888604a8b95ae4f35f4f56677ca57a4fc3a4220b"}, - {file = "greenlet-3.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44cd313629ded43bb3b98737bba2f3e2c2c8679b55ea29ed73daea6b755fe8e7"}, - {file = "greenlet-3.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fad7a051e07f64e297e6e8399b4d6a3bdcad3d7297409e9a06ef8cbccff4f501"}, - {file = "greenlet-3.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3967dcc1cd2ea61b08b0b276659242cbce5caca39e7cbc02408222fb9e6ff39"}, - {file = "greenlet-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d45b75b0f3fd8d99f62eb7908cfa6d727b7ed190737dec7fe46d993da550b81a"}, - {file = "greenlet-3.1.0-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2d004db911ed7b6218ec5c5bfe4cf70ae8aa2223dffbb5b3c69e342bb253cb28"}, - {file = "greenlet-3.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b9505a0c8579899057cbefd4ec34d865ab99852baf1ff33a9481eb3924e2da0b"}, - {file = "greenlet-3.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5fd6e94593f6f9714dbad1aaba734b5ec04593374fa6638df61592055868f8b8"}, - {file = "greenlet-3.1.0-cp39-cp39-win32.whl", hash = "sha256:d0dd943282231480aad5f50f89bdf26690c995e8ff555f26d8a5b9887b559bcc"}, - {file = "greenlet-3.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:ac0adfdb3a21dc2a24ed728b61e72440d297d0fd3a577389df566651fcd08f97"}, - {file = "greenlet-3.1.0.tar.gz", hash = "sha256:b395121e9bbe8d02a750886f108d540abe66075e61e22f7353d9acb0b81be0f0"}, + {file = "greenlet-3.1.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:0bbae94a29c9e5c7e4a2b7f0aae5c17e8e90acbfd3bf6270eeba60c39fce3563"}, + {file = "greenlet-3.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fde093fb93f35ca72a556cf72c92ea3ebfda3d79fc35bb19fbe685853869a83"}, + {file = "greenlet-3.1.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:36b89d13c49216cadb828db8dfa6ce86bbbc476a82d3a6c397f0efae0525bdd0"}, + {file = "greenlet-3.1.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94b6150a85e1b33b40b1464a3f9988dcc5251d6ed06842abff82e42632fac120"}, + {file = "greenlet-3.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93147c513fac16385d1036b7e5b102c7fbbdb163d556b791f0f11eada7ba65dc"}, + {file = "greenlet-3.1.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:da7a9bff22ce038e19bf62c4dd1ec8391062878710ded0a845bcf47cc0200617"}, + {file = "greenlet-3.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b2795058c23988728eec1f36a4e5e4ebad22f8320c85f3587b539b9ac84128d7"}, + {file = "greenlet-3.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ed10eac5830befbdd0c32f83e8aa6288361597550ba669b04c48f0f9a2c843c6"}, + {file = "greenlet-3.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:77c386de38a60d1dfb8e55b8c1101d68c79dfdd25c7095d51fec2dd800892b80"}, + {file = "greenlet-3.1.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:e4d333e558953648ca09d64f13e6d8f0523fa705f51cae3f03b5983489958c70"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fc016b73c94e98e29af67ab7b9a879c307c6731a2c9da0db5a7d9b7edd1159"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d5e975ca70269d66d17dd995dafc06f1b06e8cb1ec1e9ed54c1d1e4a7c4cf26e"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b2813dc3de8c1ee3f924e4d4227999285fd335d1bcc0d2be6dc3f1f6a318ec1"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e347b3bfcf985a05e8c0b7d462ba6f15b1ee1c909e2dcad795e49e91b152c383"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e8f8c9cb53cdac7ba9793c276acd90168f416b9ce36799b9b885790f8ad6c0a"}, + {file = "greenlet-3.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:62ee94988d6b4722ce0028644418d93a52429e977d742ca2ccbe1c4f4a792511"}, + {file = "greenlet-3.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1776fd7f989fc6b8d8c8cb8da1f6b82c5814957264d1f6cf818d475ec2bf6395"}, + {file = "greenlet-3.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:48ca08c771c268a768087b408658e216133aecd835c0ded47ce955381105ba39"}, + {file = "greenlet-3.1.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:4afe7ea89de619adc868e087b4d2359282058479d7cfb94970adf4b55284574d"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f406b22b7c9a9b4f8aa9d2ab13d6ae0ac3e85c9a809bd590ad53fed2bf70dc79"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c3a701fe5a9695b238503ce5bbe8218e03c3bcccf7e204e455e7462d770268aa"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2846930c65b47d70b9d178e89c7e1a69c95c1f68ea5aa0a58646b7a96df12441"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99cfaa2110534e2cf3ba31a7abcac9d328d1d9f1b95beede58294a60348fba36"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1443279c19fca463fc33e65ef2a935a5b09bb90f978beab37729e1c3c6c25fe9"}, + {file = "greenlet-3.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b7cede291382a78f7bb5f04a529cb18e068dd29e0fb27376074b6d0317bf4dd0"}, + {file = "greenlet-3.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:23f20bb60ae298d7d8656c6ec6db134bca379ecefadb0b19ce6f19d1f232a942"}, + {file = "greenlet-3.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:7124e16b4c55d417577c2077be379514321916d5790fa287c9ed6f23bd2ffd01"}, + {file = "greenlet-3.1.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:05175c27cb459dcfc05d026c4232f9de8913ed006d42713cb8a5137bd49375f1"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:935e943ec47c4afab8965954bf49bfa639c05d4ccf9ef6e924188f762145c0ff"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:667a9706c970cb552ede35aee17339a18e8f2a87a51fba2ed39ceeeb1004798a"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b8a678974d1f3aa55f6cc34dc480169d58f2e6d8958895d68845fa4ab566509e"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efc0f674aa41b92da8c49e0346318c6075d734994c3c4e4430b1c3f853e498e4"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0153404a4bb921f0ff1abeb5ce8a5131da56b953eda6e14b88dc6bbc04d2049e"}, + {file = "greenlet-3.1.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:275f72decf9932639c1c6dd1013a1bc266438eb32710016a1c742df5da6e60a1"}, + {file = "greenlet-3.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c4aab7f6381f38a4b42f269057aee279ab0fc7bf2e929e3d4abfae97b682a12c"}, + {file = "greenlet-3.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:b42703b1cf69f2aa1df7d1030b9d77d3e584a70755674d60e710f0af570f3761"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1695e76146579f8c06c1509c7ce4dfe0706f49c6831a817ac04eebb2fd02011"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7876452af029456b3f3549b696bb36a06db7c90747740c5302f74a9e9fa14b13"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ead44c85f8ab905852d3de8d86f6f8baf77109f9da589cb4fa142bd3b57b475"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8320f64b777d00dd7ccdade271eaf0cad6636343293a25074cc5566160e4de7b"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6510bf84a6b643dabba74d3049ead221257603a253d0a9873f55f6a59a65f822"}, + {file = "greenlet-3.1.1-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:04b013dc07c96f83134b1e99888e7a79979f1a247e2a9f59697fa14b5862ed01"}, + {file = "greenlet-3.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:411f015496fec93c1c8cd4e5238da364e1da7a124bcb293f085bf2860c32c6f6"}, + {file = "greenlet-3.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47da355d8687fd65240c364c90a31569a133b7b60de111c255ef5b606f2ae291"}, + {file = "greenlet-3.1.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:98884ecf2ffb7d7fe6bd517e8eb99d31ff7855a840fa6d0d63cd07c037f6a981"}, + {file = "greenlet-3.1.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1d4aeb8891338e60d1ab6127af1fe45def5259def8094b9c7e34690c8858803"}, + {file = "greenlet-3.1.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db32b5348615a04b82240cc67983cb315309e88d444a288934ee6ceaebcad6cc"}, + {file = "greenlet-3.1.1-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dcc62f31eae24de7f8dce72134c8651c58000d3b1868e01392baea7c32c247de"}, + {file = "greenlet-3.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1d3755bcb2e02de341c55b4fca7a745a24a9e7212ac953f6b3a48d117d7257aa"}, + {file = "greenlet-3.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b8da394b34370874b4572676f36acabac172602abf054cbc4ac910219f3340af"}, + {file = "greenlet-3.1.1-cp37-cp37m-win32.whl", hash = "sha256:a0dfc6c143b519113354e780a50381508139b07d2177cb6ad6a08278ec655798"}, + {file = "greenlet-3.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:54558ea205654b50c438029505def3834e80f0869a70fb15b871c29b4575ddef"}, + {file = "greenlet-3.1.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:346bed03fe47414091be4ad44786d1bd8bef0c3fcad6ed3dee074a032ab408a9"}, + {file = "greenlet-3.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfc59d69fc48664bc693842bd57acfdd490acafda1ab52c7836e3fc75c90a111"}, + {file = "greenlet-3.1.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21e10da6ec19b457b82636209cbe2331ff4306b54d06fa04b7c138ba18c8a81"}, + {file = "greenlet-3.1.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:37b9de5a96111fc15418819ab4c4432e4f3c2ede61e660b1e33971eba26ef9ba"}, + {file = "greenlet-3.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ef9ea3f137e5711f0dbe5f9263e8c009b7069d8a1acea822bd5e9dae0ae49c8"}, + {file = "greenlet-3.1.1-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85f3ff71e2e60bd4b4932a043fbbe0f499e263c628390b285cb599154a3b03b1"}, + {file = "greenlet-3.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:95ffcf719966dd7c453f908e208e14cde192e09fde6c7186c8f1896ef778d8cd"}, + {file = "greenlet-3.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:03a088b9de532cbfe2ba2034b2b85e82df37874681e8c470d6fb2f8c04d7e4b7"}, + {file = "greenlet-3.1.1-cp38-cp38-win32.whl", hash = "sha256:8b8b36671f10ba80e159378df9c4f15c14098c4fd73a36b9ad715f057272fbef"}, + {file = "greenlet-3.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:7017b2be767b9d43cc31416aba48aab0d2309ee31b4dbf10a1d38fb7972bdf9d"}, + {file = "greenlet-3.1.1-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:396979749bd95f018296af156201d6211240e7a23090f50a8d5d18c370084dc3"}, + {file = "greenlet-3.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca9d0ff5ad43e785350894d97e13633a66e2b50000e8a183a50a88d834752d42"}, + {file = "greenlet-3.1.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f6ff3b14f2df4c41660a7dec01045a045653998784bf8cfcb5a525bdffffbc8f"}, + {file = "greenlet-3.1.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94ebba31df2aa506d7b14866fed00ac141a867e63143fe5bca82a8e503b36437"}, + {file = "greenlet-3.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73aaad12ac0ff500f62cebed98d8789198ea0e6f233421059fa68a5aa7220145"}, + {file = "greenlet-3.1.1-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:63e4844797b975b9af3a3fb8f7866ff08775f5426925e1e0bbcfe7932059a12c"}, + {file = "greenlet-3.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7939aa3ca7d2a1593596e7ac6d59391ff30281ef280d8632fa03d81f7c5f955e"}, + {file = "greenlet-3.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d0028e725ee18175c6e422797c407874da24381ce0690d6b9396c204c7f7276e"}, + {file = "greenlet-3.1.1-cp39-cp39-win32.whl", hash = "sha256:5e06afd14cbaf9e00899fae69b24a32f2196c19de08fcb9f4779dd4f004e5e7c"}, + {file = "greenlet-3.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:3319aa75e0e0639bc15ff54ca327e8dc7a6fe404003496e3c6925cd3142e0e22"}, + {file = "greenlet-3.1.1.tar.gz", hash = "sha256:4ce3ac6cdb6adf7946475d7ef31777c26d94bccc377e070a7986bd2d5c515467"}, ] [package.extras] @@ -1887,12 +1904,13 @@ ptyprocess = ">=0.5" [[package]] name = "pgvector" -version = "0.2.5" +version = "0.3.6" description = "pgvector support for Python" optional = false python-versions = ">=3.8" files = [ - {file = "pgvector-0.2.5-py2.py3-none-any.whl", hash = "sha256:5e5e93ec4d3c45ab1fa388729d56c602f6966296e19deee8878928c6d567e41b"}, + {file = "pgvector-0.3.6-py3-none-any.whl", hash = "sha256:f6c269b3c110ccb7496bac87202148ed18f34b390a0189c783e351062400a75a"}, + {file = "pgvector-0.3.6.tar.gz", hash = "sha256:31d01690e6ea26cea8a633cde5f0f55f5b246d9c8292d68efdef8c22ec994ade"}, ] [package.dependencies] @@ -2294,6 +2312,22 @@ files = [ [package.dependencies] pytest = ">=7.0.0" +[[package]] +name = "pytest-watch" +version = "4.2.0" +description = "Local continuous test runner with pytest and watchdog." +optional = false +python-versions = "*" +files = [ + {file = "pytest-watch-4.2.0.tar.gz", hash = "sha256:06136f03d5b361718b8d0d234042f7b2f203910d8568f63df2f866b547b3d4b9"}, +] + +[package.dependencies] +colorama = ">=0.3.3" +docopt = ">=0.4.0" +pytest = ">=2.6.4" +watchdog = ">=0.6.0" + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3086,6 +3120,48 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "watchdog" +version = "6.0.0" +description = "Filesystem events monitoring" +optional = false +python-versions = ">=3.9" +files = [ + {file = "watchdog-6.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d1cdb490583ebd691c012b3d6dae011000fe42edb7a82ece80965b42abd61f26"}, + {file = "watchdog-6.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc64ab3bdb6a04d69d4023b29422170b74681784ffb9463ed4870cf2f3e66112"}, + {file = "watchdog-6.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c897ac1b55c5a1461e16dae288d22bb2e412ba9807df8397a635d88f671d36c3"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6eb11feb5a0d452ee41f824e271ca311a09e250441c262ca2fd7ebcf2461a06c"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ef810fbf7b781a5a593894e4f439773830bdecb885e6880d957d5b9382a960d2"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:afd0fe1b2270917c5e23c2a65ce50c2a4abb63daafb0d419fde368e272a76b7c"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bdd4e6f14b8b18c334febb9c4425a878a2ac20efd1e0b231978e7b150f92a948"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c7c15dda13c4eb00d6fb6fc508b3c0ed88b9d5d374056b239c4ad1611125c860"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f10cb2d5902447c7d0da897e2c6768bca89174d0c6e1e30abec5421af97a5b0"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:490ab2ef84f11129844c23fb14ecf30ef3d8a6abafd3754a6f75ca1e6654136c"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:76aae96b00ae814b181bb25b1b98076d5fc84e8a53cd8885a318b42b6d3a5134"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a175f755fc2279e0b7312c0035d52e27211a5bc39719dd529625b1930917345b"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e6f0e77c9417e7cd62af82529b10563db3423625c5fce018430b249bf977f9e8"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:90c8e78f3b94014f7aaae121e6b909674df5b46ec24d6bebc45c44c56729af2a"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e7631a77ffb1f7d2eefa4445ebbee491c720a5661ddf6df3498ebecae5ed375c"}, + {file = "watchdog-6.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:c7ac31a19f4545dd92fc25d200694098f42c9a8e391bc00bdd362c5736dbf881"}, + {file = "watchdog-6.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9513f27a1a582d9808cf21a07dae516f0fab1cf2d7683a742c498b93eedabb11"}, + {file = "watchdog-6.0.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7a0e56874cfbc4b9b05c60c8a1926fedf56324bb08cfbc188969777940aef3aa"}, + {file = "watchdog-6.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:e6439e374fc012255b4ec786ae3c4bc838cd7309a540e5fe0952d03687d8804e"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:212ac9b8bf1161dc91bd09c048048a95ca3a4c4f5e5d4a7d1b1a7d5752a7f96f"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:e3df4cbb9a450c6d49318f6d14f4bbc80d763fa587ba46ec86f99f9e6876bb26"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:2cce7cfc2008eb51feb6aab51251fd79b85d9894e98ba847408f662b3395ca3c"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:20ffe5b202af80ab4266dcd3e91aae72bf2da48c0d33bdb15c66658e685e94e2"}, + {file = "watchdog-6.0.0-py3-none-win32.whl", hash = "sha256:07df1fdd701c5d4c8e55ef6cf55b8f0120fe1aef7ef39a1c6fc6bc2e606d517a"}, + {file = "watchdog-6.0.0-py3-none-win_amd64.whl", hash = "sha256:cbafb470cf848d93b5d013e2ecb245d4aa1c8fd0504e863ccefa32445359d680"}, + {file = "watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f"}, + {file = "watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282"}, +] + +[package.extras] +watchmedo = ["PyYAML (>=3.10)"] + [[package]] name = "wcwidth" version = "0.2.13" @@ -3263,4 +3339,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "d9cb071efbb2e562b65a9950ddfa0e05bd460f5db51895de2a6f929453d82021" +content-hash = "489b42ae245dba71e1c26cd63a3e1ac780518565b259676b9362aa63d8a898d9" diff --git a/pyproject.toml b/pyproject.toml index 0164257..91e5a2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,11 +16,9 @@ langchain-core = ">=0.2.13,<0.4.0" psycopg = "^3" psycopg-pool = "^3.2.1" sqlalchemy = "^2" -pgvector = "^0.2.5" +pgvector = "^0.3.6" numpy = "^1" -[tool.poetry.group.docs.dependencies] - [tool.poetry.group.dev.dependencies] jupyterlab = "^3.6.1" @@ -35,6 +33,8 @@ pytest-socket = "^0.7.0" pytest-cov = "^5.0.0" pytest-timeout = "^2.3.1" langchain-tests = "0.3.7" +greenlet = "^3.1.1" +pytest-watch = "^4.2.0" [tool.poetry.group.codespell] optional = true diff --git a/tests/unit_tests/query_constructors/test_pgvector.py b/tests/unit_tests/query_constructors/test_pgvector.py index 366c0b2..5359f41 100644 --- a/tests/unit_tests/query_constructors/test_pgvector.py +++ b/tests/unit_tests/query_constructors/test_pgvector.py @@ -1,6 +1,11 @@ +# type: ignore + +import os +import re from typing import Dict, Tuple import pytest as pytest +import sqlalchemy from langchain_core.structured_query import ( Comparator, Comparison, @@ -8,12 +13,82 @@ Operator, StructuredQuery, ) +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import ( + declarative_base, +) -from langchain_postgres import PGVectorTranslator +from langchain_postgres import ( + EmbeddingIndexType, + PGVector, + PGVectorTranslator, + vectorstores, +) +from langchain_postgres.vectorstores import ( + IterativeScan, + _get_embedding_collection_store, +) +from tests.unit_tests.test_vectorstore import FakeEmbeddingsWithAdaDimension +from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING +from tests.utils import async_session, sync_session DEFAULT_TRANSLATOR = PGVectorTranslator() +EmbeddingStore, CollectionStore = _get_embedding_collection_store() + + +@pytest.fixture(scope="function") +def drop_tables() -> None: + def drop() -> None: + with sync_session() as session: + session.execute(text("DROP SCHEMA public CASCADE")) + session.execute(text("CREATE SCHEMA public")) + session.execute( + text( + f""" + GRANT ALL + ON SCHEMA public + TO {os.environ.get('POSTGRES_USER', 'langchain')} + """ + ) + ) + session.execute(text("GRANT ALL ON SCHEMA public TO public")) + session.commit() + + vectorstores._classes = None + vectorstores.Base = declarative_base() + + drop() + + yield + + drop() + + +def normalize_sql(query) -> str: + # Remove new lines, tabs, and multiple spaces + query = re.sub(r"\s+", " ", query).strip() + # Normalize space around commas + query = re.sub(r"\s*,\s*", ", ", query) + # Normalize space around parentheses + query = re.sub(r"\(\s*", "(", query) + query = re.sub(r"\s*\)", ")", query) + return query + + +def get_index_definition(index_name) -> sqlalchemy.sql.text: + query = f""" + SELECT indexdef + FROM pg_indexes + WHERE tablename = '{EmbeddingStore.__tablename__}' + AND indexname = '{index_name}' + """ + + return text(query) + + def test_visit_comparison() -> None: comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1) expected = {"foo": {"$lt": 1}} @@ -85,3 +160,931 @@ def test_visit_structured_query() -> None: ) actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) assert expected == actual + + +def test_embedding_index_without_length() -> None: + with pytest.raises(ValueError): + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + ) + + +def test_embedding_index_without_ops() -> None: + with pytest.raises(ValueError): + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + ) + + +@pytest.mark.usefixtures("drop_tables") +def test_embedding_index_hnsw() -> None: + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + ) + + with sync_session() as session: + result = session.execute(get_index_definition("ix_embedding_hnsw")) + + expected = f""" + CREATE INDEX ix_embedding_hnsw + ON public.{EmbeddingStore.__tablename__} + USING hnsw (embedding vector_cosine_ops) + """ + + assert normalize_sql(result.fetchone()[0]) == normalize_sql(expected) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("drop_tables") +async def test_embedding_index_hnsw_async() -> None: + pgvector = PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + connection=create_async_engine(CONNECTION_STRING), + async_mode=True, + ) + + await pgvector.__apost_init__() + + async with async_session() as session: + result = await session.execute(get_index_definition("ix_embedding_hnsw")) + + expected = f""" + CREATE INDEX ix_embedding_hnsw + ON public.{EmbeddingStore.__tablename__} + USING hnsw (embedding vector_cosine_ops) + """ + + assert normalize_sql(result.fetchone()[0]) == normalize_sql(expected) + + +@pytest.mark.usefixtures("drop_tables") +def test_embedding_index_ivfflat() -> None: + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.ivfflat, + embedding_index_ops="vector_cosine_ops", + ) + + with sync_session() as session: + result = session.execute(get_index_definition("ix_embedding_ivfflat")) + + expected = f""" + CREATE INDEX ix_embedding_ivfflat + ON public.{EmbeddingStore.__tablename__} + USING ivfflat (embedding vector_cosine_ops) + """ + + assert normalize_sql(result.fetchone()[0]) == normalize_sql(expected) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("drop_tables") +async def test_embedding_index_ivfflat_async() -> None: + pgvector = PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + embedding_length=1536, + embedding_index=EmbeddingIndexType.ivfflat, + embedding_index_ops="vector_cosine_ops", + connection=create_async_engine(CONNECTION_STRING), + async_mode=True, + ) + + await pgvector.__apost_init__() + + async with async_session() as session: + result = await session.execute(get_index_definition("ix_embedding_ivfflat")) + + expected = f""" + CREATE INDEX ix_embedding_ivfflat + ON public.{EmbeddingStore.__tablename__} + USING ivfflat (embedding vector_cosine_ops) + """ + + assert normalize_sql(result.fetchone()[0]) == normalize_sql(expected) + + +def test_binary_quantization_without_hnsw() -> None: + with pytest.raises(ValueError): + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.ivfflat, + embedding_index_ops="bit_hamming_ops", + binary_quantization=True, + binary_limit=200, + ) + + +def test_binary_quantization_without_hamming_ops() -> None: + with pytest.raises(ValueError): + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + binary_quantization=True, + embedding_index_ops="vector_cosine_ops", + binary_limit=200, + ) + + +def test_binary_quantization_without_binary_limit() -> None: + with pytest.raises(ValueError): + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="bit_hamming_ops", + binary_quantization=True, + ) + + +@pytest.mark.usefixtures("drop_tables") +def test_binary_quantization_index() -> None: + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="bit_hamming_ops", + binary_quantization=True, + binary_limit=200, + ) + + with sync_session() as session: + result = session.execute(get_index_definition("ix_embedding_hnsw")) + + expected = f""" + CREATE INDEX ix_embedding_hnsw + ON public.{EmbeddingStore.__tablename__} + USING hnsw (((binary_quantize(embedding))::bit(1536)) bit_hamming_ops) + """ + + assert normalize_sql(result.fetchone()[0]) == normalize_sql(expected) + + +def test_ef_construction_without_hnsw() -> None: + with pytest.raises(ValueError): + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.ivfflat, + embedding_index_ops="vector_cosine_ops", + ef_construction=256, + ) + + +def test_m_without_hnsw() -> None: + with pytest.raises(ValueError): + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.ivfflat, + embedding_index_ops="vector_cosine_ops", + m=16, + ) + + +@pytest.mark.usefixtures("drop_tables") +def test_ef_construction() -> None: + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + ef_construction=256, + ) + + with sync_session() as session: + result = session.execute(get_index_definition("ix_embedding_hnsw")) + + expected = f""" + CREATE INDEX ix_embedding_hnsw + ON public.{EmbeddingStore.__tablename__} + USING hnsw (embedding vector_cosine_ops) WITH (ef_construction='256') + """ + + assert normalize_sql(result.fetchone()[0]) == normalize_sql(expected) + + +@pytest.mark.usefixtures("drop_tables") +def test_m() -> None: + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + m=16, + ) + + with sync_session() as session: + result = session.execute(get_index_definition("ix_embedding_hnsw")) + + expected = f""" + CREATE INDEX ix_embedding_hnsw + ON public.{EmbeddingStore.__tablename__} + USING hnsw (embedding vector_cosine_ops) WITH (m='16') + """ + + assert normalize_sql(result.fetchone()[0]) == normalize_sql(expected) + + +@pytest.mark.usefixtures("drop_tables") +def test_ef_construction_and_m() -> None: + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + ef_construction=256, + m=16, + ) + + with sync_session() as session: + result = session.execute(get_index_definition("ix_embedding_hnsw")) + + expected = f""" + CREATE INDEX ix_embedding_hnsw + ON public.{EmbeddingStore.__tablename__} + USING hnsw (embedding vector_cosine_ops) + WITH (m='16', ef_construction='256') + """ + + assert normalize_sql(result.fetchone()[0]) == normalize_sql(expected) + + +@pytest.mark.usefixtures("drop_tables") +def test_binary_quantization_with_ef_construction_and_m() -> None: + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="bit_hamming_ops", + binary_quantization=True, + binary_limit=200, + ef_construction=256, + m=16, + ) + + with sync_session() as session: + result = session.execute(get_index_definition("ix_embedding_hnsw")) + + expected = f""" + CREATE INDEX ix_embedding_hnsw + ON public.{EmbeddingStore.__tablename__} + USING hnsw (((binary_quantize(embedding))::bit(1536)) bit_hamming_ops) + WITH (m='16', ef_construction='256') + """ + + assert normalize_sql(result.fetchone()[0]) == normalize_sql(expected) + + +get_partitioned_table = text( + f""" + SELECT + c.relname AS table_name, + p.partstrat AS partition_strategy, + pg_attribute.attname AS partition_key + FROM + pg_class c + JOIN + pg_partitioned_table p ON c.oid = p.partrelid + LEFT JOIN + pg_attribute ON pg_attribute.attrelid = p.partrelid + AND pg_attribute.attnum = ANY(p.partattrs) + WHERE + c.relname = '{EmbeddingStore.__tablename__}' +""" +) + + +@pytest.mark.usefixtures("drop_tables") +def test_partitioning() -> None: + PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + enable_partitioning=True, + ) + + with sync_session() as session: + result = session.execute(get_partitioned_table) + + assert result.fetchone() == (EmbeddingStore.__tablename__, "l", "collection_id") + + query = f""" + SELECT indexname + FROM pg_indexes + WHERE tablename = '{EmbeddingStore.__tablename__}' + """ + + result = session.execute(text(query)) + + assert result.scalars().fetchall() == [ + "langchain_pg_embedding_pkey", + "ix_cmetadata_gin", + "ix_document_vector_gin", + "ix_embedding_hnsw", + ] + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("drop_tables") +async def test_partitioning_async() -> None: + pgvector = PGVector( + collection_name="test_collection", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=create_async_engine(CONNECTION_STRING), + embedding_length=1536, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + enable_partitioning=True, + async_mode=True, + ) + + await pgvector.__apost_init__() + + async with async_session() as session: + result = await session.execute(get_partitioned_table) + + assert result.fetchone() == (EmbeddingStore.__tablename__, "l", "collection_id") + + query = f""" + SELECT indexname + FROM pg_indexes + WHERE tablename = '{EmbeddingStore.__tablename__}' + """ + + result = await session.execute(text(query)) + + assert result.scalars().fetchall() == [ + "langchain_pg_embedding_pkey", + "ix_cmetadata_gin", + "ix_document_vector_gin", + "ix_embedding_hnsw", + ] + + +get_partitions = text( + f""" + SELECT c.relname as partitioned_table_name, + pg_get_expr(c.relpartbound, c.oid) as partition_bound + FROM pg_class c + JOIN pg_inherits i ON c.oid = i.inhrelid + JOIN pg_class pc ON i.inhparent = pc.oid + WHERE pc.relname = '{EmbeddingStore.__tablename__}' +""" +) + + +@pytest.mark.usefixtures("drop_tables") +def test_partitions_creation() -> None: + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + enable_partitioning=True, + ) + + with sync_session() as session: + collection1 = pgvector.get_collection(session) + + pgvector = PGVector( + collection_name="test_collection2", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + enable_partitioning=True, + ) + + with sync_session() as session: + collection2 = pgvector.get_collection(session) + + with sync_session() as session: + result = session.execute(get_partitions) + + collection1_underscored = str(collection1.uuid).replace("-", "_") + collection2_underscored = str(collection2.uuid).replace("-", "_") + + assert result.fetchall() == [ + ( + f"{EmbeddingStore.__tablename__}_{collection1_underscored}", + f"FOR VALUES IN ('{str(collection1.uuid)}')", + ), + ( + f"{EmbeddingStore.__tablename__}_{collection2_underscored}", + f"FOR VALUES IN ('{str(collection2.uuid)}')", + ), + ] + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("drop_tables") +async def test_partitions_creation_async() -> None: + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=create_async_engine(CONNECTION_STRING), + enable_partitioning=True, + async_mode=True, + ) + await pgvector.__apost_init__() + + async with async_session() as session: + collection1 = await pgvector.aget_collection(session) + + pgvector = PGVector( + collection_name="test_collection2", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=create_async_engine(CONNECTION_STRING), + enable_partitioning=True, + async_mode=True, + ) + await pgvector.__apost_init__() + + async with async_session() as session: + collection2 = await pgvector.aget_collection(session) + + async with async_session() as session: + result = await session.execute(get_partitions) + + collection1_underscored = str(collection1.uuid).replace("-", "_") + collection2_underscored = str(collection2.uuid).replace("-", "_") + + assert result.fetchall() == [ + ( + f"{EmbeddingStore.__tablename__}_{collection1_underscored}", + f"FOR VALUES IN ('{str(collection1.uuid)}')", + ), + ( + f"{EmbeddingStore.__tablename__}_{collection2_underscored}", + f"FOR VALUES IN ('{str(collection2.uuid)}')", + ), + ] + + +@pytest.mark.usefixtures("drop_tables") +def test_partitions_deletion() -> None: + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + enable_partitioning=True, + ) + + pgvector.delete_collection() + + with sync_session() as session: + result = session.execute(get_partitions) + + assert result.fetchall() == [] + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("drop_tables") +async def test_partitions_deletion_async() -> None: + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=create_async_engine(CONNECTION_STRING), + enable_partitioning=True, + async_mode=True, + ) + await pgvector.__apost_init__() + + await pgvector.adelete_collection() + + async with async_session() as session: + result = await session.execute(get_partitions) + + assert result.fetchall() == [] + + +def test_iterative_scan_without_index() -> None: + with pytest.raises(ValueError): + PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + iterative_scan=IterativeScan.off, + ) + + +@pytest.mark.usefixtures("drop_tables") +def test_iterative_scan() -> None: + with sync_session() as session: + with pytest.raises(sqlalchemy.exc.ProgrammingError): + session.execute(text("SHOW hnsw.iterative_scan")) + + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + embedding_length=1536, + iterative_scan=IterativeScan.off, + ) + + with sync_session() as session: + pgvector._set_iterative_scan(session) + + result = session.execute(text("SHOW hnsw.iterative_scan")) + assert result.fetchone()[0] == "off" + + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + embedding_length=1536, + iterative_scan=IterativeScan.relaxed_order, + ) + + with sync_session() as session: + pgvector._set_iterative_scan(session) + + result = session.execute(text("SHOW hnsw.iterative_scan")) + assert result.fetchone()[0] == "relaxed_order" + + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + embedding_length=1536, + iterative_scan=IterativeScan.strict_order, + ) + + with sync_session() as session: + pgvector._set_iterative_scan(session) + + result = session.execute(text("SHOW hnsw.iterative_scan")) + assert result.fetchone()[0] == "strict_order" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("drop_tables") +async def test_iterative_scan_async() -> None: + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=create_async_engine(CONNECTION_STRING), + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + embedding_length=1536, + iterative_scan=IterativeScan.strict_order, + async_mode=True, + ) + + await pgvector.__apost_init__() + + async with async_session() as session: + await pgvector._aset_iterative_scan(session) + + result = await session.execute(text("SHOW hnsw.iterative_scan")) + assert result.fetchone()[0] == "strict_order" + + +def test_max_scan_tuples_without_hnsw() -> None: + with pytest.raises(ValueError): + PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + max_scan_tuples=200, + ) + + +@pytest.mark.usefixtures("drop_tables") +def test_max_scan_tuples() -> None: + with sync_session() as session: + with pytest.raises(sqlalchemy.exc.ProgrammingError): + session.execute(text("SHOW hnsw.max_scan_tuples")) + + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + embedding_length=1536, + max_scan_tuples=200, + ) + + with sync_session() as session: + pgvector._set_max_scan_tuples(session) + + result = session.execute(text("SHOW hnsw.max_scan_tuples")) + assert result.fetchone()[0] == "200" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("drop_tables") +async def test_max_scan_tuples_async() -> None: + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=create_async_engine(CONNECTION_STRING), + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + embedding_length=1536, + max_scan_tuples=200, + async_mode=True, + ) + + await pgvector.__apost_init__() + + async with async_session() as session: + await pgvector._aset_max_scan_tuples(session) + + result = await session.execute(text("SHOW hnsw.max_scan_tuples")) + assert result.fetchone()[0] == "200" + + +def test_scan_mem_multiplier_without_hnsw() -> None: + with pytest.raises(ValueError): + PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + scan_mem_multiplier=200, + ) + + +@pytest.mark.usefixtures("drop_tables") +def test_scan_mem_multiplier() -> None: + with sync_session() as session: + with pytest.raises(sqlalchemy.exc.ProgrammingError): + session.execute(text("SHOW hnsw.scan_mem_multiplier")) + + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + embedding_length=1536, + scan_mem_multiplier=2, + ) + + with sync_session() as session: + pgvector._set_scan_mem_multiplier(session) + + result = session.execute(text("SHOW hnsw.scan_mem_multiplier")) + assert result.fetchone()[0] == "2" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("drop_tables") +async def test_scan_mem_multiplier_async() -> None: + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=create_async_engine(CONNECTION_STRING), + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + embedding_length=1536, + scan_mem_multiplier=2, + async_mode=True, + ) + + await pgvector.__apost_init__() + + async with async_session() as session: + await pgvector._aset_scan_mem_multiplier(session) + + result = await session.execute(text("SHOW hnsw.scan_mem_multiplier")) + assert result.fetchone()[0] == "2" + + +@pytest.mark.usefixtures("drop_tables") +def test_binary_quantization_query() -> None: + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=3, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="bit_hamming_ops", + binary_quantization=True, + binary_limit=200, + ) + + with sync_session() as session: + collection = pgvector.get_collection(session) + + stmt = pgvector._build_query_collection( + collection=collection, + embedding=[1.0, 0.0, -1.0], + k=20, + filter={"test_key": "test_value"}, + full_text_search=["word1", "word2 & word3"], + ) + + compiled = stmt.compile(dialect=sqlalchemy.dialects.postgresql.dialect()) + + query = str(compiled) + params = compiled.params + + expected_query = """ + SELECT + binary_result.id, + binary_result.collection_id, + binary_result.embedding, + binary_result.document, + binary_result.cmetadata, + binary_result.document_vector, + binary_result.embedding <=> %(embedding_1)s AS distance + FROM + ( + SELECT + langchain_pg_embedding.id AS id, + langchain_pg_embedding.collection_id AS collection_id, + langchain_pg_embedding.embedding AS embedding, + langchain_pg_embedding.document AS document, + langchain_pg_embedding.cmetadata AS cmetadata, + langchain_pg_embedding.document_vector AS document_vector + FROM + langchain_pg_embedding + WHERE + langchain_pg_embedding.collection_id = %(collection_id_1)s::UUID + AND jsonb_path_match( + langchain_pg_embedding.cmetadata, + CAST(%(param_1)s AS JSONPATH + ), CAST(%(param_2)s AS JSONB)) + AND ( + langchain_pg_embedding.document_vector + @@ to_tsquery(%(to_tsquery_1)s) + ) + ORDER BY + CAST( + binary_quantize(langchain_pg_embedding.embedding) AS BIT(3) + ) <~> binary_quantize(CAST(%(param_3)s AS VECTOR(3))) + LIMIT + %(param_4)s + ) AS binary_result + ORDER BY + distance ASC + LIMIT + %(param_5)s + """.replace("\n", "").replace(" ", " ") + + assert normalize_sql(query) == normalize_sql(expected_query) + assert params == { + "embedding_1": [1.0, 0.0, -1.0], + "collection_id_1": collection.uuid, + "param_1": "$.test_key == $value", + "param_2": {"value": "test_value"}, + "to_tsquery_1": "word1 | word2 & word3", + "param_3": [1.0, 0.0, -1.0], + "param_4": 200, + "param_5": 20, + } + + +@pytest.mark.usefixtures("drop_tables") +def test_relaxed_order_query() -> None: + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + embedding_length=3, + embedding_index=EmbeddingIndexType.hnsw, + embedding_index_ops="vector_cosine_ops", + iterative_scan=IterativeScan.relaxed_order, + ) + + with sync_session() as session: + collection = pgvector.get_collection(session) + + stmt = pgvector._build_query_collection( + collection=collection, + embedding=[1.0, 0.0, -1.0], + k=20, + filter={"test_key": "test_value"}, + full_text_search=["word1", "word2 & word3"], + ) + + compiled = stmt.compile(dialect=sqlalchemy.dialects.postgresql.dialect()) + + query = str(compiled) + params = compiled.params + + expected_query = """ + WITH relaxed_results AS MATERIALIZED ( + SELECT + langchain_pg_embedding.id AS id, + langchain_pg_embedding.collection_id AS collection_id, + langchain_pg_embedding.embedding AS embedding, + langchain_pg_embedding.document AS document, + langchain_pg_embedding.cmetadata AS cmetadata, + langchain_pg_embedding.document_vector AS document_vector, + langchain_pg_embedding.embedding <=> %(embedding_1)s AS distance + FROM + langchain_pg_embedding + WHERE + langchain_pg_embedding.collection_id = %(collection_id_1)s::UUID + AND jsonb_path_match( + langchain_pg_embedding.cmetadata, + CAST(%(param_1)s AS JSONPATH), + CAST(%(param_2)s AS JSONB) + ) + AND ( + langchain_pg_embedding.document_vector + @@ to_tsquery(%(to_tsquery_1)s) + ) + ORDER BY + distance ASC + LIMIT + %(param_3)s + ) + SELECT + relaxed_results.id, + relaxed_results.collection_id, + relaxed_results.embedding, + relaxed_results.document, + relaxed_results.cmetadata, + relaxed_results.document_vector, + relaxed_results.distance + FROM + relaxed_results + ORDER BY + distance + """ + + assert normalize_sql(query) == normalize_sql(expected_query) + assert params == { + "embedding_1": [1.0, 0.0, -1.0], + "collection_id_1": collection.uuid, + "param_1": "$.test_key == $value", + "param_2": {"value": "test_value"}, + "to_tsquery_1": "word1 | word2 & word3", + "param_3": 20, + } + + +@pytest.mark.usefixtures("drop_tables") +def test_build_full_text_search_filter() -> None: + pgvector = PGVector( + collection_name="test_collection1", + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + ) + + stmt = pgvector._build_full_text_search_filter(keywords=["word1", "word2 & word3"]) + + compiled = stmt.compile(dialect=sqlalchemy.dialects.postgresql.dialect()) + + query = str(compiled) + params = compiled.params + + expected_query = """ + langchain_pg_embedding.document_vector @@ to_tsquery(%(to_tsquery_1)s) + """ + + assert normalize_sql(query) == normalize_sql(expected_query) + assert params == {"to_tsquery_1": "word1 | word2 & word3"} diff --git a/tests/unit_tests/test_imports.py b/tests/unit_tests/test_imports.py index 4513d18..24c8709 100644 --- a/tests/unit_tests/test_imports.py +++ b/tests/unit_tests/test_imports.py @@ -5,6 +5,8 @@ "PGVector", "PGVectorTranslator", "PostgresChatMessageHistory", + "EmbeddingIndexType", + "IterativeScan", ] diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 2383daf..eb012e3 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -1,8 +1,9 @@ """Test PGVector functionality.""" import contextlib -from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence +from typing import Any, AsyncGenerator, Dict, Generator, List, Optional import pytest +from _pytest.monkeypatch import MonkeyPatch from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from sqlalchemy import select @@ -25,12 +26,19 @@ ADA_TOKEN_COUNT = 1536 -def _compare_documents(left: Sequence[Document], right: Sequence[Document]) -> None: - """Compare lists of documents, irrespective of IDs.""" - assert len(left) == len(right) - for left_doc, right_doc in zip(left, right): - assert left_doc.page_content == right_doc.page_content - assert left_doc.metadata == right_doc.metadata +@pytest.fixture(autouse=True) +def patch_document(monkeypatch: MonkeyPatch) -> None: + def eq(self: Any, other: Any) -> bool: + return ( + self.page_content == other.page_content and self.metadata == other.metadata + ) + + monkeypatch.setattr(Document, "__eq__", eq) + + +class AnyStr(str): + def __eq__(self, other: Any) -> bool: + return isinstance(other, str) class FakeEmbeddingsWithAdaDimension(FakeEmbeddings): @@ -58,7 +66,7 @@ def test_pgvector() -> None: pre_delete_collection=True, ) output = docsearch.similarity_search("foo", k=1) - _compare_documents(output, [Document(page_content="foo")]) + assert output == [Document(page_content="foo")] @pytest.mark.asyncio @@ -73,7 +81,7 @@ async def test_async_pgvector() -> None: pre_delete_collection=True, ) output = await docsearch.asimilarity_search("foo", k=1) - _compare_documents(output, [Document(page_content="foo")]) + assert output == [Document(page_content="foo")] def test_pgvector_embeddings() -> None: @@ -89,7 +97,7 @@ def test_pgvector_embeddings() -> None: pre_delete_collection=True, ) output = docsearch.similarity_search("foo", k=1) - _compare_documents(output, [Document(page_content="foo")]) + assert output == [Document(page_content="foo")] @pytest.mark.asyncio @@ -106,7 +114,7 @@ async def test_async_pgvector_embeddings() -> None: pre_delete_collection=True, ) output = await docsearch.asimilarity_search("foo", k=1) - _compare_documents(output, [Document(page_content="foo")]) + assert output == [Document(page_content="foo")] def test_pgvector_with_metadatas() -> None: @@ -122,7 +130,7 @@ def test_pgvector_with_metadatas() -> None: pre_delete_collection=True, ) output = docsearch.similarity_search("foo", k=1) - _compare_documents(output, [Document(page_content="foo", metadata={"page": "0"})]) + assert output == [Document(page_content="foo", metadata={"page": "0"})] @pytest.mark.asyncio @@ -139,7 +147,7 @@ async def test_async_pgvector_with_metadatas() -> None: pre_delete_collection=True, ) output = await docsearch.asimilarity_search("foo", k=1) - _compare_documents(output, [Document(page_content="foo", metadata={"page": "0"})]) + assert output == [Document(page_content="foo", metadata={"page": "0"})] def test_pgvector_with_metadatas_with_scores() -> None: @@ -155,9 +163,7 @@ def test_pgvector_with_metadatas_with_scores() -> None: pre_delete_collection=True, ) output = docsearch.similarity_search_with_score("foo", k=1) - docs, scores = zip(*output) - _compare_documents(docs, [Document(page_content="foo", metadata={"page": "0"})]) - assert scores == (0.0,) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] @pytest.mark.asyncio @@ -174,9 +180,7 @@ async def test_async_pgvector_with_metadatas_with_scores() -> None: pre_delete_collection=True, ) output = await docsearch.asimilarity_search_with_score("foo", k=1) - docs, scores = zip(*output) - _compare_documents(docs, [Document(page_content="foo", metadata={"page": "0"})]) - assert scores == (0.0,) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] def test_pgvector_with_filter_match() -> None: @@ -192,9 +196,7 @@ def test_pgvector_with_filter_match() -> None: pre_delete_collection=True, ) output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) - docs, scores = zip(*output) - _compare_documents(docs, [Document(page_content="foo", metadata={"page": "0"})]) - assert scores == (0.0,) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] @pytest.mark.asyncio @@ -213,9 +215,7 @@ async def test_async_pgvector_with_filter_match() -> None: output = await docsearch.asimilarity_search_with_score( "foo", k=1, filter={"page": "0"} ) - docs, scores = zip(*output) - _compare_documents(docs, [Document(page_content="foo", metadata={"page": "0"})]) - assert scores == (0.0,) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] def test_pgvector_with_filter_distant_match() -> None: @@ -231,9 +231,12 @@ def test_pgvector_with_filter_distant_match() -> None: pre_delete_collection=True, ) output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) - docs, scores = zip(*output) - _compare_documents(docs, [Document(page_content="baz", metadata={"page": "2"})]) - assert scores == (0.0013003906671379406,) + assert output == [ + ( + Document(page_content="baz", metadata={"page": "2"}), + 0.0013003906671379406, + ) + ] @pytest.mark.asyncio @@ -252,9 +255,12 @@ async def test_async_pgvector_with_filter_distant_match() -> None: output = await docsearch.asimilarity_search_with_score( "foo", k=1, filter={"page": "2"} ) - docs, scores = zip(*output) - _compare_documents(docs, [Document(page_content="baz", metadata={"page": "2"})]) - assert scores == (0.0013003906671379406,) + assert output == [ + ( + Document(page_content="baz", metadata={"page": "2"}), + 0.0013003906671379406, + ) + ] def test_pgvector_with_filter_no_match() -> None: @@ -292,6 +298,160 @@ async def test_async_pgvector_with_filter_no_match() -> None: assert output == [] +def test_pgvector_with_full_text() -> None: + """Test end to end construction and search.""" + texts = ["foo", "foo bar baz", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1, full_text_search=["bar & baz"]) + assert output == [Document(page_content="foo bar baz", metadata={"page": "1"})] + + +def test_pgvector_with_full_text_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1, full_text_search=["bar & baz"]) + assert output == [] + + +@pytest.mark.asyncio +async def test_async_pgvector_with_full_text() -> None: + """Test end to end construction and search.""" + texts = ["foo", "foo bar baz", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search( + "foo", k=1, full_text_search=["bar & baz"] + ) + assert output == [Document(page_content="foo bar baz", metadata={"page": "1"})] + + +@pytest.mark.asyncio +async def test_async_pgvector_with_full_text_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search( + "foo", k=1, full_text_search=["bar & baz"] + ) + assert output == [] + + +def test_pgvector_with_full_text_with_scores() -> None: + """Test end to end construction and search.""" + texts = ["foo", "foo bar baz", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score( + "foo", k=1, full_text_search=["bar & baz"] + ) + assert output == [ + ( + Document(page_content="foo bar baz", metadata={"page": "1"}), + 0.000325573832493542, + ) + ] + + +def test_pgvector_with_full_text_with_scores_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = PGVector.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score( + "foo", k=1, full_text_search=["bar & baz"] + ) + assert output == [] + + +@pytest.mark.asyncio +async def test_async_pgvector_with_full_text_with_scores() -> None: + """Test end to end construction and search.""" + texts = ["foo", "foo bar baz", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search_with_score( + "foo", k=1, full_text_search=["bar & baz"] + ) + assert output == [ + ( + Document(page_content="foo bar baz", metadata={"page": "1"}), + 0.000325573832493542, + ) + ] + + +@pytest.mark.asyncio +async def test_async_pgvector_with_full_text_with_scores_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search_with_score( + "foo", k=1, full_text_search=["bar & baz"] + ) + assert output == [] + + def test_pgvector_collection_with_metadata() -> None: """Test end to end collection construction""" pgvector = PGVector( @@ -590,16 +750,17 @@ def test_pgvector_relevance_score() -> None: ) output = docsearch.similarity_search_with_relevance_scores("foo", k=3) - docs, scores = zip(*output) - _compare_documents( - docs, - [ - Document(page_content="foo", metadata={"page": "0"}), + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 1.0), + ( Document(page_content="bar", metadata={"page": "1"}), + 0.9996744261675065, + ), + ( Document(page_content="baz", metadata={"page": "2"}), - ], - ) - assert scores == (1.0, 0.9996744261675065, 0.9986996093328621) + 0.9986996093328621, + ), + ] @pytest.mark.asyncio @@ -617,16 +778,17 @@ async def test_async_pgvector_relevance_score() -> None: ) output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3) - docs, scores = zip(*output) - _compare_documents( - docs, - [ - Document(page_content="foo", metadata={"page": "0"}), + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 1.0), + ( Document(page_content="bar", metadata={"page": "1"}), + 0.9996744261675065, + ), + ( Document(page_content="baz", metadata={"page": "2"}), - ], - ) - assert scores == (1.0, 0.9996744261675065, 0.9986996093328621) + 0.9986996093328621, + ), + ] def test_pgvector_retriever_search_threshold() -> None: @@ -647,13 +809,10 @@ def test_pgvector_retriever_search_threshold() -> None: search_kwargs={"k": 3, "score_threshold": 0.999}, ) output = retriever.get_relevant_documents("summer") - _compare_documents( - output, - [ - Document(page_content="foo", metadata={"page": "0"}), - Document(page_content="bar", metadata={"page": "1"}), - ], - ) + assert output == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="bar", metadata={"page": "1"}), + ] @pytest.mark.asyncio @@ -675,13 +834,10 @@ async def test_async_pgvector_retriever_search_threshold() -> None: search_kwargs={"k": 3, "score_threshold": 0.999}, ) output = await retriever.aget_relevant_documents("summer") - _compare_documents( - output, - [ - Document(page_content="foo", metadata={"page": "0"}), - Document(page_content="bar", metadata={"page": "1"}), - ], - ) + assert output == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="bar", metadata={"page": "1"}), + ] def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None: @@ -742,7 +898,7 @@ def test_pgvector_max_marginal_relevance_search() -> None: pre_delete_collection=True, ) output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3) - _compare_documents(output, [Document(page_content="foo")]) + assert output == [Document(page_content="foo")] @pytest.mark.asyncio @@ -757,7 +913,7 @@ async def test_async_pgvector_max_marginal_relevance_search() -> None: pre_delete_collection=True, ) output = await docsearch.amax_marginal_relevance_search("foo", k=1, fetch_k=3) - _compare_documents(output, [Document(page_content="foo")]) + assert output == [Document(page_content="foo")] def test_pgvector_max_marginal_relevance_search_with_score() -> None: @@ -771,9 +927,7 @@ def test_pgvector_max_marginal_relevance_search_with_score() -> None: pre_delete_collection=True, ) output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3) - docs, scores = zip(*output) - _compare_documents(docs, [Document(page_content="foo")]) - assert scores == (0.0,) + assert output == [(Document(page_content="foo"), 0.0)] @pytest.mark.asyncio @@ -790,9 +944,7 @@ async def test_async_pgvector_max_marginal_relevance_search_with_score() -> None output = await docsearch.amax_marginal_relevance_search_with_score( "foo", k=1, fetch_k=3 ) - docs, scores = zip(*output) - _compare_documents(docs, [Document(page_content="foo")]) - assert scores == (0.0,) + assert output == [(Document(page_content="foo"), 0.0)] def test_pgvector_with_custom_connection() -> None: @@ -806,7 +958,7 @@ def test_pgvector_with_custom_connection() -> None: pre_delete_collection=True, ) output = docsearch.similarity_search("foo", k=1) - _compare_documents(output, [Document(page_content="foo")]) + assert output == [Document(page_content="foo")] @pytest.mark.asyncio @@ -821,7 +973,7 @@ async def test_async_pgvector_with_custom_connection() -> None: pre_delete_collection=True, ) output = await docsearch.asimilarity_search("foo", k=1) - _compare_documents(output, [Document(page_content="foo")]) + assert output == [Document(page_content="foo")] def test_pgvector_with_custom_engine_args() -> None: @@ -844,7 +996,7 @@ def test_pgvector_with_custom_engine_args() -> None: engine_args=engine_args, ) output = docsearch.similarity_search("foo", k=1) - _compare_documents(output, [Document(page_content="foo")]) + assert output == [Document(page_content="foo")] # We should reuse this test-case across other integrations diff --git a/tests/utils.py b/tests/utils.py index a61d38b..46f60ce 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,16 @@ from contextlib import asynccontextmanager, contextmanager import psycopg +from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import ( + AsyncSession as SqlAlchemyAsyncSession, +) +from sqlalchemy.ext.asyncio import ( + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import Session as SqlAlchemySession +from sqlalchemy.orm import sessionmaker from typing_extensions import AsyncGenerator, Generator # Env variables match the default settings in the docker-compose file @@ -53,3 +63,27 @@ def syncpg_client() -> Generator[psycopg.Connection, None, None]: finally: # Cleanup: close the connection after the test is done conn.close() + + +@asynccontextmanager +async def async_session() -> AsyncGenerator[SqlAlchemyAsyncSession, None]: + engine = create_async_engine(VECTORSTORE_CONNECTION_STRING) + AsyncSession = async_sessionmaker(bind=engine) + + session = AsyncSession() + try: + yield session + finally: + await session.close() + + +@contextmanager +def sync_session() -> Generator[SqlAlchemySession, None, None]: + engine = create_engine(VECTORSTORE_CONNECTION_STRING) + Session = sessionmaker(bind=engine) + + session = Session() + try: + yield session + finally: + session.close()