-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[ENH] Embed query strings in search api #5599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: jai/schema-js-impl
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,6 @@ | |
) | ||
from chromadb.api.collection_configuration import UpdateCollectionConfiguration | ||
from chromadb.execution.expression.plan import Search | ||
from typing import cast, List | ||
propel-code-bot[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import logging | ||
|
||
|
@@ -362,9 +361,14 @@ def search( | |
if searches_list is None: | ||
searches_list = [] | ||
|
||
# Embed any string queries in Knn objects | ||
embedded_searches = [ | ||
self._embed_search_string_queries(search) for search in searches_list | ||
Comment on lines
+365
to
+366
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [BestPractice] Same error handling concern applies here. If Context for Agents
|
||
] | ||
|
||
return self._client._search( | ||
collection_id=self.id, | ||
searches=cast(List[Search], searches_list), | ||
searches=cast(List[Search], embedded_searches), | ||
tenant=self.tenant, | ||
database=self.database, | ||
) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,6 +17,7 @@ | |||||||||||||||||||||||||||
from uuid import UUID | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
from chromadb.api.types import ( | ||||||||||||||||||||||||||||
EMBEDDING_KEY, | ||||||||||||||||||||||||||||
URI, | ||||||||||||||||||||||||||||
Schema, | ||||||||||||||||||||||||||||
SparseVectorIndexConfig, | ||||||||||||||||||||||||||||
|
@@ -741,3 +742,232 @@ def _sparse_embed( | |||||||||||||||||||||||||||
if is_query: | ||||||||||||||||||||||||||||
return sparse_embedding_function.embed_query(input=input) | ||||||||||||||||||||||||||||
return sparse_embedding_function(input=input) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def _embed_knn_string_queries(self, knn: Any) -> Any: | ||||||||||||||||||||||||||||
"""Embed string queries in Knn objects using the appropriate embedding function. | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||
knn: A Knn object that may have a string query | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||
A Knn object with the string query replaced by an embedding | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Raises: | ||||||||||||||||||||||||||||
ValueError: If the query is a string but no embedding function is available | ||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||
from chromadb.execution.expression.operator import Knn | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if not isinstance(knn, Knn): | ||||||||||||||||||||||||||||
return knn | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# If query is not a string, nothing to do | ||||||||||||||||||||||||||||
if not isinstance(knn.query, str): | ||||||||||||||||||||||||||||
return knn | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
query_text = knn.query | ||||||||||||||||||||||||||||
key = knn.key | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Handle main embedding field | ||||||||||||||||||||||||||||
if key == EMBEDDING_KEY: | ||||||||||||||||||||||||||||
# Use the collection's main embedding function | ||||||||||||||||||||||||||||
embedding = self._embed(input=[query_text], is_query=True) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [BestPractice] Error handling gap: Multiple embedding function calls ( Consider wrapping embedding calls: try:
embedding = self._embed(input=[query_text], is_query=True)
except Exception as e:
raise ValueError(f"Failed to embed query '{query_text}': {e}") from e Context for Agents
|
||||||||||||||||||||||||||||
if not embedding or len(embedding) != 1: | ||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||
"Embedding function returned unexpected number of embeddings" | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
# Return a new Knn with the embedded query | ||||||||||||||||||||||||||||
return Knn( | ||||||||||||||||||||||||||||
query=embedding[0], | ||||||||||||||||||||||||||||
key=knn.key, | ||||||||||||||||||||||||||||
limit=knn.limit, | ||||||||||||||||||||||||||||
default=knn.default, | ||||||||||||||||||||||||||||
return_rank=knn.return_rank, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Handle metadata field with potential sparse embedding | ||||||||||||||||||||||||||||
schema = self.schema | ||||||||||||||||||||||||||||
if schema is None or key not in schema.key_overrides: | ||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||
f"Cannot embed string query for key '{key}': " | ||||||||||||||||||||||||||||
f"key not found in schema. Please provide an embedded vector or " | ||||||||||||||||||||||||||||
f"configure an embedding function for this key in the schema." | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
value_type = schema.key_overrides[key] | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Check for sparse vector with embedding function | ||||||||||||||||||||||||||||
if value_type.sparse_vector is not None: | ||||||||||||||||||||||||||||
sparse_index = value_type.sparse_vector.sparse_vector_index | ||||||||||||||||||||||||||||
if sparse_index is not None and sparse_index.enabled: | ||||||||||||||||||||||||||||
config = sparse_index.config | ||||||||||||||||||||||||||||
if config.embedding_function is not None: | ||||||||||||||||||||||||||||
embedding_func = config.embedding_function | ||||||||||||||||||||||||||||
if not isinstance(embedding_func, SparseEmbeddingFunction): | ||||||||||||||||||||||||||||
embedding_func = cast( | ||||||||||||||||||||||||||||
propel-code-bot[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
SparseEmbeddingFunction[Any], embedding_func | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
validate_sparse_embedding_function(embedding_func) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Embed the query | ||||||||||||||||||||||||||||
sparse_embedding = self._sparse_embed( | ||||||||||||||||||||||||||||
Comment on lines
+811
to
+812
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [BestPractice] Error handling gap: Similar to the main embedding case, if Consider adding context: Suggested Change
Suggested change
⚡ Committable suggestion Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Context for Agents
|
||||||||||||||||||||||||||||
input=[query_text], | ||||||||||||||||||||||||||||
sparse_embedding_function=embedding_func, | ||||||||||||||||||||||||||||
is_query=True, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if not sparse_embedding or len(sparse_embedding) != 1: | ||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||
"Sparse embedding function returned unexpected number of embeddings" | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Return a new Knn with the sparse embedding | ||||||||||||||||||||||||||||
return Knn( | ||||||||||||||||||||||||||||
query=sparse_embedding[0], | ||||||||||||||||||||||||||||
key=knn.key, | ||||||||||||||||||||||||||||
limit=knn.limit, | ||||||||||||||||||||||||||||
default=knn.default, | ||||||||||||||||||||||||||||
return_rank=knn.return_rank, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Check for dense vector with embedding function (float_list) | ||||||||||||||||||||||||||||
if value_type.float_list is not None: | ||||||||||||||||||||||||||||
vector_index = value_type.float_list.vector_index | ||||||||||||||||||||||||||||
if vector_index is not None and vector_index.enabled: | ||||||||||||||||||||||||||||
config = vector_index.config | ||||||||||||||||||||||||||||
if config.embedding_function is not None: | ||||||||||||||||||||||||||||
embedding_func = config.embedding_function | ||||||||||||||||||||||||||||
validate_embedding_function(embedding_func) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Embed the query using the schema's embedding function | ||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||
embeddings = embedding_func.embed_query(input=[query_text]) | ||||||||||||||||||||||||||||
Comment on lines
+842
to
+843
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [BestPractice] Error handling issue: The try:
embeddings = embedding_func.embed_query(input=[query_text])
except AttributeError:
# Fallback if embed_query doesn't exist
embeddings = embedding_func([query_text])
except Exception as e:
raise ValueError(
f"Failed to embed string query '{query_text}' using embedding function: {e}"
) from e Context for Agents
|
||||||||||||||||||||||||||||
except AttributeError: | ||||||||||||||||||||||||||||
# Fallback if embed_query doesn't exist | ||||||||||||||||||||||||||||
embeddings = embedding_func([query_text]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if not embeddings or len(embeddings) != 1: | ||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||
"Embedding function returned unexpected number of embeddings" | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Return a new Knn with the dense embedding | ||||||||||||||||||||||||||||
return Knn( | ||||||||||||||||||||||||||||
query=embeddings[0], | ||||||||||||||||||||||||||||
key=knn.key, | ||||||||||||||||||||||||||||
limit=knn.limit, | ||||||||||||||||||||||||||||
default=knn.default, | ||||||||||||||||||||||||||||
return_rank=knn.return_rank, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||
f"Cannot embed string query for key '{key}': " | ||||||||||||||||||||||||||||
f"no embedding function configured for this key in the schema. " | ||||||||||||||||||||||||||||
f"Please provide an embedded vector or configure an embedding function." | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def _embed_rank_string_queries(self, rank: Any) -> Any: | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [BestPractice] Potential infinite recursion: The method Consider adding a depth limit: def _embed_rank_string_queries(self, rank: Any, depth: int = 0, max_depth: int = 100) -> Any:
if depth > max_depth:
raise ValueError(f"Maximum recursion depth ({max_depth}) exceeded in rank expression")
# ... existing logic with depth + 1 passed to recursive calls Context for Agents
|
||||||||||||||||||||||||||||
"""Recursively embed string queries in Rank expressions. | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||
rank: A Rank expression that may contain Knn objects with string queries | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||
A Rank expression with all string queries embedded | ||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||
# Import here to avoid circular dependency | ||||||||||||||||||||||||||||
from chromadb.execution.expression.operator import ( | ||||||||||||||||||||||||||||
Knn, | ||||||||||||||||||||||||||||
Abs, | ||||||||||||||||||||||||||||
Div, | ||||||||||||||||||||||||||||
Exp, | ||||||||||||||||||||||||||||
Log, | ||||||||||||||||||||||||||||
Max, | ||||||||||||||||||||||||||||
Min, | ||||||||||||||||||||||||||||
Mul, | ||||||||||||||||||||||||||||
Sub, | ||||||||||||||||||||||||||||
Sum, | ||||||||||||||||||||||||||||
Val, | ||||||||||||||||||||||||||||
Rrf, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if rank is None: | ||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Base case: Knn - embed if it has a string query | ||||||||||||||||||||||||||||
if isinstance(rank, Knn): | ||||||||||||||||||||||||||||
return self._embed_knn_string_queries(rank) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Base case: Val - no embedding needed | ||||||||||||||||||||||||||||
if isinstance(rank, Val): | ||||||||||||||||||||||||||||
return rank | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Recursive cases: walk through child ranks | ||||||||||||||||||||||||||||
if isinstance(rank, Abs): | ||||||||||||||||||||||||||||
return Abs(self._embed_rank_string_queries(rank.rank)) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if isinstance(rank, Div): | ||||||||||||||||||||||||||||
return Div( | ||||||||||||||||||||||||||||
self._embed_rank_string_queries(rank.left), | ||||||||||||||||||||||||||||
self._embed_rank_string_queries(rank.right), | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if isinstance(rank, Exp): | ||||||||||||||||||||||||||||
return Exp(self._embed_rank_string_queries(rank.rank)) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if isinstance(rank, Log): | ||||||||||||||||||||||||||||
return Log(self._embed_rank_string_queries(rank.rank)) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if isinstance(rank, Max): | ||||||||||||||||||||||||||||
return Max([self._embed_rank_string_queries(r) for r in rank.ranks]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if isinstance(rank, Min): | ||||||||||||||||||||||||||||
return Min([self._embed_rank_string_queries(r) for r in rank.ranks]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if isinstance(rank, Mul): | ||||||||||||||||||||||||||||
return Mul([self._embed_rank_string_queries(r) for r in rank.ranks]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if isinstance(rank, Sub): | ||||||||||||||||||||||||||||
return Sub( | ||||||||||||||||||||||||||||
self._embed_rank_string_queries(rank.left), | ||||||||||||||||||||||||||||
self._embed_rank_string_queries(rank.right), | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if isinstance(rank, Sum): | ||||||||||||||||||||||||||||
return Sum([self._embed_rank_string_queries(r) for r in rank.ranks]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if isinstance(rank, Rrf): | ||||||||||||||||||||||||||||
return Rrf( | ||||||||||||||||||||||||||||
ranks=[self._embed_rank_string_queries(r) for r in rank.ranks], | ||||||||||||||||||||||||||||
k=rank.k, | ||||||||||||||||||||||||||||
weights=rank.weights, | ||||||||||||||||||||||||||||
normalize=rank.normalize, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Unknown rank type - return as is | ||||||||||||||||||||||||||||
return rank | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def _embed_search_string_queries(self, search: Any) -> Any: | ||||||||||||||||||||||||||||
"""Embed string queries in a Search object. | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||
search: A Search object that may contain Knn objects with string queries | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||
A Search object with all string queries embedded | ||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||
# Import here to avoid circular dependency | ||||||||||||||||||||||||||||
from chromadb.execution.expression.plan import Search | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if not isinstance(search, Search): | ||||||||||||||||||||||||||||
return search | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Embed the rank expression if it exists | ||||||||||||||||||||||||||||
embedded_rank = self._embed_rank_string_queries(search._rank) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Create a new Search with the embedded rank | ||||||||||||||||||||||||||||
return Search( | ||||||||||||||||||||||||||||
where=search._where, | ||||||||||||||||||||||||||||
rank=embedded_rank, | ||||||||||||||||||||||||||||
limit=search._limit, | ||||||||||||||||||||||||||||
select=search._select, | ||||||||||||||||||||||||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[BestPractice]
Potential error handling gap: If
_embed_search_string_queries()
raises an exception for any search in the list comprehension, the entire operation will fail and no searches will be processed. Consider adding error handling to gracefully handle individual search embedding failures:Context for Agents