Skip to content

Commit

Permalink
refactor: slight change in dyn batch queue (#6193)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Sep 5, 2024
1 parent 5397e43 commit abc4ca2
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 115 deletions.
193 changes: 100 additions & 93 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from asyncio import Event, Task
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
from jina._docarray import docarray_v2

import contextlib
if not docarray_v2:
from docarray import DocumentArray
else:
Expand All @@ -24,11 +24,16 @@ def __init__(
response_docarray_cls,
output_array_type: Optional[str] = None,
params: Optional[Dict] = None,
allow_concurrent: bool = False,
flush_all: bool = False,
preferred_batch_size: int = 4,
timeout: int = 10_000,
) -> None:
self._data_lock = asyncio.Lock()
# To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent
if allow_concurrent and flush_all:
self._data_lock = contextlib.AsyncExitStack()
else:
self._data_lock = asyncio.Lock()
self.func = func
if params is None:
params = dict()
Expand Down Expand Up @@ -104,19 +109,20 @@ async def push(self, request: DataRequest, http = False) -> asyncio.Queue:
# this push requests the data lock. The order of accessing the data lock guarantees that this request will be put in the `big_doc`
# before the `flush` task processes it.
self._start_timer()
if not self._flush_task:
self._flush_task = asyncio.create_task(self._await_then_flush(http))

self._big_doc.extend(docs)
next_req_idx = len(self._requests)
num_docs = len(docs)
self._request_idxs.extend([next_req_idx] * num_docs)
self._request_lens.append(len(docs))
self._requests.append(request)
queue = asyncio.Queue()
self._requests_completed.append(queue)
if len(self._big_doc) >= self._preferred_batch_size:
self._flush_trigger.set()
async with self._data_lock:
if not self._flush_task:
self._flush_task = asyncio.create_task(self._await_then_flush(http))

self._big_doc.extend(docs)
next_req_idx = len(self._requests)
num_docs = len(docs)
self._request_idxs.extend([next_req_idx] * num_docs)
self._request_lens.append(len(docs))
self._requests.append(request)
queue = asyncio.Queue()
self._requests_completed.append(queue)
if len(self._big_doc) >= self._preferred_batch_size:
self._flush_trigger.set()

return queue

Expand Down Expand Up @@ -236,74 +242,94 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):

await self._flush_trigger.wait()
# writes to shared data between tasks need to be mutually exclusive
big_doc_in_batch = copy.copy(self._big_doc)
requests_idxs_in_batch = copy.copy(self._request_idxs)
requests_lens_in_batch = copy.copy(self._request_lens)
requests_in_batch = copy.copy(self._requests)
requests_completed_in_batch = copy.copy(self._requests_completed)
async with self._data_lock:
big_doc_in_batch = copy.copy(self._big_doc)
requests_idxs_in_batch = copy.copy(self._request_idxs)
requests_lens_in_batch = copy.copy(self._request_lens)
requests_in_batch = copy.copy(self._requests)
requests_completed_in_batch = copy.copy(self._requests_completed)

self._reset()
self._reset()

# At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in
# requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to
# communicate that the request has been processed properly.
# At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in
# requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to
# communicate that the request has been processed properly.

if not docarray_v2:
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
else:
non_assigned_to_response_docs = self._response_docarray_cls()
if not docarray_v2:
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
else:
non_assigned_to_response_docs = self._response_docarray_cls()

non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
big_doc_in_batch, requests_idxs_in_batch, self._preferred_batch_size if not self._flush_all else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
input_len_before_call: int = len(docs_inner_batch)
batch_res_docs = None
try:
batch_res_docs = await self.func(
docs=docs_inner_batch,
parameters=self.params,
docs_matrix=None, # joining manually with batch queue is not supported right now
tracing_context=None,
)
# Output validation
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
):
if not len(batch_res_docs) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}'
non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
big_doc_in_batch, requests_idxs_in_batch, self._preferred_batch_size if not self._flush_all else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
input_len_before_call: int = len(docs_inner_batch)
batch_res_docs = None
try:
batch_res_docs = await self.func(
docs=docs_inner_batch,
parameters=self.params,
docs_matrix=None, # joining manually with batch queue is not supported right now
tracing_context=None,
)
# Output validation
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
):
if not len(batch_res_docs) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}'
)
elif batch_res_docs is None:
if not len(docs_inner_batch) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}'
)
else:
array_name = (
'DocumentArray' if not docarray_v2 else 'DocList'
)
elif batch_res_docs is None:
if not len(docs_inner_batch) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}'
raise TypeError(
f'The return type must be {array_name} / `None` when using dynamic batching, '
f'but getting {batch_res_docs!r}'
)
except Exception as exc:
# All the requests containing docs in this Exception should be raising it
for request_full in requests_completed_in_batch[
involved_requests_min_indx : involved_requests_max_indx + 1
]:
await request_full.put(exc)
else:
array_name = (
'DocumentArray' if not docarray_v2 else 'DocList'
# We need to attribute the docs to their requests
non_assigned_to_response_docs.extend(
batch_res_docs or docs_inner_batch
)
raise TypeError(
f'The return type must be {array_name} / `None` when using dynamic batching, '
f'but getting {batch_res_docs!r}'
non_assigned_to_response_request_idxs.extend(req_idxs)
num_assigned_docs = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
)
except Exception as exc:
# All the requests containing docs in this Exception should be raising it
for request_full in requests_completed_in_batch[
involved_requests_min_indx : involved_requests_max_indx + 1
]:
await request_full.put(exc)
else:
# We need to attribute the docs to their requests
non_assigned_to_response_docs.extend(
batch_res_docs or docs_inner_batch
)
non_assigned_to_response_request_idxs.extend(req_idxs)
num_assigned_docs = await _assign_results(

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
if len(non_assigned_to_response_request_idxs) > 0:
_ = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
Expand All @@ -312,25 +338,6 @@ def batch(iterable_1, iterable_2, n:Optional[int] = 1):
requests_completed_in_batch,
)

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
if len(non_assigned_to_response_request_idxs) > 0:
_ = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
)

async def close(self):
"""Closes the batch queue by flushing pending requests."""
if not self._is_closed:
Expand Down
1 change: 1 addition & 0 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ async def handle(
].response_schema,
output_array_type=self.args.output_array_type,
params=params,
allow_concurrent=self.args.allow_concurrent,
**self._batchqueue_config[exec_endpoint],
)
# This is necessary because push might need to await for the queue to be emptied
Expand Down
27 changes: 20 additions & 7 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def call_api_with_params(req: RequestStructParams):
],
)
@pytest.mark.parametrize('use_stream', [False, True])
def test_timeout(add_parameters, use_stream):
@pytest.mark.parametrize('allow_concurrent', [False, True])
def test_timeout(add_parameters, use_stream, allow_concurrent):
add_parameters['allow_concurrent'] = allow_concurrent
f = Flow().add(**add_parameters)
with f:
start_time = time.time()
Expand Down Expand Up @@ -265,7 +267,9 @@ def test_timeout(add_parameters, use_stream):
],
)
@pytest.mark.parametrize('use_stream', [False, True])
def test_preferred_batch_size(add_parameters, use_stream):
@pytest.mark.parametrize('allow_concurrent', [False, True])
def test_preferred_batch_size(add_parameters, use_stream, allow_concurrent):
add_parameters['allow_concurrent'] = allow_concurrent
f = Flow().add(**add_parameters)
with f:
with mp.Pool(2) as p:
Expand Down Expand Up @@ -315,8 +319,9 @@ def test_preferred_batch_size(add_parameters, use_stream):

@pytest.mark.repeat(10)
@pytest.mark.parametrize('use_stream', [False, True])
def test_correctness(use_stream):
f = Flow().add(uses=PlaceholderExecutor)
@pytest.mark.parametrize('allow_concurrent', [False, True])
def test_correctness(use_stream, allow_concurrent):
f = Flow().add(uses=PlaceholderExecutor, allow_concurrent=allow_concurrent)
with f:
with mp.Pool(2) as p:
results = list(
Expand Down Expand Up @@ -686,7 +691,14 @@ def foo(self, docs, **kwargs):
True
],
)
async def test_num_docs_processed_in_exec(flush_all):
@pytest.mark.parametrize(
'allow_concurrent',
[
False,
True
],
)
async def test_num_docs_processed_in_exec(flush_all, allow_concurrent):
class DynBatchProcessor(Executor):

@dynamic_batching(preferred_batch_size=5, timeout=5000, flush_all=flush_all)
Expand All @@ -695,7 +707,7 @@ def foo(self, docs, **kwargs):
for doc in docs:
doc.text = f"{len(docs)}"

depl = Deployment(uses=DynBatchProcessor, protocol='http')
depl = Deployment(uses=DynBatchProcessor, protocol='http', allow_concurrent=allow_concurrent)

with depl:
da = DocumentArray([Document(text='good') for _ in range(50)])
Expand All @@ -721,5 +733,6 @@ def foo(self, docs, **kwargs):
larger_than_5 += 1
if int(d.text) < 5:
smaller_than_5 += 1
assert smaller_than_5 == 1

assert smaller_than_5 == (1 if allow_concurrent else 0)
assert larger_than_5 > 0
Loading

0 comments on commit abc4ca2

Please sign in to comment.