From 33b0851c1f818713bc1cebb66e2b6c18cab4fd9c Mon Sep 17 00:00:00 2001 From: Shamsuddin Ahmed Date: Tue, 15 Oct 2024 16:13:11 +0600 Subject: [PATCH 1/2] feat(PGVector): enable deletion by metadata filter --- langchain_postgres/vectorstores.py | 110 ++++++++++++++++------------- 1 file changed, 61 insertions(+), 49 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index e1630a1..daf52fc 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -588,74 +588,86 @@ async def adelete_collection(self) -> None: await session.commit() def delete( - self, - ids: Optional[List[str]] = None, - collection_only: bool = False, - **kwargs: Any, + self, + ids: Optional[List[str]] = None, + *, + filter: Optional[Dict[str, Any]] = None, + collection_only: bool = False, + **kwargs: Any, ) -> None: - """Delete vectors by ids or uuids. + """Delete vectors by ids or metadata filter. Args: - ids: List of ids to delete. - collection_only: Only delete ids in the collection. + ids: Optional list of ids to delete. + filter: Optional metadata filter dictionary. + collection_only: If True, delete only from the current collection. + **kwargs: Additional arguments. """ - with self._make_sync_session() as session: - if ids is not None: - self.logger.debug( - "Trying to delete vectors by ids (represented by the model " - "using the custom ids field)" - ) + if ids is None and filter is None: + self.logger.warning("No ids or filter provided for deletion.") + return - stmt = delete(self.EmbeddingStore) + with self._make_sync_session() as session: + stmt = delete(self.EmbeddingStore) + if collection_only: + collection = self.get_collection(session) + if not collection: + self.logger.warning("Collection not found.") + return + stmt = stmt.where(self.EmbeddingStore.collection_id == collection.uuid) - if collection_only: - collection = self.get_collection(session) - if not collection: - self.logger.warning("Collection not found") - return + if ids is not None: + self.logger.debug("Deleting vectors by ids.") + stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) - stmt = stmt.where( - self.EmbeddingStore.collection_id == collection.uuid - ) + if filter is not None: + self.logger.debug("Deleting vectors by metadata filter.") + filter_clause = self._create_filter_clause(filter) + stmt = stmt.where(filter_clause) - stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) - session.execute(stmt) + session.execute(stmt) session.commit() async def adelete( - self, - ids: Optional[List[str]] = None, - collection_only: bool = False, - **kwargs: Any, + self, + ids: Optional[List[str]] = None, + *, + filter: Optional[Dict[str, Any]] = None, + collection_only: bool = False, + **kwargs: Any, ) -> None: - """Async delete vectors by ids or uuids. + """Asynchronously delete vectors by ids or metadata filter. Args: - ids: List of ids to delete. - collection_only: Only delete ids in the collection. + ids: Optional list of ids to delete. + filter: Optional metadata filter dictionary. + collection_only: If True, delete only from the current collection. + **kwargs: Additional arguments. """ - await self.__apost_init__() # Lazy async init - async with self._make_async_session() as session: - if ids is not None: - self.logger.debug( - "Trying to delete vectors by ids (represented by the model " - "using the custom ids field)" - ) + if ids is None and filter is None: + self.logger.warning("No ids or filter provided for deletion.") + return - stmt = delete(self.EmbeddingStore) + await self.__apost_init__() + async with self._make_async_session() as session: + stmt = delete(self.EmbeddingStore) + if collection_only: + collection = await self.aget_collection(session) + if not collection: + self.logger.warning("Collection not found.") + return + stmt = stmt.where(self.EmbeddingStore.collection_id == collection.uuid) - if collection_only: - collection = await self.aget_collection(session) - if not collection: - self.logger.warning("Collection not found") - return + if ids is not None: + self.logger.debug("Deleting vectors by ids.") + stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) - stmt = stmt.where( - self.EmbeddingStore.collection_id == collection.uuid - ) + if filter is not None: + self.logger.debug("Deleting vectors by metadata filter.") + filter_clause = self._create_filter_clause(filter) + stmt = stmt.where(filter_clause) - stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) - await session.execute(stmt) + await session.execute(stmt) await session.commit() def get_collection(self, session: Session) -> Any: From eb03ec184d8b9e95933502203532182e680e3569 Mon Sep 17 00:00:00 2001 From: Shamsuddin Ahmed Date: Tue, 15 Oct 2024 16:24:02 +0600 Subject: [PATCH 2/2] test(PGVector): add tests for deletion by metadata filter --- tests/unit_tests/test_vectorstore.py | 49 ++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 9945e51..02b6806 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -410,6 +410,55 @@ async def test_async_pgvector_delete_docs() -> None: assert sorted(record.id for record in records) == [] # type: ignore +def test_pgvector_delete_by_metadata() -> None: + """Test deleting documents by metadata.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"category": "news"}, {"category": "sports"}, {"category": "news"}] + vectorstore = PGVector.from_texts( + texts=texts, + collection_name="test_delete_by_metadata", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + ids=["1", "2", "3"], + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + # Delete documents where category is 'news' + vectorstore.delete(filter={"category": {"$eq": "news"}}) + with vectorstore.session_maker() as session: + records = list(session.query(vectorstore.EmbeddingStore).all()) + # Should only have the document with category 'sports' remaining + assert len(records) == 1 + assert records[0].id == "2" + assert records[0].cmetadata["category"] == "sports" + + +@pytest.mark.asyncio +async def test_async_pgvector_delete_by_metadata() -> None: + """Test deleting documents by metadata asynchronously.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"category": "news"}, {"category": "sports"}, {"category": "news"}] + vectorstore = await PGVector.afrom_texts( + texts=texts, + collection_name="test_delete_by_metadata", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + ids=["1", "2", "3"], + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + # Delete documents where category is 'news' + await vectorstore.adelete(filter={"category": {"$eq": "news"}}) + async with vectorstore.session_maker() as session: + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) + # Should only have the document with category 'sports' remaining + assert len(records) == 1 + assert records[0].id == "2" + assert records[0].cmetadata["category"] == "sports" + + def test_pgvector_index_documents() -> None: """Test adding duplicate documents results in overwrites.""" documents = [