diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 633587a75cd19..1c3eac3f87ee0 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -7,3 +7,4 @@
# Han Xiao owns CICD and README.md
.github @hanxiao
setup.py @hanxiao
+extra-requirements.txt @hanxiao
\ No newline at end of file
diff --git a/README.md b/README.md
index ace5638bc8fc6..30084940cd847 100644
--- a/README.md
+++ b/README.md
@@ -113,21 +113,24 @@ To visualize the Flow, simply chain it with `.plot('my-flow.svg')`. If you are u
#### Feed Data
-Let's create some random data and index it:
+To use a Flow, open it via `with` context manager, like you would open a file in Python. Now let's create some empty document and index it:
```python
-import numpy
from jina import Document
with Flow().add() as f:
- f.index((Document() for _ in range(10))) # index raw Jina Documents
+ f.index((Document() for _ in range(10)))
+```
+
+Flow supports CRUD operations: `index`, `search`, `update`, `delete`. Besides, it also provides sugary syntax on common data type such as files, text, and `ndarray`.
+
+```python
+with f:
f.index_ndarray(numpy.random.random([4,2]), on_done=print) # index ndarray data, document sliced on first dimension
f.index_lines(['hello world!', 'goodbye world!']) # index textual data, each element is a document
f.index_files(['/tmp/*.mp4', '/tmp/*.pdf']) # index files and wildcard globs, each file is a document
```
-To use a Flow, open it using the `with` context manager, like you would a file in Python. You can call `index` and `search` with nearly all types of data. The whole data stream is asynchronous and efficient.
-
#### Fetch Result
@@ -345,6 +348,19 @@ with AsyncFlow().add() as f:
await f.index_ndarray(numpy.random.random([5, 4]), on_done=print)
```
+#### Asynchronous Input
+
+`AsyncFlow`'s CRUD operations support async generator as the input function. This is particular useful when your data sources involves other asynchronous libraries (e.g. motor for Mongodb):
+
+```python
+async def input_fn():
+ for _ in range(10):
+ yield Document()
+ await asyncio.sleep(0.1)
+
+with AsyncFlow().add() as f:
+ await f.index(input_fn)
+```
That's all you need to know for understanding the magic behind `hello-world`. Now let's dive into it!
diff --git a/extra-requirements.txt b/extra-requirements.txt
index e2d9292016128..9052cbfe9ba55 100644
--- a/extra-requirements.txt
+++ b/extra-requirements.txt
@@ -71,4 +71,5 @@ pydantic: http, devel, test, daemon
python-multipart: http, devel, test, daemon
aiofiles: devel, cicd, http, test, daemon
pytest-custom_exit_code: cicd, test
-bs4: test
\ No newline at end of file
+bs4: test
+aiostream: devel, cicd
\ No newline at end of file
diff --git a/jina/clients/__init__.py b/jina/clients/__init__.py
index 2f09c8c7c7d8a..6168254211e23 100644
--- a/jina/clients/__init__.py
+++ b/jina/clients/__init__.py
@@ -5,7 +5,7 @@
from .base import BaseClient, CallbackFnType, InputFnType
from .helper import callback_exec
from .request import GeneratorSourceType
-from .websockets import WebSocketClientMixin
+from .websocket import WebSocketClientMixin
from ..enums import RequestType
from ..helper import run_async, deprecated_alias
@@ -49,6 +49,7 @@ def search(self, input_fn: InputFnType = None,
:return:
"""
self.mode = RequestType.SEARCH
+ self.add_default_kwargs(kwargs)
return run_async(self._get_results, input_fn, on_done, on_error, on_always, **kwargs)
@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
diff --git a/jina/clients/asyncio.py b/jina/clients/asyncio.py
index 7477ca9aa1b62..c5831a44704e7 100644
--- a/jina/clients/asyncio.py
+++ b/jina/clients/asyncio.py
@@ -1,5 +1,5 @@
from .base import InputFnType, BaseClient, CallbackFnType
-from .websockets import WebSocketClientMixin
+from .websocket import WebSocketClientMixin
from ..enums import RequestType
from ..helper import deprecated_alias
@@ -79,6 +79,7 @@ async def search(self, input_fn: InputFnType = None,
:return:
"""
self.mode = RequestType.SEARCH
+ self.add_default_kwargs(kwargs)
return await self._get_results(input_fn, on_done, on_error, on_always, **kwargs)
@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
diff --git a/jina/clients/base.py b/jina/clients/base.py
index 796a53a85892c..afe63c8a0c0d4 100644
--- a/jina/clients/base.py
+++ b/jina/clients/base.py
@@ -3,11 +3,10 @@
import argparse
import os
-from typing import Callable, Union, Optional, Iterator, List
+from typing import Callable, Union, Optional, Iterator, List, Dict, AsyncIterator
import grpc
-
-from . import request
+import inspect
from .helper import callback_exec
from .request import GeneratorSourceType
from ..enums import RequestType
@@ -70,8 +69,12 @@ def check_input(input_fn: Optional[InputFnType] = None, **kwargs) -> None:
kwargs['data'] = input_fn
+ if inspect.isasyncgenfunction(input_fn) or inspect.isasyncgen(input_fn):
+ raise NotImplementedError('checking the validity of an async generator is not implemented yet')
+
try:
- r = next(getattr(request, 'index')(**kwargs))
+ from .request import request_generator
+ r = next(request_generator(**kwargs))
if isinstance(r, Request):
default_logger.success(f'input_fn is valid')
else:
@@ -80,18 +83,25 @@ def check_input(input_fn: Optional[InputFnType] = None, **kwargs) -> None:
default_logger.error(f'input_fn is not valid!')
raise BadClientInput from ex
- def _get_requests(self, **kwargs) -> Iterator['Request']:
+ def _get_requests(self, **kwargs) -> Union[Iterator['Request'], AsyncIterator['Request']]:
"""Get request in generator"""
_kwargs = vars(self.args)
_kwargs['data'] = self.input_fn
# override by the caller-specific kwargs
_kwargs.update(kwargs)
+ if inspect.isasyncgen(self.input_fn):
+ from .request.asyncio import request_generator
+ return request_generator(**_kwargs)
+ else:
+ from .request import request_generator
+ return request_generator(**_kwargs)
+
+ def _get_task_name(self, kwargs: Dict) -> str:
tname = str(self.mode).lower()
if 'mode' in kwargs:
tname = str(kwargs['mode']).lower()
-
- return getattr(request, tname)(**_kwargs), tname
+ return tname
@property
def input_fn(self) -> InputFnType:
@@ -118,7 +128,8 @@ async def _get_results(self,
result = [] # type: List['Response']
try:
self.input_fn = input_fn
- req_iter, tname = self._get_requests(**kwargs)
+ tname = self._get_task_name(kwargs)
+ req_iter = self._get_requests(**kwargs)
async with grpc.aio.insecure_channel(f'{self.args.host}:{self.args.port_expose}',
options=[('grpc.max_send_message_length', -1),
('grpc.max_receive_message_length', -1)]) as channel:
@@ -167,3 +178,17 @@ def search(self):
def train(self):
raise NotImplementedError
+
+ @staticmethod
+ def add_default_kwargs(kwargs: Dict):
+ # TODO: refactor it into load from config file
+ if ('top_k' in kwargs) and (kwargs['top_k'] is not None):
+ # associate all VectorSearchDriver and SliceQL driver to use top_k
+ from jina import QueryLang
+ topk_ql = [QueryLang({'name': 'SliceQL', 'priority': 1, 'parameters': {'end': kwargs['top_k']}}),
+ QueryLang(
+ {'name': 'VectorSearchDriver', 'priority': 1, 'parameters': {'top_k': kwargs['top_k']}})]
+ if 'queryset' not in kwargs:
+ kwargs['queryset'] = topk_ql
+ else:
+ kwargs['queryset'].extend(topk_ql)
diff --git a/jina/clients/request.py b/jina/clients/request.py
deleted file mode 100644
index 442babdc19499..0000000000000
--- a/jina/clients/request.py
+++ /dev/null
@@ -1,126 +0,0 @@
-__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
-__license__ = "Apache-2.0"
-
-from typing import Iterator, Union, Tuple, Sequence
-
-from .. import Request
-from ..enums import RequestType, DataInputType
-from ..excepts import BadDocType
-from ..helper import batch_iterator
-from ..logging import default_logger
-from ..types.document import Document, DocumentSourceType, DocumentContentType
-from ..types.querylang import QueryLang
-from ..types.sets.querylang import AcceptQueryLangType
-
-GeneratorSourceType = Iterator[Union[DocumentContentType,
- DocumentSourceType,
- Tuple[DocumentContentType, DocumentContentType],
- Tuple[DocumentSourceType, DocumentSourceType]]]
-
-
-def _build_doc(data, data_type: DataInputType, **kwargs) -> Tuple['Document', 'DataInputType']:
- def _build_doc_from_content():
- with Document(**kwargs) as d:
- d.content = data
- return d, DataInputType.CONTENT
-
- if data_type == DataInputType.AUTO or data_type == DataInputType.DOCUMENT:
- if isinstance(data, Document):
- # if incoming is already primitive type Document, then all good, best practice!
- return data, DataInputType.DOCUMENT
- try:
- d = Document(data, **kwargs)
- return d, DataInputType.DOCUMENT
- except BadDocType:
- # AUTO has a fallback, now reconsider it as content
- if data_type == DataInputType.AUTO:
- return _build_doc_from_content()
- else:
- raise
- elif data_type == DataInputType.CONTENT:
- return _build_doc_from_content()
-
-
-def _generate(data: GeneratorSourceType,
- request_size: int = 0,
- mode: RequestType = RequestType.INDEX,
- mime_type: str = None,
- queryset: Union[AcceptQueryLangType, Iterator[AcceptQueryLangType]] = None,
- data_type: DataInputType = DataInputType.AUTO,
- **kwargs # do not remove this, add on purpose to suppress unknown kwargs
- ) -> Iterator['Request']:
- """
- :param data_type: if ``data`` is an iterator over self-contained document, i.e. :class:`DocumentSourceType`;
- or an interator over possible Document content (set to text, blob and buffer).
- :return:
- """
-
- _kwargs = dict(mime_type=mime_type, length=request_size, weight=1.0)
-
- try:
- for batch in batch_iterator(data, request_size):
- req = Request()
- req.request_type = str(mode)
- for content in batch:
- if isinstance(content, tuple) and len(content) == 2:
- # content comes in pair, will take the first as the input and the second as the groundtruth
-
- # note how data_type is cached
- d, data_type = _build_doc(content[0], data_type, **_kwargs)
- gt, _ = _build_doc(content[1], data_type, **_kwargs)
- req.docs.append(d)
- req.groundtruths.append(gt)
- else:
- d, data_type = _build_doc(content, data_type, **_kwargs)
- req.docs.append(d)
-
- if isinstance(queryset, Sequence):
- req.queryset.extend(queryset)
- elif queryset is not None:
- req.queryset.append(queryset)
-
- yield req
- except Exception as ex:
- # must be handled here, as grpc channel wont handle Python exception
- default_logger.critical(f'input_fn is not valid! {ex!r}', exc_info=True)
-
-
-def index(*args, **kwargs):
- """Generate a indexing request"""
- yield from _generate(*args, **kwargs)
-
-
-def update(*args, **kwargs):
- """Generate a update request"""
- yield from _generate(*args, **kwargs)
-
-
-def delete(*args, **kwargs):
- """Generate a delete request"""
- yield from _generate(*args, **kwargs)
-
-
-def train(*args, **kwargs):
- """Generate a training request """
- yield from _generate(*args, **kwargs)
- req = Request()
- req.train.flush = True
- yield req
-
-
-def search(*args, **kwargs):
- """Generate a searching request """
- if ('top_k' in kwargs) and (kwargs['top_k'] is not None):
- # associate all VectorSearchDriver and SliceQL driver to use top_k
- topk_ql = [QueryLang({'name': 'SliceQL', 'priority': 1, 'parameters': {'end': kwargs['top_k']}}),
- QueryLang({'name': 'VectorSearchDriver', 'priority': 1, 'parameters': {'top_k': kwargs['top_k']}})]
- if 'queryset' not in kwargs:
- kwargs['queryset'] = topk_ql
- else:
- kwargs['queryset'].extend(topk_ql)
- yield from _generate(*args, **kwargs)
-
-
-def evaluate(*args, **kwargs):
- """Generate an evaluation request """
- yield from _generate(*args, **kwargs)
diff --git a/jina/clients/request/__init__.py b/jina/clients/request/__init__.py
new file mode 100644
index 0000000000000..204e05f6282fd
--- /dev/null
+++ b/jina/clients/request/__init__.py
@@ -0,0 +1,44 @@
+__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
+__license__ = "Apache-2.0"
+
+from typing import Iterator, Union, Tuple, AsyncIterator
+
+from .helper import _new_request_from_batch
+from ... import Request
+from ...enums import RequestType, DataInputType
+from ...helper import batch_iterator
+from ...logging import default_logger
+from ...types.document import DocumentSourceType, DocumentContentType
+from ...types.sets.querylang import AcceptQueryLangType
+
+SingletonDataType = Union[DocumentContentType,
+ DocumentSourceType,
+ Tuple[DocumentContentType, DocumentContentType],
+ Tuple[DocumentSourceType, DocumentSourceType]]
+
+GeneratorSourceType = Union[Iterator[SingletonDataType], AsyncIterator[SingletonDataType]]
+
+
+def request_generator(data: GeneratorSourceType,
+ request_size: int = 0,
+ mode: RequestType = RequestType.INDEX,
+ mime_type: str = None,
+ queryset: Union[AcceptQueryLangType, Iterator[AcceptQueryLangType]] = None,
+ data_type: DataInputType = DataInputType.AUTO,
+ **kwargs # do not remove this, add on purpose to suppress unknown kwargs
+ ) -> Iterator['Request']:
+ """
+ :param data_type: if ``data`` is an iterator over self-contained document, i.e. :class:`DocumentSourceType`;
+ or an interator over possible Document content (set to text, blob and buffer).
+ :return:
+ """
+
+ _kwargs = dict(mime_type=mime_type, length=request_size, weight=1.0)
+
+ try:
+ for batch in batch_iterator(data, request_size):
+ yield _new_request_from_batch(_kwargs, batch, data_type, mode, queryset)
+
+ except Exception as ex:
+ # must be handled here, as grpc channel wont handle Python exception
+ default_logger.critical(f'input_fn is not valid! {ex!r}', exc_info=True)
diff --git a/jina/clients/request/asyncio.py b/jina/clients/request/asyncio.py
new file mode 100644
index 0000000000000..bba77bfab61f3
--- /dev/null
+++ b/jina/clients/request/asyncio.py
@@ -0,0 +1,34 @@
+__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
+__license__ = "Apache-2.0"
+
+from typing import Iterator, Union, AsyncIterator
+
+from .helper import _new_request_from_batch
+from .. import GeneratorSourceType
+from ... import Request
+from ...enums import RequestType, DataInputType
+from ...importer import ImportExtensions
+from ...logging import default_logger
+from ...types.sets.querylang import AcceptQueryLangType
+
+
+async def request_generator(data: GeneratorSourceType,
+ request_size: int = 0,
+ mode: RequestType = RequestType.INDEX,
+ mime_type: str = None,
+ queryset: Union[AcceptQueryLangType, Iterator[AcceptQueryLangType]] = None,
+ data_type: DataInputType = DataInputType.AUTO,
+ **kwargs # do not remove this, add on purpose to suppress unknown kwargs
+ ) -> AsyncIterator['Request']:
+
+ _kwargs = dict(mime_type=mime_type, length=request_size, weight=1.0)
+
+ try:
+ with ImportExtensions(required=True):
+ import aiostream
+
+ async for batch in aiostream.stream.chunks(data, request_size):
+ yield _new_request_from_batch(_kwargs, batch, data_type, mode, queryset)
+ except Exception as ex:
+ # must be handled here, as grpc channel wont handle Python exception
+ default_logger.critical(f'input_fn is not valid! {ex!r}', exc_info=True)
diff --git a/jina/clients/request/helper.py b/jina/clients/request/helper.py
new file mode 100644
index 0000000000000..b48d47c486950
--- /dev/null
+++ b/jina/clients/request/helper.py
@@ -0,0 +1,50 @@
+from typing import Tuple, Sequence
+
+from ... import Document, Request
+from ...enums import DataInputType
+from ...excepts import BadDocType
+
+
+def _new_doc_from_data(data, data_type: DataInputType, **kwargs) -> Tuple['Document', 'DataInputType']:
+ def _build_doc_from_content():
+ with Document(**kwargs) as d:
+ d.content = data
+ return d, DataInputType.CONTENT
+
+ if data_type == DataInputType.AUTO or data_type == DataInputType.DOCUMENT:
+ if isinstance(data, Document):
+ # if incoming is already primitive type Document, then all good, best practice!
+ return data, DataInputType.DOCUMENT
+ try:
+ d = Document(data, **kwargs)
+ return d, DataInputType.DOCUMENT
+ except BadDocType:
+ # AUTO has a fallback, now reconsider it as content
+ if data_type == DataInputType.AUTO:
+ return _build_doc_from_content()
+ else:
+ raise
+ elif data_type == DataInputType.CONTENT:
+ return _build_doc_from_content()
+
+
+def _new_request_from_batch(_kwargs, batch, data_type, mode, queryset):
+ req = Request()
+ req.request_type = str(mode)
+ for content in batch:
+ if isinstance(content, tuple) and len(content) == 2:
+ # content comes in pair, will take the first as the input and the second as the groundtruth
+
+ # note how data_type is cached
+ d, data_type = _new_doc_from_data(content[0], data_type, **_kwargs)
+ gt, _ = _new_doc_from_data(content[1], data_type, **_kwargs)
+ req.docs.append(d)
+ req.groundtruths.append(gt)
+ else:
+ d, data_type = _new_doc_from_data(content, data_type, **_kwargs)
+ req.docs.append(d)
+ if isinstance(queryset, Sequence):
+ req.queryset.extend(queryset)
+ elif queryset is not None:
+ req.queryset.append(queryset)
+ return req
diff --git a/jina/clients/websockets.py b/jina/clients/websocket.py
similarity index 96%
rename from jina/clients/websockets.py
rename to jina/clients/websocket.py
index 1162859075854..3ee309710a4a8 100644
--- a/jina/clients/websockets.py
+++ b/jina/clients/websocket.py
@@ -1,4 +1,5 @@
import asyncio
+from abc import ABC
from typing import Callable, List
from .base import BaseClient
@@ -8,7 +9,7 @@
from ..types.request import Request, Response
-class WebSocketClientMixin(BaseClient):
+class WebSocketClientMixin(BaseClient, ABC):
async def _get_results(self,
input_fn: Callable,
on_done: Callable,
@@ -33,7 +34,9 @@ async def _get_results(self,
result = [] # type: List['Response']
self.input_fn = input_fn
- req_iter, tname = self._get_requests(**kwargs)
+
+ tname = self._get_task_name(kwargs)
+ req_iter = self._get_requests(**kwargs)
try:
client_info = f'{self.args.host}:{self.args.port_expose}'
# setting `max_size` as None to avoid connection closure due to size of message
diff --git a/jina/peapods/runtimes/asyncio/rest/app.py b/jina/peapods/runtimes/asyncio/rest/app.py
index d5fa9ddfdde6d..265732684661e 100644
--- a/jina/peapods/runtimes/asyncio/rest/app.py
+++ b/jina/peapods/runtimes/asyncio/rest/app.py
@@ -6,14 +6,13 @@
from ..grpc.async_call import AsyncPrefetchCall
from ....zmq import AsyncZmqlet
-from ..... import clients
+from .....clients.request import request_generator
from .....enums import RequestType
from .....importer import ImportExtensions
from .....logging import JinaLogger
from .....types.message import Message
from .....types.request import Request
-
def get_fastapi_app(args: 'argparse.Namespace', logger: 'JinaLogger'):
with ImportExtensions(required=True):
from fastapi import FastAPI, WebSocket, Body
@@ -21,8 +20,7 @@ def get_fastapi_app(args: 'argparse.Namespace', logger: 'JinaLogger'):
from fastapi.middleware.cors import CORSMiddleware
from starlette.endpoints import WebSocketEndpoint
from starlette import status
- if False:
- from starlette.types import Receive, Scope, Send
+ from starlette.types import Receive, Scope, Send
app = FastAPI(title='RESTRuntime')
app.add_middleware(
@@ -50,7 +48,9 @@ async def api(mode: str, body: Any = Body(...)):
return error('"data" field is empty', 406)
body['mode'] = RequestType.from_string(mode)
- req_iter = getattr(clients.request, mode)(**body)
+ from .....clients import BaseClient
+ BaseClient.add_default_kwargs(body)
+ req_iter = request_generator(**body)
results = await get_result_in_json(req_iter=req_iter)
return JSONResponse(content=results[0], status_code=200)
diff --git a/tests/unit/clients/python/test_request.py b/tests/unit/clients/python/test_request.py
index b0e8ea19dd9c7..0c294e86f1913 100644
--- a/tests/unit/clients/python/test_request.py
+++ b/tests/unit/clients/python/test_request.py
@@ -5,7 +5,8 @@
from google.protobuf.json_format import MessageToJson, MessageToDict
from jina import Document, Flow
-from jina.clients.request import _generate, _build_doc
+from jina.clients.request import request_generator
+from jina.clients.request.helper import _new_doc_from_data
from jina.enums import DataInputType
from jina.excepts import BadDocType
from jina.proto import jina_pb2
@@ -30,7 +31,7 @@ def test_on_bad_iterator():
def test_data_type_builder_doc(builder):
a = DocumentProto()
a.id = 'a236cbb0eda62d58'
- d, t = _build_doc(builder(a), DataInputType.DOCUMENT)
+ d, t = _new_doc_from_data(builder(a), DataInputType.DOCUMENT)
assert d.id == a.id
assert t == DataInputType.DOCUMENT
@@ -39,13 +40,13 @@ def test_data_type_builder_doc_bad():
a = DocumentProto()
a.id = 'a236cbb0eda62d58'
with pytest.raises(BadDocType):
- _build_doc(b'BREAKIT!' + a.SerializeToString(), DataInputType.DOCUMENT)
+ _new_doc_from_data(b'BREAKIT!' + a.SerializeToString(), DataInputType.DOCUMENT)
with pytest.raises(BadDocType):
- _build_doc(MessageToJson(a) + '🍔', DataInputType.DOCUMENT)
+ _new_doc_from_data(MessageToJson(a) + '🍔', DataInputType.DOCUMENT)
with pytest.raises(BadDocType):
- _build_doc({'🍔': '🍔'}, DataInputType.DOCUMENT)
+ _new_doc_from_data({'🍔': '🍔'}, DataInputType.DOCUMENT)
@pytest.mark.parametrize('input_type', [DataInputType.AUTO, DataInputType.CONTENT])
@@ -54,20 +55,20 @@ def test_data_type_builder_auto(input_type):
print(f'quant is on: {os.environ["JINA_ARRAY_QUANT"]}')
del os.environ['JINA_ARRAY_QUANT']
- d, t = _build_doc('123', input_type)
+ d, t = _new_doc_from_data('123', input_type)
assert d.text == '123'
assert t == DataInputType.CONTENT
- d, t = _build_doc(b'45678', input_type)
+ d, t = _new_doc_from_data(b'45678', input_type)
assert t == DataInputType.CONTENT
assert d.buffer == b'45678'
- d, t = _build_doc(b'123', input_type)
+ d, t = _new_doc_from_data(b'123', input_type)
assert t == DataInputType.CONTENT
assert d.buffer == b'123'
c = np.random.random([10, 10])
- d, t = _build_doc(c, input_type)
+ d, t = _new_doc_from_data(c, input_type)
np.testing.assert_equal(d.blob, c)
assert t == DataInputType.CONTENT
@@ -77,7 +78,7 @@ def random_lines(num_lines):
for j in range(1, num_lines + 1):
yield f'i\'m dummy doc {j}'
- req = _generate(data=random_lines(100), request_size=100)
+ req = request_generator(data=random_lines(100), request_size=100)
request = next(req)
assert len(request.index.docs) == 100
@@ -91,7 +92,7 @@ def test_request_generate_lines_from_list():
def random_lines(num_lines):
return [f'i\'m dummy doc {j}' for j in range(1, num_lines + 1)]
- req = _generate(data=random_lines(100), request_size=100)
+ req = request_generator(data=random_lines(100), request_size=100)
request = next(req)
assert len(request.index.docs) == 100
@@ -106,7 +107,7 @@ def random_lines(num_lines):
for j in range(1, num_lines + 1):
yield f'https://github.com i\'m dummy doc {j}'
- req = _generate(data=random_lines(100), request_size=100)
+ req = request_generator(data=random_lines(100), request_size=100)
request = next(req)
assert len(request.index.docs) == 100
@@ -121,7 +122,7 @@ def random_lines(num_lines):
for j in range(1, num_lines + 1):
yield f'i\'m dummy doc {j}'
- req = _generate(data=random_lines(100), request_size=100)
+ req = request_generator(data=random_lines(100), request_size=100)
request = next(req)
assert len(request.index.docs) == 100
@@ -141,7 +142,7 @@ def random_docs(num_docs):
doc.mime_type = 'mime_type'
yield doc
- req = _generate(data=random_docs(100), request_size=100)
+ req = request_generator(data=random_docs(100), request_size=100)
request = next(req)
assert len(request.index.docs) == 100
@@ -174,7 +175,7 @@ def random_docs(num_docs):
}
yield doc
- req = _generate(data=random_docs(100), request_size=100)
+ req = request_generator(data=random_docs(100), request_size=100)
request = next(req)
assert len(request.index.docs) == 100
@@ -213,7 +214,7 @@ def random_docs(num_docs):
}
yield json.dumps(doc)
- req = _generate(data=random_docs(100), request_size=100)
+ req = request_generator(data=random_docs(100), request_size=100)
request = next(req)
assert len(request.index.docs) == 100
@@ -231,7 +232,7 @@ def random_docs(num_docs):
def test_request_generate_numpy_arrays():
input_array = np.random.random([10, 10])
- req = _generate(data=input_array, request_size=5)
+ req = request_generator(data=input_array, request_size=5)
request = next(req)
assert len(request.index.docs) == 5
@@ -253,7 +254,7 @@ def generator():
for array in input_array:
yield array
- req = _generate(data=generator(), request_size=5)
+ req = request_generator(data=generator(), request_size=5)
request = next(req)
assert len(request.index.docs) == 5
diff --git a/tests/unit/flow/test_asyncflow.py b/tests/unit/flow/test_asyncflow.py
index 14eb140fd5221..35f55485fc624 100644
--- a/tests/unit/flow/test_asyncflow.py
+++ b/tests/unit/flow/test_asyncflow.py
@@ -5,12 +5,14 @@
from jina import Document
from jina.flow.asyncio import AsyncFlow
-from jina.types.request import Response
from jina.logging.profile import TimeContext
+from jina.types.request import Response
+
+num_docs = 5
def validate(req):
- assert len(req.docs) == 5
+ assert len(req.docs) == num_docs
assert req.docs[0].blob.ndim == 1
@@ -32,15 +34,38 @@ def documents(start_index, end_index):
@pytest.mark.asyncio
@pytest.mark.parametrize('restful', [False])
-async def test_run_async_flow(restful):
+async def test_run_async_flow(restful, mocker):
+ r_val = mocker.Mock(wrap=validate)
+ with AsyncFlow(restful=restful).add() as f:
+ await f.index_ndarray(np.random.random([num_docs, 4]), on_done=r_val)
+ r_val.assert_called()
+
+
+async def ainput_fn():
+ for _ in range(num_docs):
+ yield np.random.random([4])
+ await asyncio.sleep(0.1)
+
+async def ainput_fn2():
+ for _ in range(num_docs):
+ yield Document(content=np.random.random([4]))
+ await asyncio.sleep(0.1)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize('restful', [False])
+@pytest.mark.parametrize('input_fn', [ainput_fn, ainput_fn(), ainput_fn2(), ainput_fn2])
+async def test_run_async_flow_async_input(restful, input_fn, mocker):
+ r_val = mocker.Mock(wrap=validate)
with AsyncFlow(restful=restful).add() as f:
- await f.index_ndarray(np.random.random([5, 4]), on_done=validate)
+ await f.index(input_fn, on_done=r_val)
+ r_val.assert_called()
async def run_async_flow_5s(restful):
# WaitDriver pause 5s makes total roundtrip ~5s
with AsyncFlow(restful=restful).add(uses='- !WaitDriver {}') as f:
- await f.index_ndarray(np.random.random([5, 4]), on_done=validate)
+ await f.index_ndarray(np.random.random([num_docs, 4]), on_done=validate)
async def sleep_print():
diff --git a/tests/unit/types/message/test_compression.py b/tests/unit/types/message/test_compression.py
index 6243982f5b57b..2776ef617a74a 100644
--- a/tests/unit/types/message/test_compression.py
+++ b/tests/unit/types/message/test_compression.py
@@ -1,7 +1,7 @@
import pytest
from jina import Message
-from jina.clients.request import _generate
+from jina.clients.request import request_generator
from jina.enums import CompressAlgo
from jina.logging.profile import TimeContext
from tests import random_docs
@@ -19,7 +19,7 @@ def test_compression(compress_algo, low_bytes, high_ratio):
compress_min_ratio=10 if high_ratio else 1)
with TimeContext(f'no compress'):
- for r in _generate(docs):
+ for r in request_generator(docs):
m = Message(None, r, compress=CompressAlgo.NONE, **kwargs)
m.dump()
no_comp_sizes.append(m.size)
@@ -28,7 +28,7 @@ def test_compression(compress_algo, low_bytes, high_ratio):
compress_min_bytes=2 * sum(no_comp_sizes) if low_bytes else 0,
compress_min_ratio=10 if high_ratio else 1)
with TimeContext(f'compressing with {str(compress_algo)}') as tc:
- for r in _generate(docs):
+ for r in request_generator(docs):
m = Message(None, r, compress=compress_algo, **kwargs)
m.dump()
sizes.append(m.size)
diff --git a/tests/unit/types/message/test_message.py b/tests/unit/types/message/test_message.py
index fdf6c5bada40c..8f7b62c52322c 100644
--- a/tests/unit/types/message/test_message.py
+++ b/tests/unit/types/message/test_message.py
@@ -4,7 +4,7 @@
import pytest
from jina import Request, QueryLang, Document
-from jina.clients.request import _generate
+from jina.clients.request import request_generator
from jina.proto import jina_pb2
from jina.proto.jina_pb2 import EnvelopeProto
from jina.types.message import Message
@@ -14,7 +14,7 @@
@pytest.mark.parametrize('field', _trigger_fields.difference({'command', 'args', 'flush'}))
def test_lazy_access(field):
- reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10)))
+ reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10)))
for r in reqs:
assert not r.is_used
@@ -26,7 +26,7 @@ def test_lazy_access(field):
def test_multiple_access():
- reqs = [Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10))]
+ reqs = [Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10))]
for r in reqs:
assert not r.is_used
assert r
@@ -39,7 +39,7 @@ def test_multiple_access():
def test_lazy_nest_access():
- reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10)))
+ reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10)))
for r in reqs:
assert not r.is_used
# write access r.train
@@ -50,7 +50,7 @@ def test_lazy_nest_access():
def test_lazy_change_message_type():
- reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10)))
+ reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10)))
for r in reqs:
assert not r.is_used
# write access r.train
@@ -61,7 +61,7 @@ def test_lazy_change_message_type():
def test_lazy_append_access():
- reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10)))
+ reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10)))
for r in reqs:
assert not r.is_used
# write access r.train
@@ -71,7 +71,7 @@ def test_lazy_append_access():
def test_lazy_clear_access():
- reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10)))
+ reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10)))
for r in reqs:
assert not r.is_used
# write access r.train
@@ -81,7 +81,7 @@ def test_lazy_clear_access():
def test_lazy_nested_clear_access():
- reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10)))
+ reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10)))
for r in reqs:
assert not r.is_used
# write access r.train
@@ -92,7 +92,7 @@ def test_lazy_nested_clear_access():
def test_lazy_msg_access():
reqs = [Message(None, r.SerializeToString(), 'test', '123',
- request_id='123', request_type='IndexRequest') for r in _generate(random_docs(10))]
+ request_id='123', request_type='IndexRequest') for r in request_generator(random_docs(10))]
for r in reqs:
assert not r.request.is_used
assert r.envelope
@@ -113,7 +113,7 @@ def test_lazy_msg_access():
def test_message_size():
- reqs = [Message(None, r, 'test', '123') for r in _generate(random_docs(10))]
+ reqs = [Message(None, r, 'test', '123') for r in request_generator(random_docs(10))]
for r in reqs:
assert r.size == 0
assert sys.getsizeof(r.envelope.SerializeToString())
@@ -124,7 +124,7 @@ def test_message_size():
def test_lazy_request_fields():
- reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in _generate(random_docs(10)))
+ reqs = (Request(r.SerializeToString(), EnvelopeProto()) for r in request_generator(random_docs(10)))
for r in reqs:
assert list(r.DESCRIPTOR.fields_by_name.keys())