Skip to content

Commit

Permalink
Auto Hashing ID for VectorDB Classes (#4746) (#4789)
Browse files Browse the repository at this point in the history
  • Loading branch information
thinkall authored Dec 27, 2024
1 parent 960fbf0 commit 38bfb3d
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 24 deletions.
19 changes: 18 additions & 1 deletion autogen/agentchat/contrib/vectordb/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import os
from typing import (
Any,
Callable,
Expand All @@ -16,6 +18,8 @@
Vector = Union[Sequence[float], Sequence[int]]
ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does

HASH_LENGTH = int(os.environ.get("HASH_LENGTH", 8))


class Document(TypedDict):
"""A Document is a record in the vector database.
Expand All @@ -26,7 +30,7 @@ class Document(TypedDict):
embedding: Vector, Optional | the vector representation of the content.
"""

id: ItemID
id: Optional[ItemID]
content: str
metadata: Optional[Metadata]
embedding: Optional[Vector]
Expand Down Expand Up @@ -108,6 +112,19 @@ def delete_collection(self, collection_name: str) -> Any:
"""
...

def generate_chunk_ids(chunks: List[str], hash_length: int = HASH_LENGTH) -> List[ItemID]:
"""
Generate chunk IDs to ensure non-duplicate uploads.
Args:
chunks (list): A list of chunks (strings) to hash.
hash_length (int): The desired length of the hash.
Returns:
list: A list of generated chunk IDs.
"""
return [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:hash_length] for chunk in chunks]

def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None:
"""
Insert documents into the collection of the vector database.
Expand Down
71 changes: 49 additions & 22 deletions autogen/agentchat/contrib/vectordb/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,17 @@ def _wait_for_document(self, collection: Collection, index_name: str, doc: Docum
if query_result and query_result[0][0]["_id"] == doc["id"]:
return
sleep(_DELAY)

raise TimeoutError(f"Document {self.index_name} is not ready!")
if (
query_result
and float(query_result[0][1]) == 1.0
and query_result[0][0].get("metadata") == doc.get("metadata")
):
# Handles edge case where document is uploaded with a specific user-generated ID, then the identical content is uploaded with a hash generated ID.
logger.warning(
f"""Documents may be ready, the search has found identical content with a different ID and {"identical" if query_result[0][0].get("metadata") == doc.get("metadata") else "different"} metadata. Duplicate ID: {str(query_result[0][0]["_id"])}"""
)
else:
raise TimeoutError(f"Document {self.index_name} is not ready!")

def _get_embedding_size(self):
return len(self.embedding_function(_SAMPLE_SENTENCE)[0])
Expand Down Expand Up @@ -275,33 +284,49 @@ def insert_docs(
For large numbers of Documents, insertion is performed in batches.
Documents are recommended to not have an ID field, as the method will generate Hashed ID's for them.
Args:
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`, which may contain an ID. Documents without ID's will have them generated.
collection_name: str | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
batch_size: Number of documents to be inserted in each batch
kwargs: Additional keyword arguments. Use `hash_length` to set the length of the hash generated ID's, use `overwrite_ids` to overwrite existing ID's with Hashed Values.
"""
hash_length = kwargs.get("hash_length")
overwrite_ids = kwargs.get("overwrite_ids", False)

if any(doc.get("content") is None for doc in docs):
raise ValueError("The document content is required.")

if not docs:
logger.info("No documents to insert.")
return

docs = deepcopy(docs)
collection = self.get_collection(collection_name)

assert (
len({doc.get("id") is None for doc in docs}) == 1
), "Documents provided must all have ID's or all not have ID's"

if docs[0].get("id") is None or overwrite_ids:
logger.info("No id field in the documents. The documents will be inserted with Hash generated IDs.")
content = [doc["content"] for doc in docs]
ids = (
self.generate_chunk_ids(content, hash_length=hash_length)
if hash_length
else self.generate_chunk_ids(content)
)
docs = [{**doc, "id": id} for doc, id in zip(docs, ids)]

if upsert:
self.update_docs(docs, collection.name, upsert=True)

else:
# Sanity checking the first document
if docs[0].get("content") is None:
raise ValueError("The document content is required.")
if docs[0].get("id") is None:
raise ValueError("The document id is required.")

input_ids = set()
result_ids = set()
id_batch = []
text_batch = []
metadata_batch = []
size = 0
i = 0
input_ids, result_ids = set(), set()
id_batch, text_batch, metadata_batch = [], [], []
size, i = 0, 0
for doc in docs:
id = doc["id"]
text = doc["content"]
Expand All @@ -314,9 +339,7 @@ def insert_docs(
if (i + 1) % batch_size == 0 or size >= 47_000_000:
result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch))
input_ids.update(id_batch)
id_batch = []
text_batch = []
metadata_batch = []
id_batch, text_batch, metadata_batch = [], [], []
size = 0
i += 1
if text_batch:
Expand Down Expand Up @@ -365,7 +388,8 @@ def _insert_batch(
]
# insert the documents in MongoDB Atlas
insert_result = collection.insert_many(to_insert) # type: ignore
return insert_result.inserted_ids # TODO Remove this. Replace by log like update_docs
# TODO Remove this. Replace by log like update_docs
return insert_result.inserted_ids

def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs: Any) -> None:
"""Update documents, including their embeddings, in the Collection.
Expand All @@ -375,11 +399,14 @@ def update_docs(self, docs: List[Document], collection_name: str = None, **kwarg
Uses deepcopy to avoid changing docs.
Args:
docs: List[Document] | A list of documents.
docs: List[Document] | A list of documents, with ID, to ensure the correct document is updated.
collection_name: str | The name of the collection. Default is None.
kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection.
"""

provided_doc_count = len(docs)
docs = [doc for doc in docs if doc.get("id") is not None]
if len(docs) != provided_doc_count:
logger.info(f"{provided_doc_count - len(docs)} will not be updated, as they did not contain an ID")
n_docs = len(docs)
logger.info(f"Preparing to embed and update {n_docs=}")
# Compute the embeddings
Expand Down
15 changes: 14 additions & 1 deletion autogen/agentchat/contrib/vectordb/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import hashlib
import logging
import os
import uuid
from typing import Callable, List, Optional, Sequence, Tuple, Union

from .base import Document, ItemID, QueryResults, VectorDB
Expand Down Expand Up @@ -155,6 +156,18 @@ def delete_collection(self, collection_name: str) -> None:
"""
return self.client.delete_collection(collection_name)

def generate_chunk_ids(chunks: List[str]) -> List[ItemID]:
"""
Generate chunk IDs to ensure non-duplicate uploads.
Args:
chunks (list): A list of chunks (strings) to hash.
Returns:
list: A list of generated chunk IDs.
"""
return [str(uuid.UUID(hex=hashlib.md5(chunk.encode("utf-8")).hexdigest())) for chunk in chunks]

def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
"""
Insert documents into the collection of the vector database.
Expand Down
55 changes: 55 additions & 0 deletions test/agentchat/contrib/vectordb/test_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,28 @@ def example_documents() -> List[Document]:
]


@pytest.fixture
def id_less_example_documents() -> List[Document]:
"""No ID for Hashing Input Test"""
return [
Document(content="Stars are Big.", metadata={"a": 1}),
Document(content="Atoms are Small.", metadata={"b": 1}),
Document(content="Clouds are White.", metadata={"c": 1}),
Document(content="Grass is Green.", metadata={"d": 1, "e": 2}),
]


@pytest.fixture
def id_mix_example_documents() -> List[Document]:
"""No ID for Hashing Input Test"""
return [
Document(id="123", content="Stars are Big.", metadata={"a": 1}),
Document(content="Atoms are Small.", metadata={"b": 1}),
Document(id="321", content="Clouds are White.", metadata={"c": 1}),
Document(content="Grass is Green.", metadata={"d": 1, "e": 2}),
]


@pytest.fixture
def db_with_indexed_clxn(collection_name):
"""VectorDB with a collection created immediately"""
Expand Down Expand Up @@ -212,6 +234,39 @@ def test_insert_docs(db, collection_name, example_documents):
assert len(found[0]["embedding"]) == 384


def test_insert_docs_no_id(db, collection_name, id_less_example_documents):
# Test that there's an active collection
with pytest.raises(ValueError) as exc:
db.insert_docs(id_less_example_documents)
assert "No collection is specified" in str(exc.value)

# Create a collection
db.delete_collection(collection_name)
collection = db.create_collection(collection_name)

# Insert example documents
db.insert_docs(id_less_example_documents, collection_name=collection_name)
found = list(collection.find({}))
assert len(found) == len(id_less_example_documents)
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
# Check ids
hash_values = set(db.generate_chunk_ids([content.get("content") for content in id_less_example_documents]))
assert {doc["_id"] for doc in found} == hash_values
# Check embedding lengths
assert len(found[0]["embedding"]) == 384


def test_insert_docs_mix_id(db, collection_name, id_mix_example_documents):
# Test that there's an active collection
with pytest.raises(ValueError) as exc:
db.insert_docs(id_mix_example_documents)
assert "No collection is specified" in str(exc.value)
# Test that insert_docs does not accept mixed ID inserts
with pytest.raises(AssertionError, match="Documents provided must all have ID's or all not have ID's"):
db.insert_docs(id_mix_example_documents, collection_name, upsert=True)


def test_update_docs(db_with_indexed_clxn, example_documents):
db, collection = db_with_indexed_clxn
# Use update_docs to insert new documents
Expand Down

0 comments on commit 38bfb3d

Please sign in to comment.