Skip to content

Commit

Permalink
update to chromadb 0.5.x
Browse files Browse the repository at this point in the history
  • Loading branch information
lpinheiroms committed Oct 23, 2024
1 parent 63095a2 commit 22cb336
Show file tree
Hide file tree
Showing 8 changed files with 1,230 additions and 307 deletions.
2 changes: 1 addition & 1 deletion python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ langchain = ["langchain_core~= 0.3.3"]
azure = ["azure-core", "azure-identity"]
docker = ["docker~=7.0"]
openai = ["openai>=1.3"]
chromadb = ["chromadb~=0.4.15"]
chromadb = ["chromadb~=0.5.15", "sentence-transformers"]

[tool.hatch.build.targets.wheel]
packages = ["src/autogen_ext"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ._chromadb import ChromaVectorDB
from ._factory import VectorDBFactory

__all__ = ["ChromaVectorDB", "VectorDBFactory"]
240 changes: 210 additions & 30 deletions python/packages/autogen-ext/src/autogen_ext/storage/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pydantic import BaseModel
from typing import (
Any,
Callable,
Expand All @@ -12,19 +11,23 @@
runtime_checkable,
)

from pydantic import BaseModel

Metadata = Union[Mapping[str, Any], None]
Vector = Union[Sequence[float], Sequence[int]]
ItemID = Union[str, int]



class Document(BaseModel):
"""Define Document according to autogen 0.4 specifications."""

id: ItemID
content: str
metadata: Optional[Metadata]
embedding: Optional[Vector]
content: Optional[str] = None
metadata: Optional[Metadata] = None
embedding: Optional[Vector] = None

model_config = {"arbitrary_types_allowed": True}


"""QueryResults is the response from the vector database for a query/queries.
A query is a list containing one string while queries is a list containing multiple strings.
Expand All @@ -33,10 +36,178 @@ class Document(BaseModel):
QueryResults = List[List[Tuple[Document, float]]]


@runtime_checkable
class AsyncVectorDB(Protocol):
"""
Abstract class for async vector database. A vector database is responsible for storing and retrieving documents.
Attributes:
active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None.
type: str | The type of the vector database, chroma, pgvector, etc. Default is "".
Methods:
create_collection: Callable[[str, bool, bool], Awaitable[Any]] | Create a collection in the vector database.
get_collection: Callable[[str], Awaitable[Any]] | Get the collection from the vector database.
delete_collection: Callable[[str], Awaitable[Any]] | Delete the collection from the vector database.
insert_docs: Callable[[List[Document], str, bool], Awaitable[None]] | Insert documents into the collection of the vector database.
update_docs: Callable[[List[Document], str], Awaitable[None]] | Update documents in the collection of the vector database.
delete_docs: Callable[[List[ItemID], str], Awaitable[None]] | Delete documents from the collection of the vector database.
retrieve_docs: Callable[[List[str], str, int, float], Awaitable[QueryResults]] | Retrieve documents from the collection of the vector database based on the queries.
get_docs_by_ids: Callable[[List[ItemID], str], Awaitable[List[Document]]] | Retrieve documents from the collection of the vector database based on the ids.
"""

active_collection: Any = None
type: str = ""
embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = (
None # embeddings = embedding_function(sentences)
)

async def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
"""
Create a collection in the vector database.
Case 1. if the collection does not exist, create the collection.
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
otherwise it raise a ValueError.
Args:
collection_name: str | The name of the collection.
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
get_or_create: bool | Whether to get the collection if it exists. Default is True.
Returns:
Any | The collection object.
"""
...

async def get_collection(self, collection_name: Optional[str] = None) -> Any:
"""
Get the collection from the vector database.
Args:
collection_name: Optional[str] | The name of the collection. Default is None.
If None, return the current active collection.
Returns:
Any | The collection object.
"""
...

async def delete_collection(self, collection_name: str) -> Any:
"""
Delete the collection from the vector database.
Args:
collection_name: str | The name of the collection.
Returns:
Any
"""
...

async def insert_docs(
self,
docs: Sequence[Document],
collection_name: Optional[str] = None,
upsert: bool = False,
**kwargs: Any,
) -> None:
"""
Insert documents into the collection of the vector database.
Args:
docs: List[Document] | A list of documents. Each document is a Pydantic Document model.
collection_name: Optional[str] | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
...

async def update_docs(self, docs: Sequence[Document], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Update documents in the collection of the vector database.
Args:
docs: List[Document] | A list of documents.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
...

async def delete_docs(self, ids: Sequence[ItemID], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Delete documents from the collection of the vector database.
Args:
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
...

async def retrieve_docs(
self,
queries: Sequence[str],
collection_name: Optional[str] = None,
n_results: int = 10,
distance_threshold: float = -1,
**kwargs: Any,
) -> QueryResults:
"""
Retrieve documents from the collection of the vector database based on the queries.
Args:
queries: List[str] | A list of queries. Each query is a string.
collection_name: Optional[str] | The name of the collection. Default is None.
n_results: int | The number of relevant documents to return. Default is 10.
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
returned. Don't filter with it if < 0. Default is -1.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
the distance.
"""
...

async def get_docs_by_ids(
self,
ids: Optional[Sequence[ItemID]] = None,
collection_name: Optional[str] = None,
include: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> List[Document]:
"""
Retrieve documents from the collection of the vector database based on the ids.
Args:
ids: Optional[List[ItemID]] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: Optional[str] | The name of the collection. Default is None.
include: Optional[List[str]] | The fields to include. Default is None.
If None, will include ["metadatas", "documents"], ids will always be included. This may differ
depending on the implementation.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
List[Document] | The results.
"""
...


@runtime_checkable
class VectorDB(Protocol):
"""
Abstract class for vector database. A vector database is responsible for storing and retrieving documents.
Abstract class for synchronous vector database. A vector database is responsible for storing and retrieving documents.
For async support, use AsyncVectorDB instead.
Attributes:
active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None.
Expand Down Expand Up @@ -77,13 +248,13 @@ def create_collection(self, collection_name: str, overwrite: bool = False, get_o
"""
...

def get_collection(self, collection_name: str = None) -> Any:
def get_collection(self, collection_name: Optional[str] = None) -> Any:
"""
Get the collection from the vector database.
Args:
collection_name: str | The name of the collection. Default is None. If None, return the
current active collection.
collection_name: Optional[str] | The name of the collection. Default is None.
If None, return the current active collection.
Returns:
Any | The collection object.
Expand All @@ -102,43 +273,49 @@ def delete_collection(self, collection_name: str) -> Any:
"""
...

def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None:
def insert_docs(
self,
docs: Sequence[Document],
collection_name: Optional[str] = None,
upsert: bool = False,
**kwargs: Any,
) -> None:
"""
Insert documents into the collection of the vector database.
Args:
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
collection_name: str | The name of the collection. Default is None.
docs: List[Document] | A list of documents. Each document is a Pydantic Document model.
collection_name: Optional[str] | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
kwargs: Dict | Additional keyword arguments.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
...

def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs) -> None:
def update_docs(self, docs: Sequence[Document], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Update documents in the collection of the vector database.
Args:
docs: List[Document] | A list of documents.
collection_name: str | The name of the collection. Default is None.
kwargs: Dict | Additional keyword arguments.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
...

def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
def delete_docs(self, ids: Sequence[ItemID], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Delete documents from the collection of the vector database.
Args:
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
collection_name: str | The name of the collection. Default is None.
kwargs: Dict | Additional keyword arguments.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
Expand All @@ -147,22 +324,22 @@ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs)

def retrieve_docs(
self,
queries: List[str],
collection_name: str = None,
queries: Sequence[str],
collection_name: Optional[str] = None,
n_results: int = 10,
distance_threshold: float = -1,
**kwargs,
**kwargs: Any,
) -> QueryResults:
"""
Retrieve documents from the collection of the vector database based on the queries.
Args:
queries: List[str] | A list of queries. Each query is a string.
collection_name: str | The name of the collection. Default is None.
collection_name: Optional[str] | The name of the collection. Default is None.
n_results: int | The number of relevant documents to return. Default is 10.
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
returned. Don't filter with it if < 0. Default is -1.
kwargs: Dict | Additional keyword arguments.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
Expand All @@ -171,21 +348,24 @@ def retrieve_docs(
...

def get_docs_by_ids(
self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
self,
ids: Optional[Sequence[ItemID]] = None,
collection_name: Optional[str] = None,
include: Optional[List[str]] = None,
**kwargs: Any,
) -> List[Document]:
"""
Retrieve documents from the collection of the vector database based on the ids.
Args:
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: str | The name of the collection. Default is None.
include: List[str] | The fields to include. Default is None.
ids: Optional[List[ItemID]] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: Optional[str] | The name of the collection. Default is None.
include: Optional[List[str]] | The fields to include. Default is None.
If None, will include ["metadatas", "documents"], ids will always be included. This may differ
depending on the implementation.
kwargs: dict | Additional keyword arguments.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
List[Document] | The results.
"""
...

Loading

0 comments on commit 22cb336

Please sign in to comment.