Skip to content

Commit

Permalink
Compute: limit the number of jobs you can create with map (#12341)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 7296af7e54b799605291cdb1e5cbfc87c1530dcf
  • Loading branch information
tkrause authored and Descartes Labs Build committed Dec 4, 2023
1 parent f7d0f2e commit 3a9db9d
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 18 deletions.
148 changes: 135 additions & 13 deletions descarteslabs/core/compute/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import re
import sys
import time
import uuid
import warnings
import zipfile
from collections import UserList
from datetime import datetime
from tempfile import NamedTemporaryFile
from typing import (
Expand All @@ -44,7 +46,8 @@
import pkg_resources
from strenum import StrEnum

from descarteslabs.exceptions import ConflictError
import descarteslabs.exceptions as exceptions

from ..client.deprecation import deprecate
from ..client.services.service import ThirdPartyService
from ..common.client import (
Expand All @@ -59,10 +62,18 @@
from .job import Job, JobSearch, JobStatus
from .result import Serializable


MAX_FUNCTION_IDS_PER_REQUEST = 128


def batched(iterable, n):
"""Batch an iterable into lists of size n"""
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := list(itertools.islice(it, n)):
yield batch


class FunctionStatus(StrEnum):
"The status of the Function."

Expand Down Expand Up @@ -106,6 +117,52 @@ def __new__(cls, memory: Union[str, int, float]) -> None:
return super().__new__(cls, memory)


class JobBulkCreateError:
"""An error that occurred while submitting a bulk job."""

def __init__(
self,
function: "Function",
args,
kwargs,
exception: Exception,
reference_id: str,
):
self.function = function
self.args = args
self.kwargs = kwargs
self.exception = exception
self.reference_id = reference_id

def __str__(self):
return f"{self.reference_id}: {self.exception}"

def __repr__(self):
return (
f"JobBulkCreateError("
f"function={self.function}, "
f"reference_id={self.reference_id}, "
f"exception={repr(self.exception)})"
)


class JobBulkCreateResult(UserList[Job]):
"""The result of a bulk job submission."""

def __init__(self):
super().__init__()
self.errors: List[JobBulkCreateError] = []

def append_error(self, error: JobBulkCreateError):
"""Append an error to the result."""
self.errors.append(error)

@property
def is_success(self) -> bool:
"""Returns true if all jobs were successfully submitted."""
return len(self.errors) == 0


class Function(Document):
"""The serverless cloud function that you can call directly or submit many jobs to."""

Expand Down Expand Up @@ -878,7 +935,7 @@ def delete(self, delete_results: bool = False):
for job in self.jobs:
try:
job.delete(delete_result=delete_results)
except ConflictError:
except exceptions.ConflictError:
pass

self._client.session.delete(f"/functions/{self.id}")
Expand Down Expand Up @@ -993,7 +1050,8 @@ def map(
args: Iterable[Iterable[Any]],
kwargs: Iterable[Mapping[str, Any]] = None,
tags: List[str] = None,
) -> List[Job]:
batch_size: int = 1000,
) -> JobBulkCreateResult:
"""Submits multiple jobs efficiently with positional args to each function call.
Preferred over repeatedly calling the function, such as in a loop, when submitting
Expand All @@ -1012,6 +1070,15 @@ def map(
>>> async_func('a', 'b', x=1) # doctest: +SKIP
>>> async_func('c', 'd', x=2) # doctest: +SKIP
Notes
-----
Map is idempotent for the initial call such that request errors that occur once started,
will not cause duplicate jobs to be submitted. However, if the method is called again
with the same arguments, it will submit duplicate jobs.
You should always check the return value to ensure all jobs were submitted successfully
and handle any errors that may have occurred.
Parameters
----------
args : Iterable[Iterable[Any]]
Expand All @@ -1024,10 +1091,32 @@ def map(
be expanded into keyword arguments for the function.
tags : List[str], optional
A list of tags to apply to all jobs submitted.
batch_size : int, default=1000
The number of jobs to submit in each batch. The maximum batch size is 1000.
Return
------
JobBulkCreateResult
An object containing the jobs that were submitted and any errors that occurred.
This object is compatible with a list of Job objects for backwards compatibility.
If the value of `JobBulkCreateResult.is_success` is False, you should check
`JobBulkCreateResult.errors` and handle any errors that occurred.
Raises
------
ClientError, ServerError
If the request to create the first batch of jobs, fails after all retries have
been exhausted.
Otherwise, any errors will be available in the returned JobBulkCreateResult.
"""
if self.state != DocumentState.SAVED:
raise ValueError("Cannot execute a Function that has not been saved")

if batch_size < 1 or batch_size > 1000:
raise ValueError("Batch size must between 1 and 1000")

args = [list(iterable) for iterable in args]
if kwargs is not None:
kwargs = [dict(mapping) for mapping in kwargs]
Expand All @@ -1036,17 +1125,50 @@ def map(
"The number of kwargs must match the number of args. "
f"Got {len(args)} args and {len(kwargs)} kwargs."
)
payload = {
"function_id": self.id,
"bulk_args": args,
"bulk_kwargs": kwargs,
}

if tags:
payload["tags"] = tags
result = JobBulkCreateResult()

# Send the jobs in batches of batch_size
for index, (positional, named) in enumerate(
itertools.zip_longest(
batched(args, batch_size), batched(kwargs or [], batch_size)
)
):
payload = {
"function_id": self.id,
"bulk_args": positional,
"bulk_kwargs": named,
"reference_id": str(uuid.uuid4()),
}

if tags:
payload["tags"] = tags

# This implementation uses a `reference_id` to ensure that the request is idempotent
# and duplicate jobs are not submitted in a retry scenario.
try:
response = self._client.session.post("/jobs/bulk", json=payload)
result.extend([Job(**data, saved=True) for data in response.json()])
except exceptions.NotFoundError:
# If one of these errors occurs, we cannot continue submitting any jobs at all
raise
except Exception as exc:
if index == 0:
# The first batch failed, let the user deal with the exception as all
# the batches would likely fail.
raise

result.append_error(
JobBulkCreateError(
function=self,
args=payload["bulk_args"],
kwargs=payload["bulk_kwargs"],
reference_id=payload["reference_id"],
exception=exc,
)
)

response = self._client.session.post("/jobs/bulk", json=payload)
return [Job(**data, saved=True) for data in response.json()]
return result

def cancel_jobs(self, query: Optional[JobSearch] = None, job_ids: List[str] = None):
"""Cancels all jobs for the Function matching the given query.
Expand Down
100 changes: 95 additions & 5 deletions descarteslabs/core/compute/tests/test_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gzip
import itertools
import json
import os
import random
Expand All @@ -8,6 +9,7 @@
from datetime import timezone

import responses
from requests import PreparedRequest

from descarteslabs import exceptions
from descarteslabs.compute import Function, FunctionStatus, Job, JobStatus
Expand Down Expand Up @@ -648,17 +650,99 @@ def test_map(self):
)

fn = Function(id="some-id", saved=True)
fn.map(
result = fn.map(
[[1, 2], [3, 4]],
kwargs=[{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}],
)
assert result.is_success
assert len(result) == 2
for job in result:
assert isinstance(job, Job)

request = responses.calls[-1].request
assert json.loads(request.body) == {
request_json: dict = json.loads(request.body)
assert request_json.pop("reference_id") is not None
assert request_json == {
"bulk_args": [[1, 2], [3, 4]],
"bulk_kwargs": [{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}],
"function_id": "some-id",
}

@responses.activate
def test_map_batching(self):
def request_callback(request: PreparedRequest):
payload: dict = json.loads(request.body)
jobs = []

args = payload["bulk_args"] or []
kwargs = payload["bulk_kwargs"] or []

for args, kwargs in itertools.zip_longest(args, kwargs):
jobs.append(self.make_job(args=args, kwargs=kwargs))

return (200, {}, json.dumps(jobs))

responses.add_callback(
responses.POST,
f"{self.compute_url}/jobs/bulk",
callback=request_callback,
)

fn = Function(id="some-id", saved=True)
result = fn.map([[n, n + 1] for n in range(3000)])
assert result.is_success is True, result.errors
assert len(result) == 3000
reference_ids = {
json.loads(call.request.body)["reference_id"] for call in responses.calls
}
assert len(reference_ids) == 3

@responses.activate
def test_map_errors(self):
global call_count
call_count = 0

def request_callback(request: PreparedRequest):
global call_count
call_count += 1

if call_count > 1:
return (500, {}, None)

payload: dict = json.loads(request.body)
jobs = []

args = payload["bulk_args"] or []
kwargs = payload["bulk_kwargs"] or []

for args, kwargs in itertools.zip_longest(args, kwargs):
jobs.append(self.make_job(args=args, kwargs=kwargs))

return (200, {}, json.dumps(jobs))

responses.add_callback(
responses.POST,
f"{self.compute_url}/jobs/bulk",
callback=request_callback,
)

fn = Function(id="some-id", saved=True)
result = fn.map(
[[1, 2], [3, 4]],
kwargs=[{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}],
batch_size=1,
)
assert result.is_success is False
assert len(result) == 1
assert len(result.errors) == 1
assert result.errors[0].args == [[3, 4]]
assert result.errors[0].kwargs == [{"first": 1.0, "second": 2.0}]
assert len(responses.calls) == 2
reference_ids = {
json.loads(call.request.body)["reference_id"] for call in responses.calls
}
assert len(reference_ids) == 2

@responses.activate
def test_map_deprecated(self):
self.mock_response(
Expand All @@ -673,7 +757,9 @@ def test_map_deprecated(self):
iterargs=[{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}],
)
request = responses.calls[-1].request
assert json.loads(request.body) == {
request_json: dict = json.loads(request.body)
assert request_json.pop("reference_id") is not None
assert request_json == {
"bulk_args": [[1, 2], [3, 4]],
"bulk_kwargs": [{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}],
"function_id": "some-id",
Expand Down Expand Up @@ -706,7 +792,9 @@ def inner(t):
kwgenerator(),
)
request = responses.calls[-1].request
assert json.loads(request.body) == {
request_json: dict = json.loads(request.body)
assert request_json.pop("reference_id") is not None
assert request_json == {
"bulk_args": [[1, 2], [3, 4]],
"bulk_kwargs": [{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}],
"function_id": "some-id",
Expand All @@ -727,7 +815,9 @@ def test_map_with_tags(self):
tags=["tag1", "tag2"],
)
request = responses.calls[-1].request
assert json.loads(request.body) == {
request_json: dict = json.loads(request.body)
assert request_json.pop("reference_id") is not None
assert request_json == {
"bulk_args": [[1, 2], [3, 4]],
"bulk_kwargs": [{"first": 1, "second": 2}, {"first": 1.0, "second": 2.0}],
"function_id": "some-id",
Expand Down

0 comments on commit 3a9db9d

Please sign in to comment.