Skip to content

Commit

Permalink
feat(jaml): enables loading config from dict (#1581)
Browse files Browse the repository at this point in the history
* test: show 1517 is doable already

* feat(jaml): allows reading from dict #1505
  • Loading branch information
hanxiao authored Jan 3, 2021
1 parent 7d42f28 commit 8ac6a63
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 6 deletions.
6 changes: 3 additions & 3 deletions jina/jaml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def load(stream,
def load_no_tags(stream, **kwargs):
"""Load yaml object but ignore all customized tags, e.g. !Executor, !Driver, !Flow
"""
safe_yml = '\n'.join(v if not re.match(r'^[\s-]*?!\b', v) else v.replace('!', '__tag: ') for v in stream)
safe_yml = '\n'.join(v if not re.match(r'^[\s-]*?!\b', v) else v.replace('!', '__cls: ') for v in stream)
return JAML.load(safe_yml, **kwargs)

@staticmethod
Expand Down Expand Up @@ -321,7 +321,7 @@ def save_config(self, filename: Optional[str] = None):

@classmethod
def load_config(cls,
source: Union[str, TextIO], *,
source: Union[str, TextIO, Dict], *,
allow_py_modules: bool = True,
substitute: bool = True,
context: Dict[str, Any] = None,
Expand Down Expand Up @@ -390,7 +390,7 @@ def load_config(cls,
load_py_modules(no_tag_yml, extra_search_paths=(os.path.dirname(s_path),) if s_path else None)

# revert yaml's tag and load again, this time with substitution
revert_tag_yml = JAML.dump(no_tag_yml).replace('__tag: ', '!')
revert_tag_yml = JAML.dump(no_tag_yml).replace('__cls: ', '!')

# load into object, no more substitute
return JAML.load(revert_tag_yml, substitute=False)
Expand Down
22 changes: 19 additions & 3 deletions jina/jaml/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from yaml.resolver import Resolver
from yaml.scanner import Scanner

import json
from jina.excepts import BadConfigSource
from jina.importer import PathImporter

Expand Down Expand Up @@ -75,22 +76,29 @@ def __init__(self, stream):
JinaResolver.yaml_implicit_resolvers.pop('O')


def parse_config_source(path: Union[str, TextIO],
def parse_config_source(path: Union[str, TextIO, Dict],
allow_stream: bool = True,
allow_yaml_file: bool = True,
allow_builtin_resource: bool = True,
allow_raw_yaml_content: bool = True,
allow_raw_driver_yaml_content: bool = True,
allow_class_type: bool = True, *args, **kwargs) -> Tuple[TextIO, Optional[str]]:
allow_class_type: bool = True,
allow_dict: bool = True,
allow_json: bool = True,
*args, **kwargs) -> Tuple[TextIO, Optional[str]]:
""" Check if the text or text stream is valid
:return: a tuple, the first element is the text stream, the second element is the file path associate to it
if available.
"""
import io
from pkg_resources import resource_filename, resource_stream
from pkg_resources import resource_filename
if not path:
raise BadConfigSource
elif allow_dict and isinstance(path, dict):
from . import JAML
tmp = JAML.dump(path)
return io.StringIO(tmp), None
elif allow_stream and hasattr(path, 'read'):
# already a readable stream
return path, None
Expand All @@ -117,6 +125,14 @@ def parse_config_source(path: Union[str, TextIO],
elif allow_class_type and path.isidentifier():
# possible class name
return io.StringIO(f'!{path}'), None
elif allow_json and isinstance(path, str):
try:
from . import JAML
tmp = json.loads(path)
tmp = JAML.dump(tmp)
return io.StringIO(tmp), None
except json.JSONDecodeError:
raise BadConfigSource
else:
raise BadConfigSource(f'{path} can not be resolved, it should be a readable stream,'
' or a valid file path, or a supported class name.')
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/flow/test_flow_yaml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jina import Flow, AsyncFlow
from jina.enums import FlowOptimizeLevel
from jina.excepts import BadFlowYAMLVersion
from jina.executors.encoders import BaseEncoder
from jina.flow import BaseFlow
from jina.jaml import JAML
from jina.jaml.parsers import get_supported_versions
Expand Down Expand Up @@ -78,6 +79,7 @@ def test_load_flow_from_yaml():
with open(cur_dir.parent / 'yaml' / 'test-flow.yml') as fp:
a = Flow.load_config(fp)


def test_flow_yaml_dump():
f = Flow(logserver_config=str(cur_dir.parent / 'yaml' / 'test-server-config.yml'),
optimize_level=FlowOptimizeLevel.IGNORE_GATEWAY,
Expand All @@ -88,3 +90,27 @@ def test_flow_yaml_dump():
assert f.args.logserver_config == fl.args.logserver_config
assert f.args.optimize_level == fl.args.optimize_level
rm_files(['test1.yml'])


def test_flow_yaml_from_string():
f1 = Flow.load_config('yaml/flow-v1.0-syntax.yml')
with open(str(cur_dir / 'yaml' / 'flow-v1.0-syntax.yml')) as fp:
str_yaml = fp.read()
assert isinstance(str_yaml, str)
f2 = Flow.load_config(str_yaml)
assert f1 == f2

f3 = Flow.load_config('!Flow\nversion: 1.0\npods: [{name: ppp0, uses: _merge}, name: aaa1]')
assert 'ppp0' in f3._pod_nodes.keys()
assert 'aaa1' in f3._pod_nodes.keys()
assert f3.num_pods == 2


def test_flow_uses_from_dict():
class DummyEncoder(BaseEncoder):
pass

d1 = {'__cls': 'DummyEncoder',
'metas': {'name': 'dummy1'}}
with Flow().add(uses=d1):
pass
57 changes: 57 additions & 0 deletions tests/unit/test_yamlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from jina.enums import SocketType
from jina.executors import BaseExecutor
from jina.executors.compound import CompoundExecutor
from jina.executors.indexers.vector import NumpyIndexer
from jina.executors.metas import fill_metas_with_defaults
from jina.helper import expand_dict
Expand Down Expand Up @@ -178,3 +179,59 @@ def test_encoder_name_dict_replace():
def test_encoder_inject_config_via_kwargs():
with BaseExecutor.load_config('yaml/test-encoder-env.yml', pea_id=345) as be:
assert be.pea_id == 345


def test_load_from_dict():
# !BaseEncoder
# metas:
# name: ${{BE_TEST_NAME}}
# batch_size: ${{BATCH_SIZE}}
# pea_id: ${{pea_id}}
# workspace: ${{this.name}}-${{this.batch_size}}

d1 = {
'__cls': 'BaseEncoder',
'metas': {'name': '${{BE_TEST_NAME}}',
'batch_size': '${{BATCH_SIZE}}',
'pea_id': '${{pea_id}}',
'workspace': '${{this.name}} -${{this.batch_size}}'}
}

# !CompoundExecutor
# components:
# - !BinaryPbIndexer
# with:
# index_filename: tmp1
# metas:
# name: test1
# - !BinaryPbIndexer
# with:
# index_filename: tmp2
# metas:
# name: test2
# metas:
# name: compound1

d2 = {
'__cls': 'CompoundExecutor',
'components':
[
{
'__cls': 'BinaryPbIndexer',
'with': {'index_filename': 'tmp1'},
'metas': {'name': 'test1'}
},
{
'__cls': 'BinaryPbIndexer',
'with': {'index_filename': 'tmp2'},
'metas': {'name': 'test2'}
},
]
}
d = {'BE_TEST_NAME': 'hello123', 'BATCH_SIZE': 256}
b1 = BaseExecutor.load_config(d1, context=d)
b2 = BaseExecutor.load_config(d2, context=d)
assert isinstance(b1, BaseExecutor)
assert isinstance(b2, CompoundExecutor)
assert b1.batch_size == 256
assert b1.name == 'hello123'

0 comments on commit 8ac6a63

Please sign in to comment.