diff --git a/jina/jaml/__init__.py b/jina/jaml/__init__.py index 0347db7c4cd16..2b2967efc20a1 100644 --- a/jina/jaml/__init__.py +++ b/jina/jaml/__init__.py @@ -91,7 +91,7 @@ def load(stream, def load_no_tags(stream, **kwargs): """Load yaml object but ignore all customized tags, e.g. !Executor, !Driver, !Flow """ - safe_yml = '\n'.join(v if not re.match(r'^[\s-]*?!\b', v) else v.replace('!', '__tag: ') for v in stream) + safe_yml = '\n'.join(v if not re.match(r'^[\s-]*?!\b', v) else v.replace('!', '__cls: ') for v in stream) return JAML.load(safe_yml, **kwargs) @staticmethod @@ -321,7 +321,7 @@ def save_config(self, filename: Optional[str] = None): @classmethod def load_config(cls, - source: Union[str, TextIO], *, + source: Union[str, TextIO, Dict], *, allow_py_modules: bool = True, substitute: bool = True, context: Dict[str, Any] = None, @@ -390,7 +390,7 @@ def load_config(cls, load_py_modules(no_tag_yml, extra_search_paths=(os.path.dirname(s_path),) if s_path else None) # revert yaml's tag and load again, this time with substitution - revert_tag_yml = JAML.dump(no_tag_yml).replace('__tag: ', '!') + revert_tag_yml = JAML.dump(no_tag_yml).replace('__cls: ', '!') # load into object, no more substitute return JAML.load(revert_tag_yml, substitute=False) diff --git a/jina/jaml/helper.py b/jina/jaml/helper.py index 841215e1f1aad..4054920dae3ea 100644 --- a/jina/jaml/helper.py +++ b/jina/jaml/helper.py @@ -10,6 +10,7 @@ from yaml.resolver import Resolver from yaml.scanner import Scanner +import json from jina.excepts import BadConfigSource from jina.importer import PathImporter @@ -75,22 +76,29 @@ def __init__(self, stream): JinaResolver.yaml_implicit_resolvers.pop('O') -def parse_config_source(path: Union[str, TextIO], +def parse_config_source(path: Union[str, TextIO, Dict], allow_stream: bool = True, allow_yaml_file: bool = True, allow_builtin_resource: bool = True, allow_raw_yaml_content: bool = True, allow_raw_driver_yaml_content: bool = True, - allow_class_type: bool = True, *args, **kwargs) -> Tuple[TextIO, Optional[str]]: + allow_class_type: bool = True, + allow_dict: bool = True, + allow_json: bool = True, + *args, **kwargs) -> Tuple[TextIO, Optional[str]]: """ Check if the text or text stream is valid :return: a tuple, the first element is the text stream, the second element is the file path associate to it if available. """ import io - from pkg_resources import resource_filename, resource_stream + from pkg_resources import resource_filename if not path: raise BadConfigSource + elif allow_dict and isinstance(path, dict): + from . import JAML + tmp = JAML.dump(path) + return io.StringIO(tmp), None elif allow_stream and hasattr(path, 'read'): # already a readable stream return path, None @@ -117,6 +125,14 @@ def parse_config_source(path: Union[str, TextIO], elif allow_class_type and path.isidentifier(): # possible class name return io.StringIO(f'!{path}'), None + elif allow_json and isinstance(path, str): + try: + from . import JAML + tmp = json.loads(path) + tmp = JAML.dump(tmp) + return io.StringIO(tmp), None + except json.JSONDecodeError: + raise BadConfigSource else: raise BadConfigSource(f'{path} can not be resolved, it should be a readable stream,' ' or a valid file path, or a supported class name.') diff --git a/tests/unit/flow/test_flow_yaml_parser.py b/tests/unit/flow/test_flow_yaml_parser.py index ef790e6062e06..4929a94e8d4dc 100644 --- a/tests/unit/flow/test_flow_yaml_parser.py +++ b/tests/unit/flow/test_flow_yaml_parser.py @@ -6,6 +6,7 @@ from jina import Flow, AsyncFlow from jina.enums import FlowOptimizeLevel from jina.excepts import BadFlowYAMLVersion +from jina.executors.encoders import BaseEncoder from jina.flow import BaseFlow from jina.jaml import JAML from jina.jaml.parsers import get_supported_versions @@ -78,6 +79,7 @@ def test_load_flow_from_yaml(): with open(cur_dir.parent / 'yaml' / 'test-flow.yml') as fp: a = Flow.load_config(fp) + def test_flow_yaml_dump(): f = Flow(logserver_config=str(cur_dir.parent / 'yaml' / 'test-server-config.yml'), optimize_level=FlowOptimizeLevel.IGNORE_GATEWAY, @@ -88,3 +90,27 @@ def test_flow_yaml_dump(): assert f.args.logserver_config == fl.args.logserver_config assert f.args.optimize_level == fl.args.optimize_level rm_files(['test1.yml']) + + +def test_flow_yaml_from_string(): + f1 = Flow.load_config('yaml/flow-v1.0-syntax.yml') + with open(str(cur_dir / 'yaml' / 'flow-v1.0-syntax.yml')) as fp: + str_yaml = fp.read() + assert isinstance(str_yaml, str) + f2 = Flow.load_config(str_yaml) + assert f1 == f2 + + f3 = Flow.load_config('!Flow\nversion: 1.0\npods: [{name: ppp0, uses: _merge}, name: aaa1]') + assert 'ppp0' in f3._pod_nodes.keys() + assert 'aaa1' in f3._pod_nodes.keys() + assert f3.num_pods == 2 + + +def test_flow_uses_from_dict(): + class DummyEncoder(BaseEncoder): + pass + + d1 = {'__cls': 'DummyEncoder', + 'metas': {'name': 'dummy1'}} + with Flow().add(uses=d1): + pass diff --git a/tests/unit/test_yamlparser.py b/tests/unit/test_yamlparser.py index b7399cceba11a..8039b81648df4 100644 --- a/tests/unit/test_yamlparser.py +++ b/tests/unit/test_yamlparser.py @@ -6,6 +6,7 @@ from jina.enums import SocketType from jina.executors import BaseExecutor +from jina.executors.compound import CompoundExecutor from jina.executors.indexers.vector import NumpyIndexer from jina.executors.metas import fill_metas_with_defaults from jina.helper import expand_dict @@ -178,3 +179,59 @@ def test_encoder_name_dict_replace(): def test_encoder_inject_config_via_kwargs(): with BaseExecutor.load_config('yaml/test-encoder-env.yml', pea_id=345) as be: assert be.pea_id == 345 + + +def test_load_from_dict(): + # !BaseEncoder + # metas: + # name: ${{BE_TEST_NAME}} + # batch_size: ${{BATCH_SIZE}} + # pea_id: ${{pea_id}} + # workspace: ${{this.name}}-${{this.batch_size}} + + d1 = { + '__cls': 'BaseEncoder', + 'metas': {'name': '${{BE_TEST_NAME}}', + 'batch_size': '${{BATCH_SIZE}}', + 'pea_id': '${{pea_id}}', + 'workspace': '${{this.name}} -${{this.batch_size}}'} + } + + # !CompoundExecutor + # components: + # - !BinaryPbIndexer + # with: + # index_filename: tmp1 + # metas: + # name: test1 + # - !BinaryPbIndexer + # with: + # index_filename: tmp2 + # metas: + # name: test2 + # metas: + # name: compound1 + + d2 = { + '__cls': 'CompoundExecutor', + 'components': + [ + { + '__cls': 'BinaryPbIndexer', + 'with': {'index_filename': 'tmp1'}, + 'metas': {'name': 'test1'} + }, + { + '__cls': 'BinaryPbIndexer', + 'with': {'index_filename': 'tmp2'}, + 'metas': {'name': 'test2'} + }, + ] + } + d = {'BE_TEST_NAME': 'hello123', 'BATCH_SIZE': 256} + b1 = BaseExecutor.load_config(d1, context=d) + b2 = BaseExecutor.load_config(d2, context=d) + assert isinstance(b1, BaseExecutor) + assert isinstance(b2, CompoundExecutor) + assert b1.batch_size == 256 + assert b1.name == 'hello123'