Skip to content

Commit

Permalink
fix(helper): fix yaml path check (#1850)
Browse files Browse the repository at this point in the history
* fix(helper): fix yaml path check
  • Loading branch information
hanxiao authored Feb 3, 2021
1 parent 8ca3ba9 commit 5c68831
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
5 changes: 5 additions & 0 deletions jina/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,8 @@ def change_env(key, val):
os.environ[key] = old_var
else:
os.environ.pop(key)


def is_yaml_filepath(val) -> bool:
r = r'^[/\w\-\_\.]+.ya?ml$'
return re.match(r, val.strip()) is not None
5 changes: 3 additions & 2 deletions jina/jaml/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from yaml.scanner import Scanner

from jina.excepts import BadConfigSource
from jina.helper import is_yaml_filepath
from jina.importer import PathImporter


Expand Down Expand Up @@ -102,7 +103,7 @@ def parse_config_source(path: Union[str, TextIO, Dict],
elif allow_stream and hasattr(path, 'read'):
# already a readable stream
return path, None
elif allow_yaml_file and (path.rstrip().endswith('.yml') or path.rstrip().endswith('.yaml')):
elif allow_yaml_file and is_yaml_filepath(path):
comp_path = complete_path(path)
return open(comp_path, encoding='utf8'), comp_path
elif allow_builtin_resource and path.lstrip().startswith('_') and os.path.exists(
Expand Down Expand Up @@ -133,7 +134,7 @@ def parse_config_source(path: Union[str, TextIO, Dict],
tmp = JAML.dump(tmp)
return io.StringIO(tmp), None
except json.JSONDecodeError:
raise BadConfigSource
raise BadConfigSource(path)
else:
raise BadConfigSource(f'{path} can not be resolved, it should be a readable stream,'
' or a valid file path, or a supported class name.')
Expand Down
6 changes: 3 additions & 3 deletions jina/peapods/runtimes/jinad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..asyncio.base import AsyncZMQRuntime
from ...zmq import Zmqlet
from ....excepts import DaemonConnectivityError
from ....helper import cached_property, colored
from ....helper import cached_property, colored, is_yaml_filepath


class JinadRuntime(AsyncZMQRuntime):
Expand Down Expand Up @@ -59,10 +59,10 @@ def teardown(self):
def _remote_id(self) -> Optional[str]:
if self.api.is_alive:
upload_files = []
if self.args.uses.endswith('.yml') or self.args.uses.endswith('.yaml'):
if is_yaml_filepath(self.args.uses):
upload_files.append(self.args.uses)

if self.args.uses_internal.endswith('.yml') or self.args.uses_internal.endswith('.yaml'):
if is_yaml_filepath(self.args.uses_internal):
upload_files.append(self.args.uses_internal)

if self.args.upload_files:
Expand Down
31 changes: 30 additions & 1 deletion tests/unit/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jina.clients.helper import _safe_callback, pprint_routes
from jina.drivers.querylang.queryset.dunderkey import dunder_get
from jina.excepts import BadClientCallback, NotSupportedError
from jina.helper import cached_property, convert_tuple_to_list, deprecated_alias
from jina.helper import cached_property, convert_tuple_to_list, deprecated_alias, is_yaml_filepath
from jina.jaml.helper import complete_path
from jina.logging import default_logger
from jina.logging.profile import TimeContext
Expand Down Expand Up @@ -209,3 +209,32 @@ def dummy(bar, foo):
# deprecated HARD
with pytest.raises(NotSupportedError):
dummy(bar=1, foofoo=2)


@pytest.mark.parametrize('val', ['merge_and_topk.yml',
'merge_and_topk.yaml',
'da.yaml',
'd.yml',
'/da/da.yml',
'das/das.yaml',
'1234.yml',
'1234.yml ',
' 1234.yml '])
def test_yaml_filepath_validate_good(val):
assert is_yaml_filepath(val)


@pytest.mark.parametrize('val', [' .yml',
'a',
' uses: yaml',
'ayaml',
'''
shards: $JINA_SHARDS_INDEXERS
host: $JINA_REDIS_INDEXER_HOST
port_expose: 8000
polling: all
timeout_ready: 100000 # larger timeout as in query time will read all the data
uses_after: merge_and_topk.yml
'''])
def test_yaml_filepath_validate_bad(val):
assert not is_yaml_filepath(val)

0 comments on commit 5c68831

Please sign in to comment.