Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AST resolver now supports some local import cases #403

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Added
- Support for on parse argument links with target subclasses in a list (`#394
<https://github.com/omni-us/jsonargparse/issues/394>`__, `lightning#18161
<https://github.com/Lightning-AI/lightning/issues/18161>`__).
- AST resolver now supports some local import cases.

Fixed
^^^^^
Expand Down
5 changes: 5 additions & 0 deletions DOCUMENTATION.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,11 @@ unrelated to these variables.
SomeClass.a_class_method(*args, **kwargs)


def calls_local_import(**kwargs):
import some_module
some_module.a_callable(**kwargs)


def pops_from_kwargs(**kwargs):
val = kwargs.pop("name", "default")

Expand Down
2 changes: 1 addition & 1 deletion jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def error(self, message: str, ex: Optional[Exception] = None) -> NoReturn:
elif debug_mode_active():
self._logger.debug("Debug enabled, thus raising exception instead of exit.")
raise argument_error(message) from ex
self.print_usage()
self.print_usage(sys.stderr)
sys.stderr.write(f"error: {message}\n")
self.exit(2)

Expand Down
65 changes: 51 additions & 14 deletions jsonargparse/_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,11 +506,11 @@ def visit_Assign(self, node):
do_generic_visit = True
for key, value in self.find_values.items():
if ast_is_assign_with_value(node, value):
self.values_found.append((key, node))
self.add_value(key, node)
do_generic_visit = False
break
elif ast_is_dict_assign_with_value(node, value):
self.values_found.append((key, node))
self.add_value(key, node)
do_generic_visit = False
if do_generic_visit:
if ast_is_dict_assign(node):
Expand All @@ -531,11 +531,11 @@ def visit_Call(self, node):
if isinstance(node.func, ast.Attribute):
value_dump = ast.dump(node.func.value)
if value_dump in self.dict_assigns:
self.values_found.append((key, self.dict_assigns[value_dump]))
self.add_value(key, self.dict_assigns[value_dump])
continue
self.values_found.append((key, node))
self.add_value(key, node)
elif ast_is_kwargs_pop_or_get(node, value_dump):
self.values_found.append((key, node))
self.add_value(key, node)
self.generic_visit(node)

def visit_If(self, node):
Expand All @@ -550,27 +550,64 @@ def visit_If(self, node):
node = ast.If(test=ast.Constant(value=True), body=body, orelse=[])
self.generic_visit(node)

def visit_Import(self, node: Union[ast.Import, ast.ImportFrom]) -> None:
for alias in node.names:
self.import_names[alias.asname or alias.name] = node

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.visit_Import(node)

def add_value(self, key, node):
source = None
if isinstance(node, ast.Call):
name = False
if isinstance(node.func, ast.Name):
name = node.func.id
elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
name = node.func.value.id
if name and name in self.import_names:
source = self.import_names[name]
self.values_found.append((key, node, source))

def find_values_usage(self, values):
self.find_values = values
self.values_found = []
self.dict_assigns = {}
self.import_names = {}
self.visit(self.component_node)
return self.values_found

def get_node_component(self, node) -> Optional[Tuple[Type, Optional[str]]]:
def get_component_from_source(self, name, source):
aliases = {}
ast_exec = ast.parse("")
ast_exec.body = [source]
try:
exec(compile(ast_exec, filename="<ast>", mode="exec"), aliases, aliases)
except Exception as ex:
if self.logger:
self.logger.debug(f"Failed to get '{name}' from '{ast_str(source)}'", exc_info=ex)
return aliases.get(name)

def get_node_component(self, node, source) -> Optional[Tuple[Type, Optional[str]]]:
function_or_class = method_or_property = None
module = inspect.getmodule(self.component)
if isinstance(node.func, ast.Name):
if is_classmethod(self.parent, self.component) and node.func.id == self.self_name:
function_or_class = self.parent
elif hasattr(module, node.func.id):
function_or_class = getattr(module, node.func.id)
elif source:
function_or_class = self.get_component_from_source(node.func.id, source)
elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
if self.parent and ast.dump(node.func.value) == ast.dump(ast_variable_load(self.self_name)):
function_or_class = self.parent
method_or_property = node.func.attr
elif hasattr(module, node.func.value.id):
container = getattr(module, node.func.value.id)
else:
container = None
if hasattr(module, node.func.value.id):
container = getattr(module, node.func.value.id)
elif source:
container = self.get_component_from_source(node.func.value.id, source)
if inspect.isclass(container):
function_or_class = container
method_or_property = node.func.attr
Expand All @@ -581,7 +618,7 @@ def get_node_component(self, node) -> Optional[Tuple[Type, Optional[str]]]:
return None
return function_or_class, method_or_property

def match_call_that_uses_attr(self, node, attr_name):
def match_call_that_uses_attr(self, node, source, attr_name):
params = None
if isinstance(node, ast.Call):
params = []
Expand All @@ -591,7 +628,7 @@ def match_call_that_uses_attr(self, node, attr_name):
if kwarg.arg:
self.log_debug(f"kwargs attribute given as keyword parameter not supported: {ast_str(node)}")
else:
get_param_args = self.get_node_component(node)
get_param_args = self.get_node_component(node, source)
if get_param_args:
try:
params = get_signature_parameters(*get_param_args, logger=self.logger)
Expand Down Expand Up @@ -686,7 +723,7 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]:
params_list = []
kwargs_value = kwargs_name and values_to_find[kwargs_name]
kwargs_value_dump = kwargs_value and ast.dump(kwargs_value)
for node in [v for k, v in values_found if k == kwargs_name]:
for node, source in [(v, s) for k, v, s in values_found if k == kwargs_name]:
if isinstance(node, ast.Call):
if ast_is_kwargs_pop_or_get(node, kwargs_value_dump):
param = self.get_kwargs_pop_or_get_parameter(node, self.component, self.parent, self.doc_params)
Expand All @@ -704,7 +741,7 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]:
self.logger,
)
else:
get_param_args = self.get_node_component(node)
get_param_args = self.get_node_component(node, source)
if get_param_args:
params = get_signature_parameters(*get_param_args, logger=self.logger)
params = remove_given_parameters(node, params)
Expand Down Expand Up @@ -757,8 +794,8 @@ def get_parameters_call_attr(self, attr_name: str, attr_value: ast.AST) -> Optio
values_found = self.find_values_usage(values_to_find)
matched = []
if values_found:
for _, node in values_found:
match = self.match_call_that_uses_attr(node, attr_name)
for _, node, source in values_found:
match = self.match_call_that_uses_attr(node, source, attr_name)
if match:
self.add_node_origins(match, node)
matched.append(match)
Expand Down
79 changes: 79 additions & 0 deletions jsonargparse_tests/test_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,56 @@
return self.fn(**self._kwd)


class AttributeLocalImport1:
def __init__(self, **kwargs):
self._kwd = dict(**kwargs)

def run(self):
from jsonargparse import set_loader

return set_loader(**self._kwd)


class AttributeLocalImport2:
def __init__(self, **kwargs):
self._kwd = dict(**kwargs)

def run(self):
import jsonargparse as ja
Dismissed Show dismissed Hide dismissed

return ja.set_loader(**self._kwd)


class AttributeLocalImport3:
def __init__(self, **kwargs):
self._kwd = dict(**kwargs)

def run(self):
from jsonargparse import set_loader

return set_loader(**self._kwd)


class AttributeLocalImport4:
def __init__(self, **kwargs):
self._kwd = dict(**kwargs)

def run(self):
from jsonargparse import set_loader as sl

return sl(**self._kwd)


class AttributeLocalImportFailure:
def __init__(self, **kwargs):
self._kwd = dict(**kwargs)

def run(self):
from jsonargparse import does_not_exist

return does_not_exist(**self._kwd)


class ClassF1:
def __init__(self, **kw):
self._ini = dict(k2=4)
Expand Down Expand Up @@ -421,6 +471,12 @@
return calendar.Calendar(**kwds)


def function_local_import(**kwds):
from jsonargparse import set_loader

return set_loader(**kwds)


constant_boolean_1 = True
constant_boolean_2 = False

Expand Down Expand Up @@ -527,6 +583,24 @@
assert_params(get_params(ClassF1), [])


@pytest.mark.parametrize(
"cls",
[AttributeLocalImport1, AttributeLocalImport2, AttributeLocalImport3, AttributeLocalImport4],
)
def test_get_params_local_import_with_kwargs_in_dict_attribute(cls):
params = get_params(cls)
assert ["mode", "loader_fn", "exceptions"] == [p.name for p in params]
with source_unavailable():
assert get_params(cls) == []


def test_get_params_local_import_failure_with_kwargs_in_dict_attribute(logger):
with capture_logs(logger) as logs:
params = get_params(AttributeLocalImportFailure, logger=logger)
assert params == []
assert "Failed to get 'does_not_exist'" in logs.getvalue()


def test_get_params_class_kwargs_in_attr_method_conditioned_on_arg():
params = get_params(ClassG)
assert_params(
Expand Down Expand Up @@ -706,6 +780,11 @@
assert ["firstweekday"] == [p.name for p in params]


def test_get_params_function_local_import():
params = get_params(function_local_import)
assert ["mode", "loader_fn", "exceptions"] == [p.name for p in params]


def test_get_params_function_constant_boolean():
assert_params(get_params(function_constant_boolean), ["k1", "pk1", "k2"])
with patch.dict(function_constant_boolean.__globals__, {"constant_boolean_1": False}):
Expand Down