Skip to content

Commit a1843a1

Browse files
committed
Add REST API endpoints for task management
1 parent 6abeaa6 commit a1843a1

File tree

14 files changed

+1297
-69
lines changed

14 files changed

+1297
-69
lines changed

chromadb/api/__init__.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,3 +770,53 @@ def _delete(
770770
database: str = DEFAULT_DATABASE,
771771
) -> None:
772772
pass
773+
774+
@abstractmethod
775+
def create_task(
776+
self,
777+
task_name: str,
778+
operator_id: str,
779+
input_collection_name: str,
780+
output_collection_name: str,
781+
params: Optional[str] = None,
782+
tenant: str = DEFAULT_TENANT,
783+
database: str = DEFAULT_DATABASE,
784+
) -> tuple[bool, str]:
785+
"""Create a recurring task on a collection.
786+
787+
Args:
788+
task_name: Unique name for this task instance
789+
operator_id: Built-in operator identifier
790+
input_collection_name: Source collection that triggers the task
791+
output_collection_name: Target collection where task output is stored
792+
params: Optional JSON string with operator-specific parameters
793+
tenant: The tenant name
794+
database: The database name
795+
796+
Returns:
797+
tuple: (success: bool, task_id: str)
798+
"""
799+
pass
800+
801+
@abstractmethod
802+
def remove_task(
803+
self,
804+
task_name: str,
805+
input_collection_name: str,
806+
delete_output: bool = False,
807+
tenant: str = DEFAULT_TENANT,
808+
database: str = DEFAULT_DATABASE,
809+
) -> bool:
810+
"""Delete a task and prevent any further runs.
811+
812+
Args:
813+
task_name: Name of the task to remove
814+
input_collection_name: Name of the input collection the task is registered on
815+
delete_output: Whether to also delete the output collection
816+
tenant: The tenant name
817+
database: The database name
818+
819+
Returns:
820+
bool: True if successful
821+
"""
822+
pass

chromadb/api/fastapi.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,3 +695,64 @@ def get_max_batch_size(self) -> int:
695695
pre_flight_checks = self.get_pre_flight_checks()
696696
max_batch_size = cast(int, pre_flight_checks.get("max_batch_size", -1))
697697
return max_batch_size
698+
699+
@trace_method("FastAPI.create_task", OpenTelemetryGranularity.ALL)
700+
@override
701+
def create_task(
702+
self,
703+
task_name: str,
704+
operator_id: str,
705+
input_collection_name: str,
706+
output_collection_name: str,
707+
params: Optional[str] = None,
708+
tenant: str = DEFAULT_TENANT,
709+
database: str = DEFAULT_DATABASE,
710+
) -> tuple[bool, str]:
711+
"""Register a recurring task on a collection."""
712+
# Get collection ID from name
713+
collection = self.get_collection(
714+
input_collection_name, tenant=tenant, database=database
715+
)
716+
717+
resp_json = self._make_request(
718+
"post",
719+
f"/tenants/{tenant}/databases/{database}/collections/{collection.id}/tasks/create",
720+
json={
721+
"tenant_id": tenant,
722+
"database_name": database,
723+
"task_name": task_name,
724+
"operator_id": operator_id,
725+
"input_collection_name": input_collection_name,
726+
"output_collection_name": output_collection_name,
727+
"params": params,
728+
},
729+
)
730+
return cast(bool, resp_json["success"]), cast(str, resp_json["task_id"])
731+
732+
@trace_method("FastAPI.remove_task", OpenTelemetryGranularity.ALL)
733+
@override
734+
def remove_task(
735+
self,
736+
task_name: str,
737+
input_collection_name: str,
738+
delete_output: bool = False,
739+
tenant: str = DEFAULT_TENANT,
740+
database: str = DEFAULT_DATABASE,
741+
) -> bool:
742+
"""Delete a task and prevent any further runs."""
743+
# Get collection ID from name
744+
collection = self.get_collection(
745+
input_collection_name, tenant=tenant, database=database
746+
)
747+
748+
resp_json = self._make_request(
749+
"post",
750+
f"/tenants/{tenant}/databases/{database}/collections/{collection.id}/tasks/delete",
751+
json={
752+
"tenant_id": tenant,
753+
"database_name": database,
754+
"task_name": task_name,
755+
"delete_output": delete_output,
756+
},
757+
)
758+
return cast(bool, resp_json["success"])

chromadb/api/models/Collection.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,29 +327,29 @@ def search(
327327
from chromadb.execution.expression import (
328328
Search, Key, K, Knn, Val
329329
)
330-
330+
331331
# Note: K is an alias for Key, so K.DOCUMENT == Key.DOCUMENT
332332
search = (Search()
333333
.where((K("category") == "science") & (K("score") > 0.5))
334334
.rank(Knn(query=[0.1, 0.2, 0.3]) * 0.8 + Val(0.5) * 0.2)
335335
.limit(10, offset=0)
336336
.select(K.DOCUMENT, K.SCORE, "title"))
337-
337+
338338
# Direct construction
339339
from chromadb.execution.expression import (
340340
Search, Eq, And, Gt, Knn, Limit, Select, Key
341341
)
342-
342+
343343
search = Search(
344344
where=And([Eq("category", "science"), Gt("score", 0.5)]),
345345
rank=Knn(query=[0.1, 0.2, 0.3]),
346346
limit=Limit(offset=0, limit=10),
347347
select=Select(keys={Key.DOCUMENT, Key.SCORE, "title"})
348348
)
349-
349+
350350
# Single search
351351
result = collection.search(search)
352-
352+
353353
# Multiple searches at once
354354
searches = [
355355
Search().where(K("type") == "article").rank(Knn(query=[0.1, 0.2])),
@@ -490,3 +490,64 @@ def delete(
490490
tenant=self.tenant,
491491
database=self.database,
492492
)
493+
494+
def create_task(
495+
self,
496+
name: str,
497+
operator_id: str,
498+
output_collection: str,
499+
params: Optional[str] = None,
500+
) -> tuple[bool, str]:
501+
"""Create a recurring task that processes this collection.
502+
503+
Args:
504+
name: Unique name for this task instance
505+
operator_id: Built-in operator identifier (e.g., "record_counter")
506+
output_collection: Name of the collection where task output will be stored
507+
params: Optional JSON string with operator-specific parameters
508+
509+
Returns:
510+
tuple: (success: bool, task_id: str)
511+
512+
Example:
513+
>>> success, task_id = collection.create_task(
514+
... name="count_docs",
515+
... operator_id="record_counter",
516+
... output_collection="doc_counts",
517+
... params=None
518+
... )
519+
"""
520+
return self._client.create_task(
521+
task_name=name,
522+
operator_id=operator_id,
523+
input_collection_name=self.name,
524+
output_collection_name=output_collection,
525+
params=params,
526+
tenant=self.tenant,
527+
database=self.database,
528+
)
529+
530+
def remove_task(
531+
self,
532+
name: str,
533+
delete_output: bool = False,
534+
) -> bool:
535+
"""Delete a task and prevent any further runs.
536+
537+
Args:
538+
name: Name of the task to remove
539+
delete_output: Whether to also delete the output collection. Defaults to False.
540+
541+
Returns:
542+
bool: True if successful
543+
544+
Example:
545+
>>> success = collection.remove_task("count_docs", delete_output=True)
546+
"""
547+
return self._client.remove_task(
548+
task_name=name,
549+
input_collection_name=self.name,
550+
delete_output=delete_output,
551+
tenant=self.tenant,
552+
database=self.database,
553+
)

chromadb/api/rust.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,7 @@ def _search(
320320
tenant: str = DEFAULT_TENANT,
321321
database: str = DEFAULT_DATABASE,
322322
) -> SearchResult:
323-
raise NotImplementedError(
324-
"Search is not implemented for Local Chroma"
325-
)
323+
raise NotImplementedError("Search is not implemented for Local Chroma")
326324

327325
@override
328326
def _count(
@@ -583,6 +581,38 @@ def get_settings(self) -> Settings:
583581
def get_max_batch_size(self) -> int:
584582
return self.bindings.get_max_batch_size()
585583

584+
@override
585+
def create_task(
586+
self,
587+
task_name: str,
588+
operator_id: str,
589+
input_collection_name: str,
590+
output_collection_name: str,
591+
params: Optional[str] = None,
592+
tenant: str = DEFAULT_TENANT,
593+
database: str = DEFAULT_DATABASE,
594+
) -> tuple[bool, str]:
595+
"""Tasks are not supported in the Rust bindings (local embedded mode)."""
596+
raise NotImplementedError(
597+
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
598+
"The Rust bindings (embedded mode) do not support task operations."
599+
)
600+
601+
@override
602+
def remove_task(
603+
self,
604+
task_name: str,
605+
input_collection_name: str,
606+
delete_output: bool = False,
607+
tenant: str = DEFAULT_TENANT,
608+
database: str = DEFAULT_DATABASE,
609+
) -> bool:
610+
"""Tasks are not supported in the Rust bindings (local embedded mode)."""
611+
raise NotImplementedError(
612+
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
613+
"The Rust bindings (embedded mode) do not support task operations."
614+
)
615+
586616
# TODO: Remove this if it's not planned to be used
587617
@override
588618
def get_user_identity(self) -> UserIdentity:

chromadb/api/segment.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,7 @@ def _search(
427427
tenant: str = DEFAULT_TENANT,
428428
database: str = DEFAULT_DATABASE,
429429
) -> SearchResult:
430-
raise NotImplementedError(
431-
"Seach is not implemented for SegmentAPI"
432-
)
430+
raise NotImplementedError("Seach is not implemented for SegmentAPI")
433431

434432
@trace_method("SegmentAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
435433
@override
@@ -902,6 +900,38 @@ def get_settings(self) -> Settings:
902900
def get_max_batch_size(self) -> int:
903901
return self._producer.max_batch_size
904902

903+
@override
904+
def create_task(
905+
self,
906+
task_name: str,
907+
operator_id: str,
908+
input_collection_name: str,
909+
output_collection_name: str,
910+
params: Optional[str] = None,
911+
tenant: str = DEFAULT_TENANT,
912+
database: str = DEFAULT_DATABASE,
913+
) -> tuple[bool, str]:
914+
"""Tasks are not supported in the Segment API (local embedded mode)."""
915+
raise NotImplementedError(
916+
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
917+
"The Segment API (embedded mode) does not support task operations."
918+
)
919+
920+
@override
921+
def remove_task(
922+
self,
923+
task_name: str,
924+
input_collection_name: str,
925+
delete_output: bool = False,
926+
tenant: str = DEFAULT_TENANT,
927+
database: str = DEFAULT_DATABASE,
928+
) -> bool:
929+
"""Tasks are not supported in the Segment API (local embedded mode)."""
930+
raise NotImplementedError(
931+
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
932+
"The Segment API (embedded mode) does not support task operations."
933+
)
934+
905935
# TODO: This could potentially cause race conditions in a distributed version of the
906936
# system, since the cache is only local.
907937
# TODO: promote collection -> topic to a base class method so that it can be

0 commit comments

Comments
 (0)