Skip to content

Commit

Permalink
feat: add custom_metric for dynamic batching (#6189)
Browse files Browse the repository at this point in the history
Co-authored-by: Jina Dev Bot <[email protected]>
  • Loading branch information
JoanFM and jina-bot authored Sep 18, 2024
1 parent d4fb94d commit d17b620
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 58 deletions.
15 changes: 14 additions & 1 deletion jina/serve/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,22 @@ def _validate_sagemaker(self):
return

def _add_dynamic_batching(self, _dynamic_batching: Optional[Dict]):
import collections

def deep_update(source, overrides):
for key, value in overrides.items():
if isinstance(value, collections.Mapping) and value:
returned = deep_update(source.get(key, {}), value)
source[key] = returned
else:
source[key] = overrides[key]
return source

if _dynamic_batching:
self.dynamic_batching = getattr(self, 'dynamic_batching', {})
self.dynamic_batching.update(_dynamic_batching)
self.dynamic_batching = deep_update(
self.dynamic_batching, _dynamic_batching
)

def _add_metas(self, _metas: Optional[Dict]):
from jina.serve.executors.metas import get_default_metas
Expand Down
8 changes: 7 additions & 1 deletion jina/serve/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,9 @@ def dynamic_batching(
*,
preferred_batch_size: Optional[int] = None,
timeout: Optional[float] = 10_000,
flush_all: bool = False
flush_all: bool = False,
custom_metric: Optional[Callable[['DocumentArray'], Union[float, int]]] = None,
use_custom_metric: bool = False,
):
"""
`@dynamic_batching` defines the dynamic batching behavior of an Executor.
Expand All @@ -434,6 +436,8 @@ def dynamic_batching(
Default is 10_000ms (10 seconds).
:param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not.
If this is true, `preferred_batch_size` is used as a trigger mechanism.
:param custom_metric: Potential lambda function to measure the "weight" of each request.
:param use_custom_metric: Determines if we need to use the `custom_metric` to determine preferred_batch_size.
:return: decorated function
"""

Expand Down Expand Up @@ -480,6 +484,8 @@ def _inject_owner_attrs(self, owner, name):
] = preferred_batch_size
owner.dynamic_batching[fn_name]['timeout'] = timeout
owner.dynamic_batching[fn_name]['flush_all'] = flush_all
owner.dynamic_batching[fn_name]['use_custom_metric'] = use_custom_metric
owner.dynamic_batching[fn_name]['custom_metric'] = custom_metric
setattr(owner, name, self.fn)

def __set_name__(self, owner, name):
Expand Down
131 changes: 81 additions & 50 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import copy
from asyncio import Event, Task
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
from jina._docarray import docarray_v2
import contextlib

if not docarray_v2:
from docarray import DocumentArray
else:
Expand All @@ -18,16 +19,18 @@ class BatchQueue:
"""A batch queue that holds the data request and the callable to batch requests to."""

def __init__(
self,
func: Callable,
request_docarray_cls,
response_docarray_cls,
output_array_type: Optional[str] = None,
params: Optional[Dict] = None,
allow_concurrent: bool = False,
flush_all: bool = False,
preferred_batch_size: int = 4,
timeout: int = 10_000,
self,
func: Callable,
request_docarray_cls,
response_docarray_cls,
output_array_type: Optional[str] = None,
params: Optional[Dict] = None,
allow_concurrent: bool = False,
flush_all: bool = False,
preferred_batch_size: int = 4,
timeout: int = 10_000,
custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None,
use_custom_metric: bool = False,
) -> None:
# To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent
if allow_concurrent and flush_all:
Expand All @@ -44,6 +47,8 @@ def __init__(
self._response_docarray_cls = response_docarray_cls
self._flush_all = flush_all
self._preferred_batch_size: int = preferred_batch_size
self._custom_metric = None if not use_custom_metric else custom_metric
self._metric_value = 0
self._timeout: int = timeout
self._reset()
self._flush_trigger: Event = Event()
Expand All @@ -62,20 +67,22 @@ def _reset(self) -> None:
# a list of every request ID
self._request_idxs: List[int] = []
self._request_lens: List[int] = []
self._docs_metrics: List[int] = []
self._requests_completed: List[asyncio.Queue] = []
if not docarray_v2:
self._big_doc: DocumentArray = DocumentArray.empty()
else:
self._big_doc = self._request_docarray_cls()
self._metric_value = 0

self._flush_task: Optional[Task] = None
self._flush_trigger: Event = Event()

def _cancel_timer_if_pending(self):
if (
self._timer_task
and not self._timer_task.done()
and not self._timer_task.cancelled()
self._timer_task
and not self._timer_task.done()
and not self._timer_task.cancelled()
):
self._timer_finished = False
self._timer_task.cancel()
Expand All @@ -91,7 +98,7 @@ async def _sleep_then_set(self):
self._flush_trigger.set()
self._timer_finished = True

async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
async def push(self, request: DataRequest, http=False) -> asyncio.Queue:
"""Append request to the the list of requests to be processed.
This method creates an asyncio Queue for that request and keeps track of it. It returns
Expand All @@ -116,12 +123,18 @@ async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
self._big_doc.extend(docs)
next_req_idx = len(self._requests)
num_docs = len(docs)
metric_value = num_docs
if self._custom_metric is not None:
metrics = [self._custom_metric(doc) for doc in docs]
metric_value += sum(metrics)
self._docs_metrics.extend(metrics)
self._metric_value += metric_value
self._request_idxs.extend([next_req_idx] * num_docs)
self._request_lens.append(len(docs))
self._request_lens.append(num_docs)
self._requests.append(request)
queue = asyncio.Queue()
self._requests_completed.append(queue)
if len(self._big_doc) >= self._preferred_batch_size:
if self._metric_value >= self._preferred_batch_size:
self._flush_trigger.set()

return queue
Expand All @@ -132,10 +145,10 @@ async def _await_then_flush(self, http=False) -> None:
"""

def _get_docs_groups_completed_request_indexes(
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
requests_lens_in_batch,
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
requests_lens_in_batch,
):
"""
This method groups all the `non_assigned_docs` into groups of docs according to the `req_idx` they belong to.
Expand All @@ -160,9 +173,9 @@ def _get_docs_groups_completed_request_indexes(
)
if req_idx > min_involved_req_idx:
request_bucket = non_assigned_docs[
num_distributed_docs : num_distributed_docs
+ num_docs_in_req_idx
]
num_distributed_docs: num_distributed_docs
+ num_docs_in_req_idx
]
num_distributed_docs += num_docs_in_req_idx
completed_req_idx.append(min_involved_req_idx)
min_involved_req_idx = req_idx
Expand All @@ -171,25 +184,25 @@ def _get_docs_groups_completed_request_indexes(
num_docs_in_req_idx += 1

if (
req_idx not in completed_req_idx
and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx
== requests_lens_in_batch[req_idx]
req_idx not in completed_req_idx
and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx
== requests_lens_in_batch[req_idx]
):
completed_req_idx.append(req_idx)
request_bucket = non_assigned_docs[
num_distributed_docs : num_distributed_docs + num_docs_in_req_idx
]
num_distributed_docs: num_distributed_docs + num_docs_in_req_idx
]
distributed_requests.append(request_bucket)

return distributed_requests, completed_req_idx

async def _assign_results(
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
):
"""
This method aims to assign to the corresponding request objects the resulting documents from the mini batches.
Expand Down Expand Up @@ -220,7 +233,7 @@ async def _assign_results(
request = requests_in_batch[request_idx]
request_completed = requests_completed_in_batch[request_idx]
if http is False or self._output_array_type is not None:
request.direct_docs = None # batch queue will work in place, therefore result will need to read from data.
request.direct_docs = None # batch queue will work in place, therefore result will need to read from data.
request.data.set_docs_convert_arrays(
docs_group, ndarray_type=self._output_array_type
)
Expand All @@ -230,22 +243,39 @@ async def _assign_results(

return num_assigned_docs

def batch(iterable_1, iterable_2, n:Optional[int] = 1):
def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Optional = None):
if n is None:
yield iterable_1, iterable_2
return
items = len(iterable_1)
for ndx in range(0, items, n):
yield iterable_1[ndx : min(ndx + n, items)], iterable_2[
ndx : min(ndx + n, items)
]
elif iterable_metrics is None:
items = len(iterable_1)
for ndx in range(0, items, n):
yield iterable_1[ndx: min(ndx + n, items)], iterable_2[
ndx: min(ndx + n, items)
]
else:
batch_idx = 0
batch_weight = 0

for i, (item, weight) in enumerate(zip(iterable_1, iterable_metrics)):
batch_weight += weight

if batch_weight >= n:
yield iterable_1[batch_idx: i + 1], iterable_2[batch_idx: i + 1]
batch_idx = i + 1
batch_weight = 0

# Yield any remaining items
if batch_weight > 0:
yield iterable_1[batch_idx: len(iterable_1)], iterable_2[batch_idx: len(iterable_1)]

await self._flush_trigger.wait()
# writes to shared data between tasks need to be mutually exclusive
async with self._data_lock:
big_doc_in_batch = copy.copy(self._big_doc)
requests_idxs_in_batch = copy.copy(self._request_idxs)
requests_lens_in_batch = copy.copy(self._request_lens)
docs_metrics_in_batch = copy.copy(self._docs_metrics)
requests_in_batch = copy.copy(self._requests)
requests_completed_in_batch = copy.copy(self._requests_completed)

Expand All @@ -263,7 +293,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
big_doc_in_batch, requests_idxs_in_batch, self._preferred_batch_size if not self._flush_all else None
big_doc_in_batch, requests_idxs_in_batch,
self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
Expand All @@ -278,8 +309,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
)
# Output validation
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
):
if not len(batch_res_docs) == input_len_before_call:
raise ValueError(
Expand All @@ -301,8 +332,8 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
except Exception as exc:
# All the requests containing docs in this Exception should be raising it
for request_full in requests_completed_in_batch[
involved_requests_min_indx : involved_requests_max_indx + 1
]:
involved_requests_min_indx: involved_requests_max_indx + 1
]:
await request_full.put(exc)
else:
# We need to attribute the docs to their requests
Expand All @@ -320,11 +351,11 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
)

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
Expand Down
64 changes: 64 additions & 0 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,3 +736,67 @@ def foo(self, docs, **kwargs):

assert smaller_than_5 == (1 if allow_concurrent else 0)
assert larger_than_5 > 0


@pytest.mark.asyncio
@pytest.mark.parametrize('use_custom_metric', [True, False])
@pytest.mark.parametrize('flush_all', [False, True])
async def test_dynamic_batching_custom_metric(use_custom_metric, flush_all):
class DynCustomBatchProcessor(Executor):

@dynamic_batching(preferred_batch_size=10, custom_metric=lambda x: len(x.text))
@requests(on='/foo')
def foo(self, docs, **kwargs):
time.sleep(0.5)
total_len = sum([len(doc.text) for doc in docs])
for doc in docs:
doc.text = f"{total_len}"

depl = Deployment(uses=DynCustomBatchProcessor, uses_dynamic_batching={'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_custom_metric": use_custom_metric, "flush_all": flush_all}})
da = DocumentArray([Document(text='aaaaa') for i in range(50)])
with depl:
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
res = []
async for r in cl.post(
on='/foo',
inputs=da,
request_size=1,
continue_on_error=True,
results_in_order=True,
):
res.extend(r)
assert len(res) == 50 # 1 request per input

# If custom_metric and flush all
if use_custom_metric and not flush_all:
for doc in res:
assert doc.text == "10"

elif not use_custom_metric and not flush_all:
for doc in res:
assert doc.text == "50"

elif use_custom_metric and flush_all:
# There will be 2 "10" and the rest will be "240"
num_10 = 0
num_240 = 0
for doc in res:
if doc.text == "10":
num_10 += 1
elif doc.text == "240":
num_240 += 1

assert num_10 == 2
assert num_240 == 48
elif not use_custom_metric and flush_all:
# There will be 10 "50" and the rest will be "200"
num_50 = 0
num_200 = 0
for doc in res:
if doc.text == "50":
num_50 += 1
elif doc.text == "200":
num_200 += 1

assert num_50 == 10
assert num_200 == 40
Loading

0 comments on commit d17b620

Please sign in to comment.