diff --git a/descarteslabs/core/compute/function.py b/descarteslabs/core/compute/function.py index c7c9a859..0e8423e2 100644 --- a/descarteslabs/core/compute/function.py +++ b/descarteslabs/core/compute/function.py @@ -24,8 +24,10 @@ import re import sys import time +import uuid import warnings import zipfile +from collections import UserList from datetime import datetime from tempfile import NamedTemporaryFile from typing import ( @@ -44,7 +46,8 @@ import pkg_resources from strenum import StrEnum -from descarteslabs.exceptions import ConflictError +import descarteslabs.exceptions as exceptions + from ..client.deprecation import deprecate from ..client.services.service import ThirdPartyService from ..common.client import ( @@ -59,10 +62,18 @@ from .job import Job, JobSearch, JobStatus from .result import Serializable - MAX_FUNCTION_IDS_PER_REQUEST = 128 +def batched(iterable, n): + """Batch an iterable into lists of size n""" + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := list(itertools.islice(it, n)): + yield batch + + class FunctionStatus(StrEnum): "The status of the Function." @@ -106,6 +117,52 @@ def __new__(cls, memory: Union[str, int, float]) -> None: return super().__new__(cls, memory) +class JobBulkCreateError: + """An error that occurred while submitting a bulk job.""" + + def __init__( + self, + function: "Function", + args, + kwargs, + exception: Exception, + reference_id: str, + ): + self.function = function + self.args = args + self.kwargs = kwargs + self.exception = exception + self.reference_id = reference_id + + def __str__(self): + return f"{self.reference_id}: {self.exception}" + + def __repr__(self): + return ( + f"JobBulkCreateError(" + f"function={self.function}, " + f"reference_id={self.reference_id}, " + f"exception={repr(self.exception)})" + ) + + +class JobBulkCreateResult(UserList[Job]): + """The result of a bulk job submission.""" + + def __init__(self): + super().__init__() + self.errors: List[JobBulkCreateError] = [] + + def append_error(self, error: JobBulkCreateError): + """Append an error to the result.""" + self.errors.append(error) + + @property + def is_success(self) -> bool: + """Returns true if all jobs were successfully submitted.""" + return len(self.errors) == 0 + + class Function(Document): """The serverless cloud function that you can call directly or submit many jobs to.""" @@ -878,7 +935,7 @@ def delete(self, delete_results: bool = False): for job in self.jobs: try: job.delete(delete_result=delete_results) - except ConflictError: + except exceptions.ConflictError: pass self._client.session.delete(f"/functions/{self.id}") @@ -993,7 +1050,8 @@ def map( args: Iterable[Iterable[Any]], kwargs: Iterable[Mapping[str, Any]] = None, tags: List[str] = None, - ) -> List[Job]: + batch_size: int = 1000, + ) -> JobBulkCreateResult: """Submits multiple jobs efficiently with positional args to each function call. Preferred over repeatedly calling the function, such as in a loop, when submitting @@ -1012,6 +1070,15 @@ def map( >>> async_func('a', 'b', x=1) # doctest: +SKIP >>> async_func('c', 'd', x=2) # doctest: +SKIP + Notes + ----- + Map is idempotent for the initial call such that request errors that occur once started, + will not cause duplicate jobs to be submitted. However, if the method is called again + with the same arguments, it will submit duplicate jobs. + + You should always check the return value to ensure all jobs were submitted successfully + and handle any errors that may have occurred. + Parameters ---------- args : Iterable[Iterable[Any]] @@ -1024,10 +1091,32 @@ def map( be expanded into keyword arguments for the function. tags : List[str], optional A list of tags to apply to all jobs submitted. + batch_size : int, default=1000 + The number of jobs to submit in each batch. The maximum batch size is 1000. + + Return + ------ + JobBulkCreateResult + An object containing the jobs that were submitted and any errors that occurred. + This object is compatible with a list of Job objects for backwards compatibility. + + If the value of `JobBulkCreateResult.is_success` is False, you should check + `JobBulkCreateResult.errors` and handle any errors that occurred. + + Raises + ------ + ClientError, ServerError + If the request to create the first batch of jobs, fails after all retries have + been exhausted. + + Otherwise, any errors will be available in the returned JobBulkCreateResult. """ if self.state != DocumentState.SAVED: raise ValueError("Cannot execute a Function that has not been saved") + if batch_size < 1 or batch_size > 1000: + raise ValueError("Batch size must between 1 and 1000") + args = [list(iterable) for iterable in args] if kwargs is not None: kwargs = [dict(mapping) for mapping in kwargs] @@ -1036,17 +1125,50 @@ def map( "The number of kwargs must match the number of args. " f"Got {len(args)} args and {len(kwargs)} kwargs." ) - payload = { - "function_id": self.id, - "bulk_args": args, - "bulk_kwargs": kwargs, - } - if tags: - payload["tags"] = tags + result = JobBulkCreateResult() + + # Send the jobs in batches of batch_size + for index, (positional, named) in enumerate( + itertools.zip_longest( + batched(args, batch_size), batched(kwargs or [], batch_size) + ) + ): + payload = { + "function_id": self.id, + "bulk_args": positional, + "bulk_kwargs": named, + "reference_id": str(uuid.uuid4()), + } + + if tags: + payload["tags"] = tags + + # This implementation uses a `reference_id` to ensure that the request is idempotent + # and duplicate jobs are not submitted in a retry scenario. + try: + response = self._client.session.post("/jobs/bulk", json=payload) + result.extend([Job(**data, saved=True) for data in response.json()]) + except exceptions.NotFoundError: + # If one of these errors occurs, we cannot continue submitting any jobs at all + raise + except Exception as exc: + if index == 0: + # The first batch failed, let the user deal with the exception as all + # the batches would likely fail. + raise + + result.append_error( + JobBulkCreateError( + function=self, + args=payload["bulk_args"], + kwargs=payload["bulk_kwargs"], + reference_id=payload["reference_id"], + exception=exc, + ) + ) - response = self._client.session.post("/jobs/bulk", json=payload) - return [Job(**data, saved=True) for data in response.json()] + return result def cancel_jobs(self, query: Optional[JobSearch] = None, job_ids: List[str] = None): """Cancels all jobs for the Function matching the given query. diff --git a/descarteslabs/core/compute/tests/test_function.py b/descarteslabs/core/compute/tests/test_function.py index 17c5bf1e..eafcaa15 100644 --- a/descarteslabs/core/compute/tests/test_function.py +++ b/descarteslabs/core/compute/tests/test_function.py @@ -1,4 +1,5 @@ import gzip +import itertools import json import os import random @@ -8,6 +9,7 @@ from datetime import timezone import responses +from requests import PreparedRequest from descarteslabs import exceptions from descarteslabs.compute import Function, FunctionStatus, Job, JobStatus @@ -648,17 +650,99 @@ def test_map(self): ) fn = Function(id="some-id", saved=True) - fn.map( + result = fn.map( [[1, 2], [3, 4]], kwargs=[{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}], ) + assert result.is_success + assert len(result) == 2 + for job in result: + assert isinstance(job, Job) + request = responses.calls[-1].request - assert json.loads(request.body) == { + request_json: dict = json.loads(request.body) + assert request_json.pop("reference_id") is not None + assert request_json == { "bulk_args": [[1, 2], [3, 4]], "bulk_kwargs": [{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}], "function_id": "some-id", } + @responses.activate + def test_map_batching(self): + def request_callback(request: PreparedRequest): + payload: dict = json.loads(request.body) + jobs = [] + + args = payload["bulk_args"] or [] + kwargs = payload["bulk_kwargs"] or [] + + for args, kwargs in itertools.zip_longest(args, kwargs): + jobs.append(self.make_job(args=args, kwargs=kwargs)) + + return (200, {}, json.dumps(jobs)) + + responses.add_callback( + responses.POST, + f"{self.compute_url}/jobs/bulk", + callback=request_callback, + ) + + fn = Function(id="some-id", saved=True) + result = fn.map([[n, n + 1] for n in range(3000)]) + assert result.is_success is True, result.errors + assert len(result) == 3000 + reference_ids = { + json.loads(call.request.body)["reference_id"] for call in responses.calls + } + assert len(reference_ids) == 3 + + @responses.activate + def test_map_errors(self): + global call_count + call_count = 0 + + def request_callback(request: PreparedRequest): + global call_count + call_count += 1 + + if call_count > 1: + return (500, {}, None) + + payload: dict = json.loads(request.body) + jobs = [] + + args = payload["bulk_args"] or [] + kwargs = payload["bulk_kwargs"] or [] + + for args, kwargs in itertools.zip_longest(args, kwargs): + jobs.append(self.make_job(args=args, kwargs=kwargs)) + + return (200, {}, json.dumps(jobs)) + + responses.add_callback( + responses.POST, + f"{self.compute_url}/jobs/bulk", + callback=request_callback, + ) + + fn = Function(id="some-id", saved=True) + result = fn.map( + [[1, 2], [3, 4]], + kwargs=[{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}], + batch_size=1, + ) + assert result.is_success is False + assert len(result) == 1 + assert len(result.errors) == 1 + assert result.errors[0].args == [[3, 4]] + assert result.errors[0].kwargs == [{"first": 1.0, "second": 2.0}] + assert len(responses.calls) == 2 + reference_ids = { + json.loads(call.request.body)["reference_id"] for call in responses.calls + } + assert len(reference_ids) == 2 + @responses.activate def test_map_deprecated(self): self.mock_response( @@ -673,7 +757,9 @@ def test_map_deprecated(self): iterargs=[{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}], ) request = responses.calls[-1].request - assert json.loads(request.body) == { + request_json: dict = json.loads(request.body) + assert request_json.pop("reference_id") is not None + assert request_json == { "bulk_args": [[1, 2], [3, 4]], "bulk_kwargs": [{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}], "function_id": "some-id", @@ -706,7 +792,9 @@ def inner(t): kwgenerator(), ) request = responses.calls[-1].request - assert json.loads(request.body) == { + request_json: dict = json.loads(request.body) + assert request_json.pop("reference_id") is not None + assert request_json == { "bulk_args": [[1, 2], [3, 4]], "bulk_kwargs": [{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}], "function_id": "some-id", @@ -727,7 +815,9 @@ def test_map_with_tags(self): tags=["tag1", "tag2"], ) request = responses.calls[-1].request - assert json.loads(request.body) == { + request_json: dict = json.loads(request.body) + assert request_json.pop("reference_id") is not None + assert request_json == { "bulk_args": [[1, 2], [3, 4]], "bulk_kwargs": [{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}], "function_id": "some-id",