Skip to content

Commit

Permalink
AST resolver now supports some local import cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa committed Oct 18, 2023
1 parent bbeb9d1 commit 4e23ac2
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Added
- Support for on parse argument links with target subclasses in a list (`#394
<https://github.com/omni-us/jsonargparse/issues/394>`__, `lightning#18161
<https://github.com/Lightning-AI/lightning/issues/18161>`__).
- AST resolver now supports some local import cases.

Fixed
^^^^^
Expand Down
5 changes: 5 additions & 0 deletions DOCUMENTATION.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,11 @@ unrelated to these variables.
SomeClass.a_class_method(*args, **kwargs)

def calls_local_import(**kwargs):
import some_module
some_module.a_callable(**kwargs)

def pops_from_kwargs(**kwargs):
val = kwargs.pop("name", "default")
Expand Down
2 changes: 1 addition & 1 deletion jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def error(self, message: str, ex: Optional[Exception] = None) -> NoReturn:
elif debug_mode_active():
self._logger.debug("Debug enabled, thus raising exception instead of exit.")
raise argument_error(message) from ex
self.print_usage()
self.print_usage(sys.stderr)
sys.stderr.write(f"error: {message}\n")
self.exit(2)

Expand Down
65 changes: 51 additions & 14 deletions jsonargparse/_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,11 +506,11 @@ def visit_Assign(self, node):
do_generic_visit = True
for key, value in self.find_values.items():
if ast_is_assign_with_value(node, value):
self.values_found.append((key, node))
self.add_value(key, node)
do_generic_visit = False
break
elif ast_is_dict_assign_with_value(node, value):
self.values_found.append((key, node))
self.add_value(key, node)
do_generic_visit = False
if do_generic_visit:
if ast_is_dict_assign(node):
Expand All @@ -531,11 +531,11 @@ def visit_Call(self, node):
if isinstance(node.func, ast.Attribute):
value_dump = ast.dump(node.func.value)
if value_dump in self.dict_assigns:
self.values_found.append((key, self.dict_assigns[value_dump]))
self.add_value(key, self.dict_assigns[value_dump])
continue
self.values_found.append((key, node))
self.add_value(key, node)
elif ast_is_kwargs_pop_or_get(node, value_dump):
self.values_found.append((key, node))
self.add_value(key, node)
self.generic_visit(node)

def visit_If(self, node):
Expand All @@ -550,27 +550,64 @@ def visit_If(self, node):
node = ast.If(test=ast.Constant(value=True), body=body, orelse=[])
self.generic_visit(node)

def visit_Import(self, node: Union[ast.Import, ast.ImportFrom]) -> None:
for alias in node.names:
self.import_names[alias.asname or alias.name] = node

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.visit_Import(node)

def add_value(self, key, node):
source = None
if isinstance(node, ast.Call):
name = False
if isinstance(node.func, ast.Name):
name = node.func.id
elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
name = node.func.value.id
if name and name in self.import_names:
source = self.import_names[name]
self.values_found.append((key, node, source))

def find_values_usage(self, values):
self.find_values = values
self.values_found = []
self.dict_assigns = {}
self.import_names = {}
self.visit(self.component_node)
return self.values_found

def get_node_component(self, node) -> Optional[Tuple[Type, Optional[str]]]:
def get_component_from_source(self, name, source):
aliases = {}
ast_exec = ast.parse("")
ast_exec.body = [source]
try:
exec(compile(ast_exec, filename="<ast>", mode="exec"), aliases, aliases)
except Exception as ex:
if self.logger:
self.logger.debug(f"Failed to get '{name}' from '{ast_str(source)}'", exc_info=ex)
return aliases.get(name)

def get_node_component(self, node, source) -> Optional[Tuple[Type, Optional[str]]]:
function_or_class = method_or_property = None
module = inspect.getmodule(self.component)
if isinstance(node.func, ast.Name):
if is_classmethod(self.parent, self.component) and node.func.id == self.self_name:
function_or_class = self.parent
elif hasattr(module, node.func.id):
function_or_class = getattr(module, node.func.id)
elif source:
function_or_class = self.get_component_from_source(node.func.id, source)
elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
if self.parent and ast.dump(node.func.value) == ast.dump(ast_variable_load(self.self_name)):
function_or_class = self.parent
method_or_property = node.func.attr
elif hasattr(module, node.func.value.id):
container = getattr(module, node.func.value.id)
else:
container = None
if hasattr(module, node.func.value.id):
container = getattr(module, node.func.value.id)
elif source:
container = self.get_component_from_source(node.func.value.id, source)
if inspect.isclass(container):
function_or_class = container
method_or_property = node.func.attr
Expand All @@ -581,7 +618,7 @@ def get_node_component(self, node) -> Optional[Tuple[Type, Optional[str]]]:
return None
return function_or_class, method_or_property

def match_call_that_uses_attr(self, node, attr_name):
def match_call_that_uses_attr(self, node, source, attr_name):
params = None
if isinstance(node, ast.Call):
params = []
Expand All @@ -591,7 +628,7 @@ def match_call_that_uses_attr(self, node, attr_name):
if kwarg.arg:
self.log_debug(f"kwargs attribute given as keyword parameter not supported: {ast_str(node)}")
else:
get_param_args = self.get_node_component(node)
get_param_args = self.get_node_component(node, source)
if get_param_args:
try:
params = get_signature_parameters(*get_param_args, logger=self.logger)
Expand Down Expand Up @@ -686,7 +723,7 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]:
params_list = []
kwargs_value = kwargs_name and values_to_find[kwargs_name]
kwargs_value_dump = kwargs_value and ast.dump(kwargs_value)
for node in [v for k, v in values_found if k == kwargs_name]:
for node, source in [(v, s) for k, v, s in values_found if k == kwargs_name]:
if isinstance(node, ast.Call):
if ast_is_kwargs_pop_or_get(node, kwargs_value_dump):
param = self.get_kwargs_pop_or_get_parameter(node, self.component, self.parent, self.doc_params)
Expand All @@ -704,7 +741,7 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]:
self.logger,
)
else:
get_param_args = self.get_node_component(node)
get_param_args = self.get_node_component(node, source)
if get_param_args:
params = get_signature_parameters(*get_param_args, logger=self.logger)
params = remove_given_parameters(node, params)
Expand Down Expand Up @@ -757,8 +794,8 @@ def get_parameters_call_attr(self, attr_name: str, attr_value: ast.AST) -> Optio
values_found = self.find_values_usage(values_to_find)
matched = []
if values_found:
for _, node in values_found:
match = self.match_call_that_uses_attr(node, attr_name)
for _, node, source in values_found:
match = self.match_call_that_uses_attr(node, source, attr_name)
if match:
self.add_node_origins(match, node)
matched.append(match)
Expand Down
79 changes: 79 additions & 0 deletions jsonargparse_tests/test_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,56 @@ def start(self):
return self.fn(**self._kwd)


class AttributeLocalImport1:
def __init__(self, **kwargs):
self._kwd = dict(**kwargs)

def run(self):
from jsonargparse import set_loader

return set_loader(**self._kwd)


class AttributeLocalImport2:
def __init__(self, **kwargs):
self._kwd = dict(**kwargs)

def run(self):
import jsonargparse as ja

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'jsonargparse' is imported with both 'import' and 'import from'.

return ja.set_loader(**self._kwd)


class AttributeLocalImport3:
def __init__(self, **kwargs):
self._kwd = dict(**kwargs)

def run(self):
from jsonargparse import set_loader

return set_loader(**self._kwd)


class AttributeLocalImport4:
def __init__(self, **kwargs):
self._kwd = dict(**kwargs)

def run(self):
from jsonargparse import set_loader as sl

return sl(**self._kwd)


class AttributeLocalImportFailure:
def __init__(self, **kwargs):
self._kwd = dict(**kwargs)

def run(self):
from jsonargparse import does_not_exist

return does_not_exist(**self._kwd)


class ClassF1:
def __init__(self, **kw):
self._ini = dict(k2=4)
Expand Down Expand Up @@ -421,6 +471,12 @@ def function_module_class(**kwds):
return calendar.Calendar(**kwds)


def function_local_import(**kwds):
from jsonargparse import set_loader

return set_loader(**kwds)


constant_boolean_1 = True
constant_boolean_2 = False

Expand Down Expand Up @@ -527,6 +583,24 @@ def test_get_params_class_with_kwargs_in_dict_attribute():
assert_params(get_params(ClassF1), [])


@pytest.mark.parametrize(
"cls",
[AttributeLocalImport1, AttributeLocalImport2, AttributeLocalImport3, AttributeLocalImport4],
)
def test_get_params_local_import_with_kwargs_in_dict_attribute(cls):
params = get_params(cls)
assert ["mode", "loader_fn", "exceptions"] == [p.name for p in params]
with source_unavailable():
assert get_params(cls) == []


def test_get_params_local_import_failure_with_kwargs_in_dict_attribute(logger):
with capture_logs(logger) as logs:
params = get_params(AttributeLocalImportFailure, logger=logger)
assert params == []
assert "Failed to get 'does_not_exist'" in logs.getvalue()


def test_get_params_class_kwargs_in_attr_method_conditioned_on_arg():
params = get_params(ClassG)
assert_params(
Expand Down Expand Up @@ -706,6 +780,11 @@ def test_get_params_function_module_class():
assert ["firstweekday"] == [p.name for p in params]


def test_get_params_function_local_import():
params = get_params(function_local_import)
assert ["mode", "loader_fn", "exceptions"] == [p.name for p in params]


def test_get_params_function_constant_boolean():
assert_params(get_params(function_constant_boolean), ["k1", "pk1", "k2"])
with patch.dict(function_constant_boolean.__globals__, {"constant_boolean_1": False}):
Expand Down

0 comments on commit 4e23ac2

Please sign in to comment.