Skip to content

Commit

Permalink
feat: avoid need data lock in batch queue (#6190)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Sep 3, 2024
1 parent f3e3442 commit 3e0943f
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 109 deletions.
213 changes: 116 additions & 97 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import copy
from asyncio import Event, Task
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
from jina._docarray import docarray_v2
Expand Down Expand Up @@ -63,6 +64,7 @@ def _reset(self) -> None:
self._big_doc = self._request_docarray_cls()

self._flush_task: Optional[Task] = None
self._flush_trigger: Event = Event()

def _cancel_timer_if_pending(self):
if (
Expand Down Expand Up @@ -102,20 +104,19 @@ async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
# this push requests the data lock. The order of accessing the data lock guarantees that this request will be put in the `big_doc`
# before the `flush` task processes it.
self._start_timer()
async with self._data_lock:
if not self._flush_task:
self._flush_task = asyncio.create_task(self._await_then_flush(http))

self._big_doc.extend(docs)
next_req_idx = len(self._requests)
num_docs = len(docs)
self._request_idxs.extend([next_req_idx] * num_docs)
self._request_lens.append(len(docs))
self._requests.append(request)
queue = asyncio.Queue()
self._requests_completed.append(queue)
if len(self._big_doc) >= self._preferred_batch_size:
self._flush_trigger.set()
if not self._flush_task:
self._flush_task = asyncio.create_task(self._await_then_flush(http))

self._big_doc.extend(docs)
next_req_idx = len(self._requests)
num_docs = len(docs)
self._request_idxs.extend([next_req_idx] * num_docs)
self._request_lens.append(len(docs))
self._requests.append(request)
queue = asyncio.Queue()
self._requests_completed.append(queue)
if len(self._big_doc) >= self._preferred_batch_size:
self._flush_trigger.set()

return queue

Expand All @@ -128,6 +129,7 @@ def _get_docs_groups_completed_request_indexes(
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
requests_lens_in_batch,
):
"""
This method groups all the `non_assigned_docs` into groups of docs according to the `req_idx` they belong to.
Expand All @@ -136,6 +138,7 @@ def _get_docs_groups_completed_request_indexes(
:param non_assigned_docs: The documents that have already been processed but have not been assigned to a request result
:param non_assigned_docs_reqs_idx: The request IDX that are not yet completed (not all of its docs have been processed)
:param sum_from_previous_mini_batch_in_first_req_idx: The number of docs from previous iteration that belong to the first non_assigned_req_idx. This is useful to make sure we know when a request is completed.
:param requests_lens_in_batch: List of lens of documents for each request in the batch.
:return: list of document groups and a list of request Idx to which each of these groups belong
"""
Expand Down Expand Up @@ -164,7 +167,7 @@ def _get_docs_groups_completed_request_indexes(
if (
req_idx not in completed_req_idx
and num_docs_in_req_idx + sum_from_previous_mini_batch_in_first_req_idx
== self._request_lens[req_idx]
== requests_lens_in_batch[req_idx]
):
completed_req_idx.append(req_idx)
request_bucket = non_assigned_docs[
Expand All @@ -178,6 +181,9 @@ async def _assign_results(
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
):
"""
This method aims to assign to the corresponding request objects the resulting documents from the mini batches.
Expand All @@ -187,6 +193,9 @@ async def _assign_results(
:param non_assigned_docs: The documents that have already been processed but have not been assigned to a request result
:param non_assigned_docs_reqs_idx: The request IDX that are not yet completed (not all of its docs have been processed)
:param sum_from_previous_mini_batch_in_first_req_idx: The number of docs from previous iteration that belong to the first non_assigned_req_idx. This is useful to make sure we know when a request is completed.
:param requests_lens_in_batch: List of lens of documents for each request in the batch.
:param requests_in_batch: List requests in batch
:param requests_completed_in_batch: List of queues for requests to be completed
:return: amount of assigned documents so that some documents can come back in the next iteration
"""
Expand All @@ -197,12 +206,13 @@ async def _assign_results(
non_assigned_docs,
non_assigned_docs_reqs_idx,
sum_from_previous_mini_batch_in_first_req_idx,
requests_lens_in_batch
)
num_assigned_docs = sum(len(group) for group in docs_grouped)

for docs_group, request_idx in zip(docs_grouped, completed_req_idxs):
request = self._requests[request_idx]
request_completed = self._requests_completed[request_idx]
request = requests_in_batch[request_idx]
request_completed = requests_completed_in_batch[request_idx]
if http is False or self._output_array_type is not None:
request.direct_docs = None # batch queue will work in place, therefore result will need to read from data.
request.data.set_docs_convert_arrays(
Expand All @@ -226,91 +236,100 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):

await self._flush_trigger.wait()
# writes to shared data between tasks need to be mutually exclusive
async with self._data_lock:
# At this moment, we have documents concatenated in self._big_doc corresponding to requests in
# self._requests with its lengths stored in self._requests_len. For each requests, there is a queue to
# communicate that the request has been processed properly. At this stage the data_lock is ours and
# therefore no-one can add requests to this list.
self._flush_trigger: Event = Event()
try:
if not docarray_v2:
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
else:
non_assigned_to_response_docs = self._response_docarray_cls()
big_doc_in_batch = copy.copy(self._big_doc)
requests_idxs_in_batch = copy.copy(self._request_idxs)
requests_lens_in_batch = copy.copy(self._request_lens)
requests_in_batch = copy.copy(self._requests)
requests_completed_in_batch = copy.copy(self._requests_completed)

non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
self._big_doc, self._request_idxs, self._preferred_batch_size if not self._flush_all else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
input_len_before_call: int = len(docs_inner_batch)
batch_res_docs = None
try:
batch_res_docs = await self.func(
docs=docs_inner_batch,
parameters=self.params,
docs_matrix=None, # joining manually with batch queue is not supported right now
tracing_context=None,
)
# Output validation
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
):
if not len(batch_res_docs) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}'
)
elif batch_res_docs is None:
if not len(docs_inner_batch) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}'
)
else:
array_name = (
'DocumentArray' if not docarray_v2 else 'DocList'
)
raise TypeError(
f'The return type must be {array_name} / `None` when using dynamic batching, '
f'but getting {batch_res_docs!r}'
)
except Exception as exc:
# All the requests containing docs in this Exception should be raising it
for request_full in self._requests_completed[
involved_requests_min_indx : involved_requests_max_indx + 1
]:
await request_full.put(exc)
else:
# We need to attribute the docs to their requests
non_assigned_to_response_docs.extend(
batch_res_docs or docs_inner_batch
)
non_assigned_to_response_request_idxs.extend(req_idxs)
num_assigned_docs = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
)
self._reset()

# At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in
# requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to
# communicate that the request has been processed properly.

if not docarray_v2:
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
else:
non_assigned_to_response_docs = self._response_docarray_cls()

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
big_doc_in_batch, requests_idxs_in_batch, self._preferred_batch_size if not self._flush_all else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
input_len_before_call: int = len(docs_inner_batch)
batch_res_docs = None
try:
batch_res_docs = await self.func(
docs=docs_inner_batch,
parameters=self.params,
docs_matrix=None, # joining manually with batch queue is not supported right now
tracing_context=None,
)
# Output validation
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
):
if not len(batch_res_docs) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}'
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
elif batch_res_docs is None:
if not len(docs_inner_batch) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}'
)
if len(non_assigned_to_response_request_idxs) > 0:
_ = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
else:
array_name = (
'DocumentArray' if not docarray_v2 else 'DocList'
)
raise TypeError(
f'The return type must be {array_name} / `None` when using dynamic batching, '
f'but getting {batch_res_docs!r}'
)
finally:
self._reset()
except Exception as exc:
# All the requests containing docs in this Exception should be raising it
for request_full in requests_completed_in_batch[
involved_requests_min_indx : involved_requests_max_indx + 1
]:
await request_full.put(exc)
else:
# We need to attribute the docs to their requests
non_assigned_to_response_docs.extend(
batch_res_docs or docs_inner_batch
)
non_assigned_to_response_request_idxs.extend(req_idxs)
num_assigned_docs = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
)

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
if len(non_assigned_to_response_request_idxs) > 0:
_ = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
)

async def close(self):
"""Closes the batch queue by flushing pending requests."""
Expand Down
27 changes: 19 additions & 8 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,14 @@ def test_failure_propagation():
True
],
)
def test_exception_handling_in_dynamic_batch(flush_all):
@pytest.mark.parametrize(
'allow_concurrent',
[
False,
True
],
)
def test_exception_handling_in_dynamic_batch(flush_all, allow_concurrent):
class SlowExecutorWithException(Executor):

@dynamic_batching(preferred_batch_size=3, timeout=5000, flush_all=flush_all)
Expand All @@ -646,7 +653,7 @@ def foo(self, docs, **kwargs):
if doc.text == 'fail':
raise Exception('Fail is in the Batch')

depl = Deployment(uses=SlowExecutorWithException)
depl = Deployment(uses=SlowExecutorWithException, allow_concurrent=allow_concurrent)

with depl:
da = DocumentArray([Document(text='good') for _ in range(50)])
Expand All @@ -670,6 +677,7 @@ def foo(self, docs, **kwargs):
else:
assert 1 <= num_failed_requests <= len(da) # 3 requests in the dynamic batch failing


@pytest.mark.asyncio
@pytest.mark.parametrize(
'flush_all',
Expand All @@ -694,11 +702,11 @@ def foo(self, docs, **kwargs):
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
res = []
async for r in cl.post(
on='/foo',
inputs=da,
request_size=7,
continue_on_error=True,
results_in_order=True,
on='/foo',
inputs=da,
request_size=7,
continue_on_error=True,
results_in_order=True,
):
res.extend(r)
assert len(res) == 50 # 1 request per input
Expand All @@ -707,8 +715,11 @@ def foo(self, docs, **kwargs):
assert int(d.text) <= 5
else:
larger_than_5 = 0
smaller_than_5 = 0
for d in res:
if int(d.text) > 5:
larger_than_5 += 1
assert int(d.text) >= 5
if int(d.text) < 5:
smaller_than_5 += 1
assert smaller_than_5 == 1
assert larger_than_5 > 0
10 changes: 6 additions & 4 deletions tests/unit/serve/dynamic_batching/test_batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ async def process_request(req):
@pytest.mark.parametrize('flush_all', [False, True])
async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all):
batches_lengths_computed = []
lock = asyncio.Lock()

async def foo(docs, **kwargs):
await asyncio.sleep(4)
batches_lengths_computed.append(len(docs))
return DocumentArray([Document(text='Done') for _ in docs])
async with lock:
await asyncio.sleep(4)
batches_lengths_computed.append(len(docs))
return DocumentArray([Document(text='Done') for _ in docs])

bq: BatchQueue = BatchQueue(
foo,
Expand Down Expand Up @@ -109,7 +111,7 @@ async def process_request(req, sleep=0):
assert time_spent >= 8000
assert time_spent <= 8500
if flush_all is False:
assert batches_lengths_computed == [5, 1, 2]
assert batches_lengths_computed == [5, 2, 1]
else:
assert batches_lengths_computed == [6, 2]

Expand Down

0 comments on commit 3e0943f

Please sign in to comment.