Skip to content

Commit

Permalink
fix(helper): improve deprecated decor (#1761)
Browse files Browse the repository at this point in the history
* fix(helper): improve deprecated decor
  • Loading branch information
hanxiao authored Jan 22, 2021
1 parent 60f80b2 commit 780d626
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 28 deletions.
7 changes: 6 additions & 1 deletion jina/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from .helper import callback_exec
from .request import GeneratorSourceType
from ..enums import RequestType
from ..helper import run_async
from ..helper import run_async, deprecated_alias


class Client(BaseClient):
"""A simple Python client for connecting to the gRPC gateway.
It manages the asyncio eventloop internally, so all interfaces are synchronous from the outside.
"""

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def train(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand All @@ -32,6 +33,7 @@ def train(self, input_fn: InputFnType = None,
self.mode = RequestType.TRAIN
return run_async(self._get_results, input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def search(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand All @@ -49,6 +51,7 @@ def search(self, input_fn: InputFnType = None,
self.mode = RequestType.SEARCH
return run_async(self._get_results, input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def index(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand All @@ -66,6 +69,7 @@ def index(self, input_fn: InputFnType = None,
self.mode = RequestType.INDEX
return run_async(self._get_results, input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def update(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand All @@ -83,6 +87,7 @@ def update(self, input_fn: InputFnType = None,
self.mode = RequestType.UPDATE
return run_async(self._get_results, input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def delete(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand Down
4 changes: 4 additions & 0 deletions jina/clients/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base import InputFnType, BaseClient, CallbackFnType
from .websockets import WebSocketClientMixin
from ..enums import RequestType
from ..helper import deprecated_alias


class AsyncClient(BaseClient):
Expand Down Expand Up @@ -44,6 +45,7 @@ async def concurrent_main():
One can think of :class:`Client` as Jina-managed eventloop, whereas :class:`AsyncClient` is self-managed eventloop.
"""

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def train(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand All @@ -61,6 +63,7 @@ async def train(self, input_fn: InputFnType = None,
self.mode = RequestType.TRAIN
return await self._get_results(input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def search(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand All @@ -78,6 +81,7 @@ async def search(self, input_fn: InputFnType = None,
self.mode = RequestType.SEARCH
return await self._get_results(input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def index(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand Down
4 changes: 4 additions & 0 deletions jina/excepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,7 @@ class HubLoginRequired(Exception):

class DaemonConnectivityError(Exception):
""" Exception to raise when jina daemon is not connectable"""


class NotSupportedError(Exception):
""" Exeception when user accidentally using a retired argument """
25 changes: 9 additions & 16 deletions jina/flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .base import BaseFlow
from ..clients.base import InputFnType, CallbackFnType
from ..enums import DataInputType
from ..helper import deprecated_alias


class Flow(BaseFlow):
Expand Down Expand Up @@ -42,6 +43,7 @@ def my_reader():
"""
return self._get_client(**kwargs).train(input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def index_ndarray(self, array: 'np.ndarray', axis: int = 0, size: int = None, shuffle: bool = False,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand All @@ -60,6 +62,7 @@ def index_ndarray(self, array: 'np.ndarray', axis: int = 0, size: int = None, sh
return self._get_client(**kwargs).index(_input_ndarray(array, axis, size, shuffle),
on_done, on_error, on_always, data_type=DataInputType.CONTENT, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def search_ndarray(self, array: 'np.ndarray', axis: int = 0, size: int = None, shuffle: bool = False,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand All @@ -80,6 +83,7 @@ def search_ndarray(self, array: 'np.ndarray', axis: int = 0, size: int = None, s
self._get_client(**kwargs).search(_input_ndarray(array, axis, size, shuffle),
on_done, on_error, on_always, data_type=DataInputType.CONTENT, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def index_lines(self, lines: Iterator[str] = None, filepath: str = None, size: int = None,
sampling_rate: float = None, read_mode='r',
on_done: CallbackFnType = None,
Expand All @@ -102,6 +106,7 @@ def index_lines(self, lines: Iterator[str] = None, filepath: str = None, size: i
return self._get_client(**kwargs).index(_input_lines(lines, filepath, size, sampling_rate, read_mode),
on_done, on_error, on_always, data_type=DataInputType.AUTO, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def index_files(self, patterns: Union[str, List[str]], recursive: bool = True,
size: int = None, sampling_rate: float = None, read_mode: str = None,
on_done: CallbackFnType = None,
Expand All @@ -125,6 +130,7 @@ def index_files(self, patterns: Union[str, List[str]], recursive: bool = True,
return self._get_client(**kwargs).index(_input_files(patterns, recursive, size, sampling_rate, read_mode),
on_done, on_error, on_always, data_type=DataInputType.CONTENT, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def search_files(self, patterns: Union[str, List[str]], recursive: bool = True,
size: int = None, sampling_rate: float = None, read_mode: str = None,
on_done: CallbackFnType = None,
Expand All @@ -149,6 +155,7 @@ def search_files(self, patterns: Union[str, List[str]], recursive: bool = True,
on_done, on_error, on_always, data_type=DataInputType.CONTENT,
**kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def search_lines(self, filepath: str = None, lines: Iterator[str] = None, size: int = None,
sampling_rate: float = None, read_mode='r',
on_done: CallbackFnType = None,
Expand All @@ -171,6 +178,7 @@ def search_lines(self, filepath: str = None, lines: Iterator[str] = None, size:
return self._get_client(**kwargs).search(_input_lines(lines, filepath, size, sampling_rate, read_mode),
on_done, on_error, on_always, data_type=DataInputType.AUTO, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def index(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand Down Expand Up @@ -264,6 +272,7 @@ def my_reader():
"""
self._get_client(**kwargs).delete(input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
def search(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand Down Expand Up @@ -294,19 +303,3 @@ def my_reader():
:param kwargs: accepts all keyword arguments of `jina client` CLI
"""
return self._get_client(**kwargs).search(input_fn, on_done, on_error, on_always, **kwargs)

@property
def workspace_id(self) -> Dict[str, str]:
"""Get all Pods' ``workspace_id`` values in a dict """
return {k: p.args.workspace_id for k, p in self if hasattr(p.args, 'workspace_id')}

@workspace_id.setter
def workspace_id(self, value: str):
"""Set all Pods' ``workspace_id`` to ``value``
:param value: a hexadecimal UUID string
"""
uuid.UUID(value)
for k, p in self:
if hasattr(p.args, 'workspace_id'):
p.args.workspace_id = value
14 changes: 12 additions & 2 deletions jina/flow/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _update_client(self):
if self._pod_nodes['gateway'].args.restful:
self._cls_client = AsyncWebSocketClient

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def train(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand Down Expand Up @@ -118,6 +119,7 @@ def my_reader():
"""
return await self._get_client(**kwargs).train(input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def index_ndarray(self, array: 'np.ndarray', axis: int = 0, size: int = None, shuffle: bool = False,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand All @@ -139,6 +141,7 @@ async def index_ndarray(self, array: 'np.ndarray', axis: int = 0, size: int = No
on_done, on_error, on_always, data_type=DataInputType.CONTENT,
**kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def search_ndarray(self, array: 'np.ndarray', axis: int = 0, size: int = None, shuffle: bool = False,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand All @@ -156,9 +159,11 @@ async def search_ndarray(self, array: 'np.ndarray', axis: int = 0, size: int = N
:param kwargs: accepts all keyword arguments of `jina client` CLI
"""
from ..clients.sugary_io import _input_ndarray
await self._get_client(**kwargs).search(_input_ndarray(array, axis, size, shuffle),
on_done, on_error, on_always, data_type=DataInputType.CONTENT, **kwargs)
return await self._get_client(**kwargs).search(_input_ndarray(array, axis, size, shuffle),
on_done, on_error, on_always, data_type=DataInputType.CONTENT,
**kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def index_lines(self, lines: Iterator[str] = None, filepath: str = None, size: int = None,
sampling_rate: float = None, read_mode='r',
on_done: CallbackFnType = None,
Expand All @@ -183,6 +188,7 @@ async def index_lines(self, lines: Iterator[str] = None, filepath: str = None, s
on_done, on_error, on_always, data_type=DataInputType.CONTENT,
**kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def index_files(self, patterns: Union[str, List[str]], recursive: bool = True,
size: int = None, sampling_rate: float = None, read_mode: str = None,
on_done: CallbackFnType = None,
Expand All @@ -208,6 +214,7 @@ async def index_files(self, patterns: Union[str, List[str]], recursive: bool = T
on_done, on_error, on_always, data_type=DataInputType.CONTENT,
**kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def search_files(self, patterns: Union[str, List[str]], recursive: bool = True,
size: int = None, sampling_rate: float = None, read_mode: str = None,
on_done: CallbackFnType = None,
Expand All @@ -233,6 +240,7 @@ async def search_files(self, patterns: Union[str, List[str]], recursive: bool =
_input_files(patterns, recursive, size, sampling_rate, read_mode),
on_done, on_error, on_always, data_type=DataInputType.CONTENT, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def search_lines(self, filepath: str = None, lines: Iterator[str] = None, size: int = None,
sampling_rate: float = None, read_mode='r',
on_done: CallbackFnType = None,
Expand All @@ -257,6 +265,7 @@ async def search_lines(self, filepath: str = None, lines: Iterator[str] = None,
on_done, on_error, on_always, data_type=DataInputType.CONTENT,
**kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def index(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand Down Expand Up @@ -300,6 +309,7 @@ def my_reader():
"""
return await self._get_client(**kwargs).index(input_fn, on_done, on_error, on_always, **kwargs)

@deprecated_alias(buffer=('input_fn', 1), callback=('on_done', 1), output_fn=('on_done', 1))
async def search(self, input_fn: InputFnType = None,
on_done: CallbackFnType = None,
on_error: CallbackFnType = None,
Expand Down
17 changes: 17 additions & 0 deletions jina/flow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import re
import threading
import uuid
from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from typing import Optional, Union, Tuple, List, Set, Dict, TextIO, TypeVar
Expand Down Expand Up @@ -733,6 +734,22 @@ def _update_client(self):
if self._pod_nodes['gateway'].args.restful:
self._cls_client = WebSocketClient

@property
def workspace_id(self) -> Dict[str, str]:
"""Get all Pods' ``workspace_id`` values in a dict """
return {k: p.args.workspace_id for k, p in self if hasattr(p.args, 'workspace_id')}

@workspace_id.setter
def workspace_id(self, value: str):
"""Set all Pods' ``workspace_id`` to ``value``
:param value: a hexadecimal UUID string
"""
uuid.UUID(value)
for k, p in self:
if hasattr(p.args, 'workspace_id'):
p.args.workspace_id = value

def index(self):
raise NotImplementedError

Expand Down
34 changes: 27 additions & 7 deletions jina/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,37 @@
'typename', 'get_public_ip', 'get_internal_ip', 'convert_tuple_to_list',
'run_async', 'deprecated_alias']

from jina.excepts import NotSupportedError


def deprecated_alias(**aliases):
"""
Usage, kwargs with key as the deprecated arg name and value be a tuple,
(new_name, deprecate_level). With level 0 means warning, level 1 means exception
Example:
@deprecated_alias(buffer=('input_fn', 0), callback=('on_done', 1), output_fn=('on_done', 1))
"""

def rename_kwargs(func_name: str, kwargs, aliases):
for alias, new in aliases.items():
for alias, new_arg in aliases.items():
if not isinstance(new_arg, tuple):
raise ValueError(f'{new_arg} must be a tuple, with first element as the new name, '
f'second element as the deprecated level: 0 as warning, 1 as exception')
if alias in kwargs:
if new in kwargs:
raise TypeError(f'{func_name} received both {alias} and {new}')
warnings.warn(
f'"{alias}" is renamed to {new}" in "{func_name}()" '
f'and "{alias}" will be removed in the next version', DeprecationWarning)
kwargs[new] = kwargs.pop(alias)
new_name, dep_level = new_arg
if new_name in kwargs:
raise NotSupportedError(f'{func_name} received both {alias} and {new_name}')

if dep_level == 0:
warnings.warn(
f'`{alias}` is renamed to `{new_name}` in `{func_name}()`, the usage of `{alias}` is '
f'deprecated and will be removed in the next version.',
DeprecationWarning)
kwargs[new_name] = kwargs.pop(alias)
elif dep_level == 1:
raise NotSupportedError(f'{alias} has been renamed to `{new_name}`')

def deco(f):
@functools.wraps(f)
Expand Down
21 changes: 19 additions & 2 deletions tests/unit/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from jina import NdArray, Request
from jina.clients.helper import _safe_callback, pprint_routes
from jina.drivers.querylang.queryset.dunderkey import dunder_get
from jina.excepts import BadClientCallback
from jina.helper import cached_property, convert_tuple_to_list
from jina.excepts import BadClientCallback, NotSupportedError
from jina.helper import cached_property, convert_tuple_to_list, deprecated_alias
from jina.jaml.helper import complete_path
from jina.logging import default_logger
from jina.logging.profile import TimeContext
Expand Down Expand Up @@ -192,3 +192,20 @@ def test_complete_path_success():
def test_complete_path_not_found():
with pytest.raises(FileNotFoundError):
assert complete_path('unknown.yaml')


def test_deprecated_decor():
@deprecated_alias(barbar=('bar', 0), foofoo=('foo', 1))
def dummy(bar, foo):
return bar, foo

# normal
assert dummy(bar=1, foo=2) == (1, 2)

# deprecated warn
with pytest.deprecated_call():
assert dummy(barbar=1, foo=2) == (1, 2)

# deprecated HARD
with pytest.raises(NotSupportedError):
dummy(bar=1, foofoo=2)

0 comments on commit 780d626

Please sign in to comment.