diff --git a/libs/community/langchain_community/vectorstores/tiledb.py b/libs/community/langchain_community/vectorstores/tiledb.py index 5db48a07..521b5803 100644 --- a/libs/community/langchain_community/vectorstores/tiledb.py +++ b/libs/community/langchain_community/vectorstores/tiledb.py @@ -1,5 +1,3 @@ -"""Wrapper around TileDB vector database.""" - from __future__ import annotations import pickle @@ -15,69 +13,71 @@ from langchain_community.vectorstores.utils import maximal_marginal_relevance -INDEX_METRICS = frozenset(["euclidean"]) +INDEX_METRICS = frozenset(["euclidean", "squared_l2", "cosine"]) DEFAULT_METRIC = "euclidean" + +def _metric_to_enum(metric: str): + tiledb_vs = guard_import("tiledb.vector_search") + vspy = tiledb_vs.vspy + return { + "euclidean": vspy.DistanceMetric.L2, + "squared_l2": vspy.DistanceMetric.SUM_OF_SQUARES, + "cosine": vspy.DistanceMetric.COSINE, + }[metric] + +def _normalize(v: np.ndarray) -> np.ndarray: + norm = np.linalg.norm(v, axis=1, keepdims=True) + norm[norm == 0] = 1.0 + return (v / norm).astype(v.dtype) + +_SUPPORTED_DTYPES = (np.float32, np.int8, np.uint8) +try: + _HALF_DTYPES = (np.float16, np.dtype("bfloat16")) +except TypeError: + _HALF_DTYPES = (np.float16,) + +def _resolve_vector_dtype(requested: Optional[np.dtype], sample: np.ndarray) -> np.dtype: + if requested is not None: + if requested not in _SUPPORTED_DTYPES: + raise ValueError + return requested + src = sample.dtype + if src in _SUPPORTED_DTYPES: + return src + if src in _HALF_DTYPES: + return np.float32 + raise ValueError + DOCUMENTS_ARRAY_NAME = "documents" VECTOR_INDEX_NAME = "vectors" MAX_UINT64 = np.iinfo(np.dtype("uint64")).max MAX_FLOAT_32 = np.finfo(np.dtype("float32")).max MAX_FLOAT = sys.float_info.max - def dependable_tiledb_import() -> Any: - """Import tiledb-vector-search if available, otherwise raise error.""" return ( guard_import("tiledb.vector_search"), guard_import("tiledb"), ) - def get_vector_index_uri_from_group(group: Any) -> str: - """Get the URI of the vector index.""" return group[VECTOR_INDEX_NAME].uri - def get_documents_array_uri_from_group(group: Any) -> str: - """Get the URI of the documents array from group. - - Args: - group: TileDB group object. - - Returns: - URI of the documents array. - """ return group[DOCUMENTS_ARRAY_NAME].uri - def get_vector_index_uri(uri: str) -> str: - """Get the URI of the vector index.""" return f"{uri}/{VECTOR_INDEX_NAME}" - def get_documents_array_uri(uri: str) -> str: - """Get the URI of the documents array.""" return f"{uri}/{DOCUMENTS_ARRAY_NAME}" - class TileDB(VectorStore): - """TileDB vector store. - - To use, you should have the ``tiledb-vector-search`` python package installed. - - Example: - .. code-block:: python - - from langchain_community import TileDB - embeddings = OpenAIEmbeddings() - db = TileDB(embeddings, index_uri, metric) - - """ - def __init__( self, embedding: Embeddings, index_uri: str, - metric: str, + metric: str = DEFAULT_METRIC, *, vector_index_uri: str = "", docs_array_uri: str = "", @@ -86,46 +86,23 @@ def __init__( allow_dangerous_deserialization: bool = False, **kwargs: Any, ): - """Initialize with necessary components. - - Args: - allow_dangerous_deserialization: whether to allow deserialization - of the data which involves loading data using pickle. - data can be modified by malicious actors to deliver a - malicious payload that results in execution of - arbitrary code on your machine. - """ if not allow_dangerous_deserialization: - raise ValueError( - "TileDB relies on pickle for serialization and deserialization. " - "This can be dangerous if the data is intercepted and/or modified " - "by malicious actors prior to being de-serialized. " - "If you are sure that the data is safe from modification, you can " - " set allow_dangerous_deserialization=True to proceed. " - "Loading of compromised data using pickle can result in execution of " - "arbitrary code on your machine." - ) + raise ValueError + if metric not in INDEX_METRICS: + raise ValueError self.embedding = embedding self.embedding_function = embedding.embed_query self.index_uri = index_uri self.metric = metric self.config = config - - tiledb_vs, tiledb = ( - guard_import("tiledb.vector_search"), - guard_import("tiledb"), - ) + tiledb_vs, tiledb = dependable_tiledb_import() with tiledb.scope_ctx(ctx_or_config=config): index_group = tiledb.Group(self.index_uri, "r") self.vector_index_uri = ( - vector_index_uri - if vector_index_uri != "" - else get_vector_index_uri_from_group(index_group) + vector_index_uri or get_vector_index_uri_from_group(index_group) ) self.docs_array_uri = ( - docs_array_uri - if docs_array_uri != "" - else get_documents_array_uri_from_group(index_group) + docs_array_uri or get_documents_array_uri_from_group(index_group) ) index_group.close() group = tiledb.Group(self.vector_index_uri, "r") @@ -146,6 +123,11 @@ def __init__( timestamp=self.timestamp, **kwargs, ) + self._index_dtype = getattr( + self.vector_index, + "dtype", + getattr(self.vector_index, "vector_type", np.float32), + ) @property def embeddings(self) -> Optional[Embeddings]: @@ -160,18 +142,6 @@ def process_index_results( filter: Optional[Dict[str, Any]] = None, score_threshold: float = MAX_FLOAT, ) -> List[Tuple[Document, float]]: - """Turns TileDB results into a list of documents and scores. - - Args: - ids: List of indices of the documents in the index. - scores: List of distances of the documents in the index. - k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None. - score_threshold: Optional, a floating point value to filter the - resulting set of retrieved docs - Returns: - List of Documents and scores. - """ tiledb = guard_import("tiledb") docs = [] docs_array = tiledb.open( @@ -184,11 +154,11 @@ def process_index_results( continue doc = docs_array[idx] if doc is None or len(doc["text"]) == 0: - raise ValueError(f"Could not find document for id {idx}, got {doc}") + raise ValueError pickled_metadata = doc.get("metadata") result_doc = Document(page_content=str(doc["text"][0])) if pickled_metadata is not None: - metadata = pickle.loads( # ignore[pickle]: explicit-opt-in + metadata = pickle.loads( np.array(pickled_metadata.tolist()).astype(np.uint8).tobytes() ) result_doc.metadata = metadata @@ -208,6 +178,12 @@ def process_index_results( docs = [(doc, score) for doc, score in docs if score <= score_threshold] return docs[:k] + def _prepare_query_vector(self, v: List[float]) -> np.ndarray: + vec = np.asarray(v) + if vec.dtype in _HALF_DTYPES and self._index_dtype == np.float32: + vec = vec.astype(np.float32) + return vec.reshape(1, -1).astype(self._index_dtype, copy=False) + def similarity_search_with_score_by_vector( self, embedding: List[float], @@ -217,29 +193,13 @@ def similarity_search_with_score_by_vector( fetch_k: int = 20, **kwargs: Any, ) -> List[Tuple[Document, float]]: - """Return docs most similar to query. - - Args: - embedding: Embedding vector to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None. - fetch_k: (Optional[int]) Number of Documents to fetch before filtering. - Defaults to 20. - **kwargs: kwargs to be passed to similarity search. Can include: - nprobe: Optional, number of partitions to check if using IVF_FLAT index - score_threshold: Optional, a floating point value to filter the - resulting set of retrieved docs - - Returns: - List of documents most similar to the query text and distance - in float for each. Lower score represents more similarity. - """ if "score_threshold" in kwargs: score_threshold = kwargs.pop("score_threshold") else: score_threshold = MAX_FLOAT + q = self._prepare_query_vector(embedding) d, i = self.vector_index.query( - np.array([np.array(embedding).astype(np.float32)]).astype(np.float32), + q, k=k if filter is None else fetch_k, **kwargs, ) @@ -256,28 +216,14 @@ def similarity_search_with_score( fetch_k: int = 20, **kwargs: Any, ) -> List[Tuple[Document, float]]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - fetch_k: (Optional[int]) Number of Documents to fetch before filtering. - Defaults to 20. - - Returns: - List of documents most similar to the query text with - Distance as float. Lower score represents more similarity. - """ embedding = self.embedding_function(query) - docs = self.similarity_search_with_score_by_vector( + return self.similarity_search_with_score_by_vector( embedding, k=k, filter=filter, fetch_k=fetch_k, **kwargs, ) - return docs def similarity_search_by_vector( self, @@ -287,18 +233,6 @@ def similarity_search_by_vector( fetch_k: int = 20, **kwargs: Any, ) -> List[Document]: - """Return docs most similar to embedding vector. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - fetch_k: (Optional[int]) Number of Documents to fetch before filtering. - Defaults to 20. - - Returns: - List of Documents most similar to the embedding. - """ docs_and_scores = self.similarity_search_with_score_by_vector( embedding, k=k, @@ -316,18 +250,6 @@ def similarity_search( fetch_k: int = 20, **kwargs: Any, ) -> List[Document]: - """Return docs most similar to query. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. - fetch_k: (Optional[int]) Number of Documents to fetch before filtering. - Defaults to 20. - - Returns: - List of Documents most similar to the query. - """ docs_and_scores = self.similarity_search_with_score( query, k=k, filter=filter, fetch_k=fetch_k, **kwargs ) @@ -343,31 +265,13 @@ def max_marginal_relevance_search_with_score_by_vector( filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: - """Return docs and their similarity scores selected using the maximal marginal - relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch before filtering to - pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents and similarity scores selected by maximal marginal - relevance and score for each. - """ if "score_threshold" in kwargs: score_threshold = kwargs.pop("score_threshold") else: score_threshold = MAX_FLOAT + q = self._prepare_query_vector(embedding) scores, indices = self.vector_index.query( - np.array([np.array(embedding).astype(np.float32)]).astype(np.float32), + q, k=fetch_k if filter is None else fetch_k * 2, **kwargs, ) @@ -381,16 +285,15 @@ def max_marginal_relevance_search_with_score_by_vector( embeddings = [ self.embedding.embed_documents([doc.page_content])[0] for doc, _ in results ] + if self.metric == "cosine" and embeddings: + embeddings = _normalize(np.vstack(embeddings)) mmr_selected = maximal_marginal_relevance( np.array([embedding], dtype=np.float32), embeddings, k=k, lambda_mult=lambda_mult, ) - docs_and_scores = [] - for i in mmr_selected: - docs_and_scores.append(results[i]) - return docs_and_scores + return [results[i] for i in mmr_selected] def max_marginal_relevance_search_by_vector( self, @@ -401,23 +304,6 @@ def max_marginal_relevance_search_by_vector( filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch before filtering to - pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( embedding, k=k, @@ -437,25 +323,8 @@ def max_marginal_relevance_search( filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch before filtering (if needed) to - pass to MMR algorithm. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - Defaults to 0.5. - Returns: - List of Documents selected by maximal marginal relevance. - """ embedding = self.embedding_function(query) - docs = self.max_marginal_relevance_search_by_vector( + return self.max_marginal_relevance_search_by_vector( embedding, k=k, fetch_k=fetch_k, @@ -463,7 +332,6 @@ def max_marginal_relevance_search( filter=filter, **kwargs, ) - return docs @classmethod def create( @@ -475,11 +343,10 @@ def create( *, metadatas: bool = True, config: Optional[Mapping[str, Any]] = None, + metric: str = DEFAULT_METRIC, ) -> None: - tiledb_vs, tiledb = ( - guard_import("tiledb.vector_search"), - guard_import("tiledb"), - ) + tiledb_vs, tiledb = dependable_tiledb_import() + distance_enum = _metric_to_enum(metric) with tiledb.scope_ctx(ctx_or_config=config): try: tiledb.group_create(index_uri) @@ -488,32 +355,24 @@ def create( group = tiledb.Group(index_uri, "w") vector_index_uri = get_vector_index_uri(group.uri) docs_uri = get_documents_array_uri(group.uri) + create_kwargs = dict( + uri=vector_index_uri, + dimensions=dimensions, + vector_type=vector_type, + config=config, + distance_metric=distance_enum, + ) if index_type == "FLAT": - tiledb_vs.flat_index.create( - uri=vector_index_uri, - dimensions=dimensions, - vector_type=vector_type, - config=config, - ) + tiledb_vs.flat_index.create(**create_kwargs) elif index_type == "IVF_FLAT": - tiledb_vs.ivf_flat_index.create( - uri=vector_index_uri, - dimensions=dimensions, - vector_type=vector_type, - config=config, - ) + tiledb_vs.ivf_flat_index.create(**create_kwargs) group.add(vector_index_uri, name=VECTOR_INDEX_NAME) - - # Create TileDB array to store Documents - # TODO add a Document store API to tiledb-vector-search to allow storing - # different types of objects and metadata in a more generic way. dim = tiledb.Dim( name="id", domain=(0, MAX_UINT64 - 1), dtype=np.dtype(np.uint64), ) dom = tiledb.Domain(dim) - text_attr = tiledb.Attr(name="text", dtype=np.dtype("U1"), var=True) attrs = [text_attr] if metadatas: @@ -541,40 +400,33 @@ def __from( ids: Optional[List[str]] = None, metric: str = DEFAULT_METRIC, index_type: str = "FLAT", + vector_dtype: Optional[np.dtype] = None, config: Optional[Mapping[str, Any]] = None, index_timestamp: int = 0, **kwargs: Any, - ) -> TileDB: + ) -> "TileDB": if metric not in INDEX_METRICS: - raise ValueError( - ( - f"Unsupported distance metric: {metric}. " - f"Expected one of {list(INDEX_METRICS)}" - ) - ) - tiledb_vs, tiledb = ( - guard_import("tiledb.vector_search"), - guard_import("tiledb"), - ) - input_vectors = np.array(embeddings).astype(np.float32) + raise ValueError + vector_dtype = _resolve_vector_dtype(vector_dtype, np.asarray(embeddings[0])) + input_vectors = np.asarray(embeddings, dtype=vector_dtype) cls.create( index_uri=index_uri, index_type=index_type, dimensions=input_vectors.shape[1], - vector_type=input_vectors.dtype, + vector_type=vector_dtype, metadatas=metadatas is not None, config=config, + metric=metric, ) + tiledb_vs, tiledb = dependable_tiledb_import() with tiledb.scope_ctx(ctx_or_config=config): if not embeddings: - raise ValueError("embeddings must be provided to build a TileDB index") - + raise ValueError vector_index_uri = get_vector_index_uri(index_uri) docs_uri = get_documents_array_uri(index_uri) if ids is None: ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts] - external_ids = np.array(ids).astype(np.uint64) - + external_ids = np.asarray(ids, dtype=np.uint64) tiledb_vs.ingestion.ingest( index_type=index_type, index_uri=vector_index_uri, @@ -582,25 +434,18 @@ def __from( external_ids=external_ids, index_timestamp=index_timestamp if index_timestamp != 0 else None, config=config, + distance_metric=_metric_to_enum(metric), **kwargs, ) with tiledb.open(docs_uri, "w") as A: - if external_ids is None: - external_ids = np.zeros(len(texts), dtype=np.uint64) - for i in range(len(texts)): - external_ids[i] = i - data = {} - data["text"] = np.array(texts) + data = {"text": np.asarray(texts)} if metadatas is not None: metadata_attr = np.empty([len(metadatas)], dtype=object) - i = 0 - for metadata in metadatas: + for i, md in enumerate(metadatas): metadata_attr[i] = np.frombuffer( - pickle.dumps(metadata), dtype=np.uint8 + pickle.dumps(md), dtype=np.uint8 ) - i += 1 data["metadata"] = metadata_attr - A[external_ids] = data return cls( embedding=embedding, @@ -610,118 +455,22 @@ def __from( **kwargs, ) - def delete( - self, ids: Optional[List[str]] = None, timestamp: int = 0, **kwargs: Any - ) -> Optional[bool]: - """Delete by vector ID or other criteria. - - Args: - ids: List of ids to delete. - timestamp: Optional timestamp to delete with. - **kwargs: Other keyword arguments that subclasses might use. - - Returns: - Optional[bool]: True if deletion is successful, - False otherwise, None if not implemented. - """ - - external_ids = np.array(ids).astype(np.uint64) - self.vector_index.delete_batch( - external_ids=external_ids, timestamp=timestamp if timestamp != 0 else None - ) - return True - - def add_texts( - self, - texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, - timestamp: int = 0, - **kwargs: Any, - ) -> List[str]: - """Run more texts through the embeddings and add to the vectorstore. - - Args: - texts: Iterable of strings to add to the vectorstore. - metadatas: Optional list of metadatas associated with the texts. - ids: Optional ids of each text object. - timestamp: Optional timestamp to write new texts with. - kwargs: vectorstore specific parameters - - Returns: - List of ids from adding the texts into the vectorstore. - """ - tiledb = guard_import("tiledb") - embeddings = self.embedding.embed_documents(list(texts)) - if ids is None: - ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts] - - external_ids = np.array(ids).astype(np.uint64) - vectors = np.empty((len(embeddings)), dtype="O") - for i in range(len(embeddings)): - vectors[i] = np.array(embeddings[i], dtype=np.float32) - self.vector_index.update_batch( - vectors=vectors, - external_ids=external_ids, - timestamp=timestamp if timestamp != 0 else None, - ) - - docs = {} - docs["text"] = np.array(texts) - if metadatas is not None: - metadata_attr = np.empty([len(metadatas)], dtype=object) - i = 0 - for metadata in metadatas: - metadata_attr[i] = np.frombuffer(pickle.dumps(metadata), dtype=np.uint8) - i += 1 - docs["metadata"] = metadata_attr - - docs_array = tiledb.open( - self.docs_array_uri, - "w", - timestamp=timestamp if timestamp != 0 else None, - config=self.config, - ) - docs_array[external_ids] = docs - docs_array.close() - return ids - @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, + *, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, metric: str = DEFAULT_METRIC, index_uri: str = "/tmp/tiledb_array", index_type: str = "FLAT", + vector_dtype: Optional[np.dtype] = None, config: Optional[Mapping[str, Any]] = None, index_timestamp: int = 0, **kwargs: Any, - ) -> TileDB: - """Construct a TileDB index from raw documents. - - Args: - texts: List of documents to index. - embedding: Embedding function to use. - metadatas: List of metadata dictionaries to associate with documents. - ids: Optional ids of each text object. - metric: Metric to use for indexing. Defaults to "euclidean". - index_uri: The URI to write the TileDB arrays - index_type: Optional, Vector index type ("FLAT", IVF_FLAT") - config: Optional, TileDB config - index_timestamp: Optional, timestamp to write new texts with. - - Example: - .. code-block:: python - - from langchain_community import TileDB - from langchain_community.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - index = TileDB.from_texts(texts, embeddings) - """ - embeddings = [] + ) -> "TileDB": embeddings = embedding.embed_documents(texts) return cls.__from( texts=texts, @@ -732,6 +481,7 @@ def from_texts( metric=metric, index_uri=index_uri, index_type=index_type, + vector_dtype=vector_dtype, config=config, index_timestamp=index_timestamp, **kwargs, @@ -748,35 +498,13 @@ def from_embeddings( ids: Optional[List[str]] = None, metric: str = DEFAULT_METRIC, index_type: str = "FLAT", + vector_dtype: Optional[np.dtype] = None, config: Optional[Mapping[str, Any]] = None, index_timestamp: int = 0, **kwargs: Any, - ) -> TileDB: - """Construct TileDB index from embeddings. - - Args: - text_embeddings: List of tuples of (text, embedding) - embedding: Embedding function to use. - index_uri: The URI to write the TileDB arrays - metadatas: List of metadata dictionaries to associate with documents. - metric: Optional, Metric to use for indexing. Defaults to "euclidean". - index_type: Optional, Vector index type ("FLAT", IVF_FLAT") - config: Optional, TileDB config - index_timestamp: Optional, timestamp to write new texts with. - - Example: - .. code-block:: python - - from langchain_community import TileDB - from langchain_community.embeddings import OpenAIEmbeddings - embeddings = OpenAIEmbeddings() - text_embeddings = embeddings.embed_documents(texts) - text_embedding_pairs = list(zip(texts, text_embeddings)) - db = TileDB.from_embeddings(text_embedding_pairs, embeddings) - """ + ) -> "TileDB": texts = [t[0] for t in text_embeddings] embeddings = [t[1] for t in text_embeddings] - return cls.__from( texts=texts, embeddings=embeddings, @@ -786,6 +514,7 @@ def from_embeddings( metric=metric, index_uri=index_uri, index_type=index_type, + vector_dtype=vector_dtype, config=config, index_timestamp=index_timestamp, **kwargs, @@ -801,16 +530,7 @@ def load( config: Optional[Mapping[str, Any]] = None, timestamp: Any = None, **kwargs: Any, - ) -> TileDB: - """Load a TileDB index from a URI. - - Args: - index_uri: The URI of the TileDB vector index. - embedding: Embeddings to use when generating queries. - metric: Optional, Metric to use for indexing. Defaults to "euclidean". - config: Optional, TileDB config - timestamp: Optional, timestamp to use for opening the arrays. - """ + ) -> "TileDB": return cls( embedding=embedding, index_uri=index_uri, @@ -820,5 +540,56 @@ def load( **kwargs, ) + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + timestamp: int = 0, + **kwargs: Any, + ) -> List[str]: + tiledb = guard_import("tiledb") + embeddings = self.embedding.embed_documents(list(texts)) + embeddings_np = np.asarray(embeddings) + target_dtype = self._index_dtype + if embeddings_np.dtype in _HALF_DTYPES and target_dtype == np.float32: + embeddings_np = embeddings_np.astype(np.float32) + vectors = embeddings_np.astype(target_dtype) + if ids is None: + ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts] + external_ids = np.asarray(ids, dtype=np.uint64) + vectors_object = np.empty((len(vectors)), dtype="O") + vectors_object[:] = [v for v in vectors] + self.vector_index.update_batch( + vectors=vectors_object, + external_ids=external_ids, + timestamp=timestamp if timestamp != 0 else None, + ) + docs = {"text": np.asarray(texts)} + if metadatas is not None: + metadata_attr = np.empty([len(metadatas)], dtype=object) + for i, md in enumerate(metadatas): + metadata_attr[i] = np.frombuffer(pickle.dumps(md), dtype=np.uint8) + docs["metadata"] = metadata_attr + docs_array = tiledb.open( + self.docs_array_uri, + "w", + timestamp=timestamp if timestamp != 0 else None, + config=self.config, + ) + docs_array[external_ids] = docs + docs_array.close() + return ids + + def delete( + self, ids: Optional[List[str]] = None, timestamp: int = 0, **kwargs: Any + ) -> Optional[bool]: + external_ids = np.asarray(ids, dtype=np.uint64) + self.vector_index.delete_batch( + external_ids=external_ids, + timestamp=timestamp if timestamp != 0 else None, + ) + return True + def consolidate_updates(self, **kwargs: Any) -> None: self.vector_index = self.vector_index.consolidate_updates(**kwargs)