Skip to content

Commit

Permalink
feat(client): input_fn can now be async generator (#1543) (#1808)
Browse files Browse the repository at this point in the history
* refactor(client): prepare for #1543

* feat(client): client now supports async input_fn
  • Loading branch information
hanxiao authored Jan 29, 2021
1 parent 300931a commit 72aa907
Show file tree
Hide file tree
Showing 16 changed files with 262 additions and 186 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
# Han Xiao owns CICD and README.md
.github @hanxiao
setup.py @hanxiao
extra-requirements.txt @hanxiao
26 changes: 21 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,24 @@ To visualize the Flow, simply chain it with `.plot('my-flow.svg')`. If you are u
#### Feed Data
<a href="https://mybinder.org/v2/gh/jina-ai/jupyter-notebooks/main?filepath=basic-feed-data.ipynb"><img align="right" src="https://github.com/jina-ai/jina/blob/master/.github/badges/run-badge.svg?raw=true"/></a>

Let's create some random data and index it:
To use a Flow, open it via `with` context manager, like you would open a file in Python. Now let's create some empty document and index it:

```python
import numpy
from jina import Document

with Flow().add() as f:
f.index((Document() for _ in range(10))) # index raw Jina Documents
f.index((Document() for _ in range(10)))
```

Flow supports CRUD operations: `index`, `search`, `update`, `delete`. Besides, it also provides sugary syntax on common data type such as files, text, and `ndarray`.

```python
with f:
f.index_ndarray(numpy.random.random([4,2]), on_done=print) # index ndarray data, document sliced on first dimension
f.index_lines(['hello world!', 'goodbye world!']) # index textual data, each element is a document
f.index_files(['/tmp/*.mp4', '/tmp/*.pdf']) # index files and wildcard globs, each file is a document
```

To use a Flow, open it using the `with` context manager, like you would a file in Python. You can call `index` and `search` with nearly all types of data. The whole data stream is asynchronous and efficient.

#### Fetch Result
<a href="https://mybinder.org/v2/gh/jina-ai/jupyter-notebooks/main?filepath=basic-fetch-result.ipynb"><img align="right" src="https://github.com/jina-ai/jina/blob/master/.github/badges/run-badge.svg?raw=true"/></a>

Expand Down Expand Up @@ -345,6 +348,19 @@ with AsyncFlow().add() as f:
await f.index_ndarray(numpy.random.random([5, 4]), on_done=print)
```

#### Asynchronous Input

`AsyncFlow`'s CRUD operations support async generator as the input function. This is particular useful when your data sources involves other asynchronous libraries (e.g. motor for Mongodb):

```python
async def input_fn():
for _ in range(10):
yield Document()
await asyncio.sleep(0.1)
with AsyncFlow().add() as f:
await f.index(input_fn)
```

That's all you need to know for understanding the magic behind `hello-world`. Now let's dive into it!

Expand Down
3 changes: 2 additions & 1 deletion extra-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ pydantic: http, devel, test, daemon
python-multipart: http, devel, test, daemon
aiofiles: devel, cicd, http, test, daemon
pytest-custom_exit_code: cicd, test
bs4: test
bs4: test
aiostream: devel, cicd
3 changes: 2 additions & 1 deletion jina/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .base import BaseClient, CallbackFnType, InputFnType
from .helper import callback_exec
from .request import GeneratorSourceType
from .websockets import WebSocketClientMixin
from .websocket import WebSocketClientMixin
from ..enums import RequestType
from ..helper import run_async, deprecated_alias

Expand Down Expand Up @@ -49,6 +49,7 @@ def search(self, input_fn: InputFnType = None,
:return:
"""
self.mode = RequestType.SEARCH
self.add_default_kwargs(kwargs)
return run_async(self._get_results, input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
Expand Down
3 changes: 2 additions & 1 deletion jina/clients/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base import InputFnType, BaseClient, CallbackFnType
from .websockets import WebSocketClientMixin
from .websocket import WebSocketClientMixin
from ..enums import RequestType
from ..helper import deprecated_alias

Expand Down Expand Up @@ -79,6 +79,7 @@ async def search(self, input_fn: InputFnType = None,
:return:
"""
self.mode = RequestType.SEARCH
self.add_default_kwargs(kwargs)
return await self._get_results(input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
Expand Down
41 changes: 33 additions & 8 deletions jina/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

import argparse
import os
from typing import Callable, Union, Optional, Iterator, List
from typing import Callable, Union, Optional, Iterator, List, Dict, AsyncIterator

import grpc

from . import request
import inspect
from .helper import callback_exec
from .request import GeneratorSourceType
from ..enums import RequestType
Expand Down Expand Up @@ -70,8 +69,12 @@ def check_input(input_fn: Optional[InputFnType] = None, **kwargs) -> None:

kwargs['data'] = input_fn

if inspect.isasyncgenfunction(input_fn) or inspect.isasyncgen(input_fn):
raise NotImplementedError('checking the validity of an async generator is not implemented yet')

try:
r = next(getattr(request, 'index')(**kwargs))
from .request import request_generator
r = next(request_generator(**kwargs))
if isinstance(r, Request):
default_logger.success(f'input_fn is valid')
else:
Expand All @@ -80,18 +83,25 @@ def check_input(input_fn: Optional[InputFnType] = None, **kwargs) -> None:
default_logger.error(f'input_fn is not valid!')
raise BadClientInput from ex

def _get_requests(self, **kwargs) -> Iterator['Request']:
def _get_requests(self, **kwargs) -> Union[Iterator['Request'], AsyncIterator['Request']]:
"""Get request in generator"""
_kwargs = vars(self.args)
_kwargs['data'] = self.input_fn
# override by the caller-specific kwargs
_kwargs.update(kwargs)

if inspect.isasyncgen(self.input_fn):
from .request.asyncio import request_generator
return request_generator(**_kwargs)
else:
from .request import request_generator
return request_generator(**_kwargs)

def _get_task_name(self, kwargs: Dict) -> str:
tname = str(self.mode).lower()
if 'mode' in kwargs:
tname = str(kwargs['mode']).lower()

return getattr(request, tname)(**_kwargs), tname
return tname

@property
def input_fn(self) -> InputFnType:
Expand All @@ -118,7 +128,8 @@ async def _get_results(self,
result = [] # type: List['Response']
try:
self.input_fn = input_fn
req_iter, tname = self._get_requests(**kwargs)
tname = self._get_task_name(kwargs)
req_iter = self._get_requests(**kwargs)
async with grpc.aio.insecure_channel(f'{self.args.host}:{self.args.port_expose}',
options=[('grpc.max_send_message_length', -1),
('grpc.max_receive_message_length', -1)]) as channel:
Expand Down Expand Up @@ -167,3 +178,17 @@ def search(self):

def train(self):
raise NotImplementedError

@staticmethod
def add_default_kwargs(kwargs: Dict):
# TODO: refactor it into load from config file
if ('top_k' in kwargs) and (kwargs['top_k'] is not None):
# associate all VectorSearchDriver and SliceQL driver to use top_k
from jina import QueryLang
topk_ql = [QueryLang({'name': 'SliceQL', 'priority': 1, 'parameters': {'end': kwargs['top_k']}}),
QueryLang(
{'name': 'VectorSearchDriver', 'priority': 1, 'parameters': {'top_k': kwargs['top_k']}})]
if 'queryset' not in kwargs:
kwargs['queryset'] = topk_ql
else:
kwargs['queryset'].extend(topk_ql)
126 changes: 0 additions & 126 deletions jina/clients/request.py

This file was deleted.

44 changes: 44 additions & 0 deletions jina/clients/request/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

from typing import Iterator, Union, Tuple, AsyncIterator

from .helper import _new_request_from_batch
from ... import Request
from ...enums import RequestType, DataInputType
from ...helper import batch_iterator
from ...logging import default_logger
from ...types.document import DocumentSourceType, DocumentContentType
from ...types.sets.querylang import AcceptQueryLangType

SingletonDataType = Union[DocumentContentType,
DocumentSourceType,
Tuple[DocumentContentType, DocumentContentType],
Tuple[DocumentSourceType, DocumentSourceType]]

GeneratorSourceType = Union[Iterator[SingletonDataType], AsyncIterator[SingletonDataType]]


def request_generator(data: GeneratorSourceType,
request_size: int = 0,
mode: RequestType = RequestType.INDEX,
mime_type: str = None,
queryset: Union[AcceptQueryLangType, Iterator[AcceptQueryLangType]] = None,
data_type: DataInputType = DataInputType.AUTO,
**kwargs # do not remove this, add on purpose to suppress unknown kwargs
) -> Iterator['Request']:
"""
:param data_type: if ``data`` is an iterator over self-contained document, i.e. :class:`DocumentSourceType`;
or an interator over possible Document content (set to text, blob and buffer).
:return:
"""

_kwargs = dict(mime_type=mime_type, length=request_size, weight=1.0)

try:
for batch in batch_iterator(data, request_size):
yield _new_request_from_batch(_kwargs, batch, data_type, mode, queryset)

except Exception as ex:
# must be handled here, as grpc channel wont handle Python exception
default_logger.critical(f'input_fn is not valid! {ex!r}', exc_info=True)
34 changes: 34 additions & 0 deletions jina/clients/request/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

from typing import Iterator, Union, AsyncIterator

from .helper import _new_request_from_batch
from .. import GeneratorSourceType
from ... import Request
from ...enums import RequestType, DataInputType
from ...importer import ImportExtensions
from ...logging import default_logger
from ...types.sets.querylang import AcceptQueryLangType


async def request_generator(data: GeneratorSourceType,
request_size: int = 0,
mode: RequestType = RequestType.INDEX,
mime_type: str = None,
queryset: Union[AcceptQueryLangType, Iterator[AcceptQueryLangType]] = None,
data_type: DataInputType = DataInputType.AUTO,
**kwargs # do not remove this, add on purpose to suppress unknown kwargs
) -> AsyncIterator['Request']:

_kwargs = dict(mime_type=mime_type, length=request_size, weight=1.0)

try:
with ImportExtensions(required=True):
import aiostream

async for batch in aiostream.stream.chunks(data, request_size):
yield _new_request_from_batch(_kwargs, batch, data_type, mode, queryset)
except Exception as ex:
# must be handled here, as grpc channel wont handle Python exception
default_logger.critical(f'input_fn is not valid! {ex!r}', exc_info=True)
Loading

0 comments on commit 72aa907

Please sign in to comment.