From 72aa9071dfd088a4606ce81446635542c589dd3d Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Fri, 29 Jan 2021 01:05:31 +0100 Subject: [PATCH] feat(client): input_fn can now be async generator (#1543) (#1808) * refactor(client): prepare for #1543 * feat(client): client now supports async input_fn --- .github/CODEOWNERS | 1 + README.md | 26 +++- extra-requirements.txt | 3 +- jina/clients/__init__.py | 3 +- jina/clients/asyncio.py | 3 +- jina/clients/base.py | 41 ++++-- jina/clients/request.py | 126 ------------------- jina/clients/request/__init__.py | 44 +++++++ jina/clients/request/asyncio.py | 34 +++++ jina/clients/request/helper.py | 50 ++++++++ jina/clients/{websockets.py => websocket.py} | 7 +- jina/peapods/runtimes/asyncio/rest/app.py | 10 +- tests/unit/clients/python/test_request.py | 37 +++--- tests/unit/flow/test_asyncflow.py | 35 +++++- tests/unit/types/message/test_compression.py | 6 +- tests/unit/types/message/test_message.py | 22 ++-- 16 files changed, 262 insertions(+), 186 deletions(-) delete mode 100644 jina/clients/request.py create mode 100644 jina/clients/request/__init__.py create mode 100644 jina/clients/request/asyncio.py create mode 100644 jina/clients/request/helper.py rename jina/clients/{websockets.py => websocket.py} (96%) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 633587a75cd19..1c3eac3f87ee0 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -7,3 +7,4 @@ # Han Xiao owns CICD and README.md .github @hanxiao setup.py @hanxiao +extra-requirements.txt @hanxiao \ No newline at end of file diff --git a/README.md b/README.md index ace5638bc8fc6..30084940cd847 100644 --- a/README.md +++ b/README.md @@ -113,21 +113,24 @@ To visualize the Flow, simply chain it with `.plot('my-flow.svg')`. If you are u #### Feed Data -Let's create some random data and index it: +To use a Flow, open it via `with` context manager, like you would open a file in Python. Now let's create some empty document and index it: ```python -import numpy from jina import Document with Flow().add() as f: - f.index((Document() for _ in range(10))) # index raw Jina Documents + f.index((Document() for _ in range(10))) +``` + +Flow supports CRUD operations: `index`, `search`, `update`, `delete`. Besides, it also provides sugary syntax on common data type such as files, text, and `ndarray`. + +```python +with f: f.index_ndarray(numpy.random.random([4,2]), on_done=print) # index ndarray data, document sliced on first dimension f.index_lines(['hello world!', 'goodbye world!']) # index textual data, each element is a document f.index_files(['/tmp/*.mp4', '/tmp/*.pdf']) # index files and wildcard globs, each file is a document ``` -To use a Flow, open it using the `with` context manager, like you would a file in Python. You can call `index` and `search` with nearly all types of data. The whole data stream is asynchronous and efficient. - #### Fetch Result @@ -345,6 +348,19 @@ with AsyncFlow().add() as f: await f.index_ndarray(numpy.random.random([5, 4]), on_done=print) ``` +#### Asynchronous Input + +`AsyncFlow`'s CRUD operations support async generator as the input function. This is particular useful when your data sources involves other asynchronous libraries (e.g. motor for Mongodb): + +```python +async def input_fn(): + for _ in range(10): + yield Document() + await asyncio.sleep(0.1) + +with AsyncFlow().add() as f: + await f.index(input_fn) +``` That's all you need to know for understanding the magic behind `hello-world`. Now let's dive into it! diff --git a/extra-requirements.txt b/extra-requirements.txt index e2d9292016128..9052cbfe9ba55 100644 --- a/extra-requirements.txt +++ b/extra-requirements.txt @@ -71,4 +71,5 @@ pydantic: http, devel, test, daemon python-multipart: http, devel, test, daemon aiofiles: devel, cicd, http, test, daemon pytest-custom_exit_code: cicd, test -bs4: test \ No newline at end of file +bs4: test +aiostream: devel, cicd \ No newline at end of file diff --git a/jina/clients/__init__.py b/jina/clients/__init__.py index 2f09c8c7c7d8a..6168254211e23 100644 --- a/jina/clients/__init__.py +++ b/jina/clients/__init__.py @@ -5,7 +5,7 @@ from .base import BaseClient, CallbackFnType, InputFnType from .helper import callback_exec from .request import GeneratorSourceType -from .websockets import WebSocketClientMixin +from .websocket import WebSocketClientMixin from ..enums import RequestType from ..helper import run_async, deprecated_alias @@ -49,6 +49,7 @@ def search(self, input_fn: InputFnType = None, :return: """ self.mode = RequestType.SEARCH + self.add_default_kwargs(kwargs) return run_async(self._get_results, input_fn, on_done, on_error, on_always, **kwargs) @deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1)) diff --git a/jina/clients/asyncio.py b/jina/clients/asyncio.py index 7477ca9aa1b62..c5831a44704e7 100644 --- a/jina/clients/asyncio.py +++ b/jina/clients/asyncio.py @@ -1,5 +1,5 @@ from .base import InputFnType, BaseClient, CallbackFnType -from .websockets import WebSocketClientMixin +from .websocket import WebSocketClientMixin from ..enums import RequestType from ..helper import deprecated_alias @@ -79,6 +79,7 @@ async def search(self, input_fn: InputFnType = None, :return: """ self.mode = RequestType.SEARCH + self.add_default_kwargs(kwargs) return await self._get_results(input_fn, on_done, on_error, on_always, **kwargs) @deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1)) diff --git a/jina/clients/base.py b/jina/clients/base.py index 796a53a85892c..afe63c8a0c0d4 100644 --- a/jina/clients/base.py +++ b/jina/clients/base.py @@ -3,11 +3,10 @@ import argparse import os -from typing import Callable, Union, Optional, Iterator, List +from typing import Callable, Union, Optional, Iterator, List, Dict, AsyncIterator import grpc - -from . import request +import inspect from .helper import callback_exec from .request import GeneratorSourceType from ..enums import RequestType @@ -70,8 +69,12 @@ def check_input(input_fn: Optional[InputFnType] = None, **kwargs) -> None: kwargs['data'] = input_fn + if inspect.isasyncgenfunction(input_fn) or inspect.isasyncgen(input_fn): + raise NotImplementedError('checking the validity of an async generator is not implemented yet') + try: - r = next(getattr(request, 'index')(**kwargs)) + from .request import request_generator + r = next(request_generator(**kwargs)) if isinstance(r, Request): default_logger.success(f'input_fn is valid') else: @@ -80,18 +83,25 @@ def check_input(input_fn: Optional[InputFnType] = None, **kwargs) -> None: default_logger.error(f'input_fn is not valid!') raise BadClientInput from ex - def _get_requests(self, **kwargs) -> Iterator['Request']: + def _get_requests(self, **kwargs) -> Union[Iterator['Request'], AsyncIterator['Request']]: """Get request in generator""" _kwargs = vars(self.args) _kwargs['data'] = self.input_fn # override by the caller-specific kwargs _kwargs.update(kwargs) + if inspect.isasyncgen(self.input_fn): + from .request.asyncio import request_generator + return request_generator(**_kwargs) + else: + from .request import request_generator + return request_generator(**_kwargs) + + def _get_task_name(self, kwargs: Dict) -> str: tname = str(self.mode).lower() if 'mode' in kwargs: tname = str(kwargs['mode']).lower() - - return getattr(request, tname)(**_kwargs), tname + return tname @property def input_fn(self) -> InputFnType: @@ -118,7 +128,8 @@ async def _get_results(self, result = [] # type: List['Response'] try: self.input_fn = input_fn - req_iter, tname = self._get_requests(**kwargs) + tname = self._get_task_name(kwargs) + req_iter = self._get_requests(**kwargs) async with grpc.aio.insecure_channel(f'{self.args.host}:{self.args.port_expose}', options=[('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)]) as channel: @@ -167,3 +178,17 @@ def search(self): def train(self): raise NotImplementedError + + @staticmethod + def add_default_kwargs(kwargs: Dict): + # TODO: refactor it into load from config file + if ('top_k' in kwargs) and (kwargs['top_k'] is not None): + # associate all VectorSearchDriver and SliceQL driver to use top_k + from jina import QueryLang + topk_ql = [QueryLang({'name': 'SliceQL', 'priority': 1, 'parameters': {'end': kwargs['top_k']}}), + QueryLang( + {'name': 'VectorSearchDriver', 'priority': 1, 'parameters': {'top_k': kwargs['top_k']}})] + if 'queryset' not in kwargs: + kwargs['queryset'] = topk_ql + else: + kwargs['queryset'].extend(topk_ql) diff --git a/jina/clients/request.py b/jina/clients/request.py deleted file mode 100644 index 442babdc19499..0000000000000 --- a/jina/clients/request.py +++ /dev/null @@ -1,126 +0,0 @@ -__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved." -__license__ = "Apache-2.0" - -from typing import Iterator, Union, Tuple, Sequence - -from .. import Request -from ..enums import RequestType, DataInputType -from ..excepts import BadDocType -from ..helper import batch_iterator -from ..logging import default_logger -from ..types.document import Document, DocumentSourceType, DocumentContentType -from ..types.querylang import QueryLang -from ..types.sets.querylang import AcceptQueryLangType - -GeneratorSourceType = Iterator[Union[DocumentContentType, - DocumentSourceType, - Tuple[DocumentContentType, DocumentContentType], - Tuple[DocumentSourceType, DocumentSourceType]]] - - -def _build_doc(data, data_type: DataInputType, **kwargs) -> Tuple['Document', 'DataInputType']: - def _build_doc_from_content(): - with Document(**kwargs) as d: - d.content = data - return d, DataInputType.CONTENT - - if data_type == DataInputType.AUTO or data_type == DataInputType.DOCUMENT: - if isinstance(data, Document): - # if incoming is already primitive type Document, then all good, best practice! - return data, DataInputType.DOCUMENT - try: - d = Document(data, **kwargs) - return d, DataInputType.DOCUMENT - except BadDocType: - # AUTO has a fallback, now reconsider it as content - if data_type == DataInputType.AUTO: - return _build_doc_from_content() - else: - raise - elif data_type == DataInputType.CONTENT: - return _build_doc_from_content() - - -def _generate(data: GeneratorSourceType, - request_size: int = 0, - mode: RequestType = RequestType.INDEX, - mime_type: str = None, - queryset: Union[AcceptQueryLangType, Iterator[AcceptQueryLangType]] = None, - data_type: DataInputType = DataInputType.AUTO, - **kwargs # do not remove this, add on purpose to suppress unknown kwargs - ) -> Iterator['Request']: - """ - :param data_type: if ``data`` is an iterator over self-contained document, i.e. :class:`DocumentSourceType`; - or an interator over possible Document content (set to text, blob and buffer). - :return: - """ - - _kwargs = dict(mime_type=mime_type, length=request_size, weight=1.0) - - try: - for batch in batch_iterator(data, request_size): - req = Request() - req.request_type = str(mode) - for content in batch: - if isinstance(content, tuple) and len(content) == 2: - # content comes in pair, will take the first as the input and the second as the groundtruth - - # note how data_type is cached - d, data_type = _build_doc(content[0], data_type, **_kwargs) - gt, _ = _build_doc(content[1], data_type, **_kwargs) - req.docs.append(d) - req.groundtruths.append(gt) - else: - d, data_type = _build_doc(content, data_type, **_kwargs) - req.docs.append(d) - - if isinstance(queryset, Sequence): - req.queryset.extend(queryset) - elif queryset is not None: - req.queryset.append(queryset) - - yield req - except Exception as ex: - # must be handled here, as grpc channel wont handle Python exception - default_logger.critical(f'input_fn is not valid! {ex!r}', exc_info=True) - - -def index(*args, **kwargs): - """Generate a indexing request""" - yield from _generate(*args, **kwargs) - - -def update(*args, **kwargs): - """Generate a update request""" - yield from _generate(*args, **kwargs) - - -def delete(*args, **kwargs): - """Generate a delete request""" - yield from _generate(*args, **kwargs) - - -def train(*args, **kwargs): - """Generate a training request """ - yield from _generate(*args, **kwargs) - req = Request() - req.train.flush = True - yield req - - -def search(*args, **kwargs): - """Generate a searching request """ - if ('top_k' in kwargs) and (kwargs['top_k'] is not None): - # associate all VectorSearchDriver and SliceQL driver to use top_k - topk_ql = [QueryLang({'name': 'SliceQL', 'priority': 1, 'parameters': {'end': kwargs['top_k']}}), - QueryLang({'name': 'VectorSearchDriver', 'priority': 1, 'parameters': {'top_k': kwargs['top_k']}})] - if 'queryset' not in kwargs: - kwargs['queryset'] = topk_ql - else: - kwargs['queryset'].extend(topk_ql) - yield from _generate(*args, **kwargs) - - -def evaluate(*args, **kwargs): - """Generate an evaluation request """ - yield from _generate(*args, **kwargs) diff --git a/jina/clients/request/__init__.py b/jina/clients/request/__init__.py new file mode 100644 index 0000000000000..204e05f6282fd --- /dev/null +++ b/jina/clients/request/__init__.py @@ -0,0 +1,44 @@ +__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved." +__license__ = "Apache-2.0" + +from typing import Iterator, Union, Tuple, AsyncIterator + +from .helper import _new_request_from_batch +from ... import Request +from ...enums import RequestType, DataInputType +from ...helper import batch_iterator +from ...logging import default_logger +from ...types.document import DocumentSourceType, DocumentContentType +from ...types.sets.querylang import AcceptQueryLangType + +SingletonDataType = Union[DocumentContentType, + DocumentSourceType, + Tuple[DocumentContentType, DocumentContentType], + Tuple[DocumentSourceType, DocumentSourceType]] + +GeneratorSourceType = Union[Iterator[SingletonDataType], AsyncIterator[SingletonDataType]] + + +def request_generator(data: GeneratorSourceType, + request_size: int = 0, + mode: RequestType = RequestType.INDEX, + mime_type: str = None, + queryset: Union[AcceptQueryLangType, Iterator[AcceptQueryLangType]] = None, + data_type: DataInputType = DataInputType.AUTO, + **kwargs # do not remove this, add on purpose to suppress unknown kwargs + ) -> Iterator['Request']: + """ + :param data_type: if ``data`` is an iterator over self-contained document, i.e. :class:`DocumentSourceType`; + or an interator over possible Document content (set to text, blob and buffer). + :return: + """ + + _kwargs = dict(mime_type=mime_type, length=request_size, weight=1.0) + + try: + for batch in batch_iterator(data, request_size): + yield _new_request_from_batch(_kwargs, batch, data_type, mode, queryset) + + except Exception as ex: + # must be handled here, as grpc channel wont handle Python exception + default_logger.critical(f'input_fn is not valid! {ex!r}', exc_info=True) diff --git a/jina/clients/request/asyncio.py b/jina/clients/request/asyncio.py new file mode 100644 index 0000000000000..bba77bfab61f3 --- /dev/null +++ b/jina/clients/request/asyncio.py @@ -0,0 +1,34 @@ +__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved." +__license__ = "Apache-2.0" + +from typing import Iterator, Union, AsyncIterator + +from .helper import _new_request_from_batch +from .. import GeneratorSourceType +from ... import Request +from ...enums import RequestType, DataInputType +from ...importer import ImportExtensions +from ...logging import default_logger +from ...types.sets.querylang import AcceptQueryLangType + + +async def request_generator(data: GeneratorSourceType, + request_size: int = 0, + mode: RequestType = RequestType.INDEX, + mime_type: str = None, + queryset: Union[AcceptQueryLangType, Iterator[AcceptQueryLangType]] = None, + data_type: DataInputType = DataInputType.AUTO, + **kwargs # do not remove this, add on purpose to suppress unknown kwargs + ) -> AsyncIterator['Request']: + + _kwargs = dict(mime_type=mime_type, length=request_size, weight=1.0) + + try: + with ImportExtensions(required=True): + import aiostream + + async for batch in aiostream.stream.chunks(data, request_size): + yield _new_request_from_batch(_kwargs, batch, data_type, mode, queryset) + except Exception as ex: + # must be handled here, as grpc channel wont handle Python exception + default_logger.critical(f'input_fn is not valid! {ex!r}', exc_info=True) diff --git a/jina/clients/request/helper.py b/jina/clients/request/helper.py new file mode 100644 index 0000000000000..b48d47c486950 --- /dev/null +++ b/jina/clients/request/helper.py @@ -0,0 +1,50 @@ +from typing import Tuple, Sequence + +from ... import Document, Request +from ...enums import DataInputType +from ...excepts import BadDocType + + +def _new_doc_from_data(data, data_type: DataInputType, **kwargs) -> Tuple['Document', 'DataInputType']: + def _build_doc_from_content(): + with Document(**kwargs) as d: + d.content = data + return d, DataInputType.CONTENT + + if data_type == DataInputType.AUTO or data_type == DataInputType.DOCUMENT: + if isinstance(data, Document): + # if incoming is already primitive type Document, then all good, best practice! + return data, DataInputType.DOCUMENT + try: + d = Document(data, **kwargs) + return d, DataInputType.DOCUMENT + except BadDocType: + # AUTO has a fallback, now reconsider it as content + if data_type == DataInputType.AUTO: + return _build_doc_from_content() + else: + raise + elif data_type == DataInputType.CONTENT: + return _build_doc_from_content() + + +def _new_request_from_batch(_kwargs, batch, data_type, mode, queryset): + req = Request() + req.request_type = str(mode) + for content in batch: + if isinstance(content, tuple) and len(content) == 2: + # content comes in pair, will take the first as the input and the second as the groundtruth + + # note how data_type is cached + d, data_type = _new_doc_from_data(content[0], data_type, **_kwargs) + gt, _ = _new_doc_from_data(content[1], data_type, **_kwargs) + req.docs.append(d) + req.groundtruths.append(gt) + else: + d, data_type = _new_doc_from_data(content, data_type, **_kwargs) + req.docs.append(d) + if isinstance(queryset, Sequence): + req.queryset.extend(queryset) + elif queryset is not None: + req.queryset.append(queryset) + return req diff --git a/jina/clients/websockets.py b/jina/clients/websocket.py similarity index 96% rename from jina/clients/websockets.py rename to jina/clients/websocket.py index 1162859075854..3ee309710a4a8 100644 --- a/jina/clients/websockets.py +++ b/jina/clients/websocket.py @@ -1,4 +1,5 @@ import asyncio +from abc import ABC from typing import Callable, List from .base import BaseClient @@ -8,7 +9,7 @@ from ..types.request import Request, Response -class WebSocketClientMixin(BaseClient): +class WebSocketClientMixin(BaseClient, ABC): async def _get_results(self, input_fn: Callable, on_done: Callable, @@ -33,7 +34,9 @@ async def _get_results(self, result = [] # type: List['Response'] self.input_fn = input_fn - req_iter, tname = self._get_requests(**kwargs) + + tname = self._get_task_name(kwargs) + req_iter = self._get_requests(**kwargs) try: client_info = f'{self.args.host}:{self.args.port_expose}' # setting `max_size` as None to avoid connection closure due to size of message diff --git a/jina/peapods/runtimes/asyncio/rest/app.py b/jina/peapods/runtimes/asyncio/rest/app.py index d5fa9ddfdde6d..265732684661e 100644 --- a/jina/peapods/runtimes/asyncio/rest/app.py +++ b/jina/peapods/runtimes/asyncio/rest/app.py @@ -6,14 +6,13 @@ from ..grpc.async_call import AsyncPrefetchCall from ....zmq import AsyncZmqlet -from ..... import clients +from .....clients.request import request_generator from .....enums import RequestType from .....importer import ImportExtensions from .....logging import JinaLogger from .....types.message import Message from .....types.request import Request - def get_fastapi_app(args: 'argparse.Namespace', logger: 'JinaLogger'): with ImportExtensions(required=True): from fastapi import FastAPI, WebSocket, Body @@ -21,8 +20,7 @@ def get_fastapi_app(args: 'argparse.Namespace', logger: 'JinaLogger'): from fastapi.middleware.cors import CORSMiddleware from starlette.endpoints import WebSocketEndpoint from starlette import status - if False: - from starlette.types import Receive, Scope, Send + from starlette.types import Receive, Scope, Send app = FastAPI(title='RESTRuntime') app.add_middleware( @@ -50,7 +48,9 @@ async def api(mode: str, body: Any = Body(...)): return error('"data" field is empty', 406) body['mode'] = RequestType.from_string(mode) - req_iter = getattr(clients.request, mode)(**body) + from .....clients import BaseClient + BaseClient.add_default_kwargs(body) + req_iter = request_generator(**body) results = await get_result_in_json(req_iter=req_iter) return JSONResponse(content=results[0], status_code=200) diff --git a/tests/unit/clients/python/test_request.py b/tests/unit/clients/python/test_request.py index b0e8ea19dd9c7..0c294e86f1913 100644 --- a/tests/unit/clients/python/test_request.py +++ b/tests/unit/clients/python/test_request.py @@ -5,7 +5,8 @@ from google.protobuf.json_format import MessageToJson, MessageToDict from jina import Document, Flow -from jina.clients.request import _generate, _build_doc +from jina.clients.request import request_generator +from jina.clients.request.helper import _new_doc_from_data from jina.enums import DataInputType from jina.excepts import BadDocType from jina.proto import jina_pb2 @@ -30,7 +31,7 @@ def test_on_bad_iterator(): def test_data_type_builder_doc(builder): a = DocumentProto() a.id = 'a236cbb0eda62d58' - d, t = _build_doc(builder(a), DataInputType.DOCUMENT) + d, t = _new_doc_from_data(builder(a), DataInputType.DOCUMENT) assert d.id == a.id assert t == DataInputType.DOCUMENT @@ -39,13 +40,13 @@ def test_data_type_builder_doc_bad(): a = DocumentProto() a.id = 'a236cbb0eda62d58' with pytest.raises(BadDocType): - _build_doc(b'BREAKIT!' + a.SerializeToString(), DataInputType.DOCUMENT) + _new_doc_from_data(b'BREAKIT!' + a.SerializeToString(), DataInputType.DOCUMENT) with pytest.raises(BadDocType): - _build_doc(MessageToJson(a) + '🍔', DataInputType.DOCUMENT) + _new_doc_from_data(MessageToJson(a) + '🍔', DataInputType.DOCUMENT) with pytest.raises(BadDocType): - _build_doc({'🍔': '🍔'}, DataInputType.DOCUMENT) + _new_doc_from_data({'🍔': '🍔'}, DataInputType.DOCUMENT) @pytest.mark.parametrize('input_type', [DataInputType.AUTO, DataInputType.CONTENT]) @@ -54,20 +55,20 @@ def test_data_type_builder_auto(input_type): print(f'quant is on: {os.environ["JINA_ARRAY_QUANT"]}') del os.environ['JINA_ARRAY_QUANT'] - d, t = _build_doc('123', input_type) + d, t = _new_doc_from_data('123', input_type) assert d.text == '123' assert t == DataInputType.CONTENT - d, t = _build_doc(b'45678', input_type) + d, t = _new_doc_from_data(b'45678', input_type) assert t == DataInputType.CONTENT assert d.buffer == b'45678' - d, t = _build_doc(b'123', input_type) + d, t = _new_doc_from_data(b'123', input_type) assert t == DataInputType.CONTENT assert d.buffer == b'123' c = np.random.random([10, 10]) - d, t = _build_doc(c, input_type) + d, t = _new_doc_from_data(c, input_type) np.testing.assert_equal(d.blob, c) assert t == DataInputType.CONTENT @@ -77,7 +78,7 @@ def random_lines(num_lines): for j in range(1, num_lines + 1): yield f'i\'m dummy doc {j}' - req = _generate(data=random_lines(100), request_size=100) + req = request_generator(data=random_lines(100), request_size=100) request = next(req) assert len(request.index.docs) == 100 @@ -91,7 +92,7 @@ def test_request_generate_lines_from_list(): def random_lines(num_lines): return [f'i\'m dummy doc {j}' for j in range(1, num_lines + 1)] - req = _generate(data=random_lines(100), request_size=100) + req = request_generator(data=random_lines(100), request_size=100) request = next(req) assert len(request.index.docs) == 100 @@ -106,7 +107,7 @@ def random_lines(num_lines): for j in range(1, num_lines + 1): yield f'https://github.com i\'m dummy doc {j}' - req = _generate(data=random_lines(100), request_size=100) + req = request_generator(data=random_lines(100), request_size=100) request = next(req) assert len(request.index.docs) == 100 @@ -121,7 +122,7 @@ def random_lines(num_lines): for j in range(1, num_lines + 1): yield f'i\'m dummy doc {j}' - req = _generate(data=random_lines(100), request_size=100) + req = request_generator(data=random_lines(100), request_size=100) request = next(req) assert len(request.index.docs) == 100 @@ -141,7 +142,7 @@ def random_docs(num_docs): doc.mime_type = 'mime_type' yield doc - req = _generate(data=random_docs(100), request_size=100) + req = request_generator(data=random_docs(100), request_size=100) request = next(req) assert len(request.index.docs) == 100 @@ -174,7 +175,7 @@ def random_docs(num_docs): } yield doc - req = _generate(data=random_docs(100), request_size=100) + req = request_generator(data=random_docs(100), request_size=100) request = next(req) assert len(request.index.docs) == 100 @@ -213,7 +214,7 @@ def random_docs(num_docs): } yield json.dumps(doc) - req = _generate(data=random_docs(100), request_size=100) + req = request_generator(data=random_docs(100), request_size=100) request = next(req) assert len(request.index.docs) == 100 @@ -231,7 +232,7 @@ def random_docs(num_docs): def test_request_generate_numpy_arrays(): input_array = np.random.random([10, 10]) - req = _generate(data=input_array, request_size=5) + req = request_generator(data=input_array, request_size=5) request = next(req) assert len(request.index.docs) == 5 @@ -253,7 +254,7 @@ def generator(): for array in input_array: yield array - req = _generate(data=generator(), request_size=5) + req = request_generator(data=generator(), request_size=5) request = next(req) assert len(request.index.docs) == 5 diff --git a/tests/unit/flow/test_asyncflow.py b/tests/unit/flow/test_asyncflow.py index 14eb140fd5221..35f55485fc624 100644 --- a/tests/unit/flow/test_asyncflow.py +++ b/tests/unit/flow/test_asyncflow.py @@ -5,12 +5,14 @@ from jina import Document from jina.flow.asyncio import AsyncFlow -from jina.types.request import Response from jina.logging.profile import TimeContext +from jina.types.request import Response + +num_docs = 5 def validate(req): - assert len(req.docs) == 5 + assert len(req.docs) == num_docs assert req.docs[0].blob.ndim == 1 @@ -32,15 +34,38 @@ def documents(start_index, end_index): @pytest.mark.asyncio @pytest.mark.parametrize('restful', [False]) -async def test_run_async_flow(restful): +async def test_run_async_flow(restful, mocker): + r_val = mocker.Mock(wrap=validate) + with AsyncFlow(restful=restful).add() as f: + await f.index_ndarray(np.random.random([num_docs, 4]), on_done=r_val) + r_val.assert_called() + + +async def ainput_fn(): + for _ in range(num_docs): + yield np.random.random([4]) + await asyncio.sleep(0.1) + +async def ainput_fn2(): + for _ in range(num_docs): + yield Document(content=np.random.random([4])) + await asyncio.sleep(0.1) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('restful', [False]) +@pytest.mark.parametrize('input_fn', [ainput_fn, ainput_fn(), ainput_fn2(), ainput_fn2]) +async def test_run_async_flow_async_input(restful, input_fn, mocker): + r_val = mocker.Mock(wrap=validate) with AsyncFlow(restful=restful).add() as f: - await f.index_ndarray(np.random.random([5, 4]), on_done=validate) + await f.index(input_fn, on_done=r_val) + r_val.assert_called() async def run_async_flow_5s(restful): # WaitDriver pause 5s makes total roundtrip ~5s with AsyncFlow(restful=restful).add(uses='- !WaitDriver {}') as f: - await f.index_ndarray(np.random.random([5, 4]), on_done=validate) + await f.index_ndarray(np.random.random([num_docs, 4]), on_done=validate) async def sleep_print(): diff --git a/tests/unit/types/message/test_compression.py b/tests/unit/types/message/test_compression.py index 6243982f5b57b..2776ef617a74a 100644 --- a/tests/unit/types/message/test_compression.py +++ b/tests/unit/types/message/test_compression.py @@ -1,7 +1,7 @@ import pytest from jina import Message -from jina.clients.request import _generate +from jina.clients.request import request_generator from jina.enums import CompressAlgo from jina.logging.profile import TimeContext from tests import random_docs @@ -19,7 +19,7 @@ def test_compression(compress_algo, low_bytes, high_ratio): compress_min_ratio=10 if high_ratio else 1) with TimeContext(f'no compress'): - for r in _generate(docs): + for r in request_generator(docs): m = Message(None, r, compress=CompressAlgo.NONE, **kwargs) m.dump() no_comp_sizes.append(m.size) @@ -28,7 +28,7 @@ def test_compression(compress_algo, low_bytes, high_ratio): compress_min_bytes=2 * sum(no_comp_sizes) if low_bytes else 0, compress_min_ratio=10 if high_ratio else 1) with TimeContext(f'compressing with {str(compress_algo)}') as tc: - for r in _generate(docs): + for r in request_generator(docs): m = Message(None, r, compress=compress_algo, **kwargs) m.dump() sizes.append(m.size) diff --git a/tests/unit/types/message/test_message.py b/tests/unit/types/message/test_message.py index fdf6c5bada40c..8f7b62c52322c 100644 --- a/tests/unit/types/message/test_message.py +++ b/tests/unit/types/message/test_message.py @@ -4,7 +4,7 @@ import pytest from jina import Request, QueryLang, Document -from jina.clients.request import _generate +from jina.clients.request import request_generator from jina.proto import jina_pb2 from jina.proto.jina_pb2 import EnvelopeProto from jina.types.message import Message @@ -14,7 +14,7 @@ @pytest.mark.parametrize('field', _trigger_fields.difference({'command', 'args', 'flush'})) def test_lazy_access(field): - reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10))) + reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10))) for r in reqs: assert not r.is_used @@ -26,7 +26,7 @@ def test_lazy_access(field): def test_multiple_access(): - reqs = [Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10))] + reqs = [Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10))] for r in reqs: assert not r.is_used assert r @@ -39,7 +39,7 @@ def test_multiple_access(): def test_lazy_nest_access(): - reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10))) + reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10))) for r in reqs: assert not r.is_used # write access r.train @@ -50,7 +50,7 @@ def test_lazy_nest_access(): def test_lazy_change_message_type(): - reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10))) + reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10))) for r in reqs: assert not r.is_used # write access r.train @@ -61,7 +61,7 @@ def test_lazy_change_message_type(): def test_lazy_append_access(): - reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10))) + reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10))) for r in reqs: assert not r.is_used # write access r.train @@ -71,7 +71,7 @@ def test_lazy_append_access(): def test_lazy_clear_access(): - reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10))) + reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10))) for r in reqs: assert not r.is_used # write access r.train @@ -81,7 +81,7 @@ def test_lazy_clear_access(): def test_lazy_nested_clear_access(): - reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10))) + reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10))) for r in reqs: assert not r.is_used # write access r.train @@ -92,7 +92,7 @@ def test_lazy_nested_clear_access(): def test_lazy_msg_access(): reqs = [Message(None, r.SerializeToString(), 'test', '123', - request_id='123', request_type='IndexRequest') for r in _generate(random_docs(10))] + request_id='123', request_type='IndexRequest') for r in request_generator(random_docs(10))] for r in reqs: assert not r.request.is_used assert r.envelope @@ -113,7 +113,7 @@ def test_lazy_msg_access(): def test_message_size(): - reqs = [Message(None, r, 'test', '123') for r in _generate(random_docs(10))] + reqs = [Message(None, r, 'test', '123') for r in request_generator(random_docs(10))] for r in reqs: assert r.size == 0 assert sys.getsizeof(r.envelope.SerializeToString()) @@ -124,7 +124,7 @@ def test_message_size(): def test_lazy_request_fields(): - reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10))) + reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10))) for r in reqs: assert list(r.DESCRIPTOR.fields_by_name.keys())