From 7aa5e5c9f383c2b30dcaa0a4873e7d5254466c02 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Wed, 25 Oct 2023 07:09:02 +0200 Subject: [PATCH] Fix unable link two deep level arguments sharing the same root class (#297). --- CHANGELOG.rst | 2 + jsonargparse/_common.py | 4 +- jsonargparse/_core.py | 8 +- jsonargparse/_link_arguments.py | 82 ++++++++++++++----- jsonargparse/_typehints.py | 7 +- jsonargparse_tests/test_link_arguments.py | 97 +++++++++++++++++++++++ 6 files changed, 176 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 18b805f5..b9ee5147 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -20,6 +20,8 @@ Fixed - Failure to parse subclass added via add_argument and required arg as link target. - ``choices`` working incorrectly when ``nargs`` is ``+``, ``*`` or number. +- Unable link two deep level arguments sharing the same root class (`#297 + `__). v4.26.1 (2023-10-23) diff --git a/jsonargparse/_common.py b/jsonargparse/_common.py index da59bce5..9c17c7ad 100644 --- a/jsonargparse/_common.py +++ b/jsonargparse/_common.py @@ -3,7 +3,7 @@ import sys from contextlib import contextmanager from contextvars import ContextVar -from typing import Dict, Generic, Optional, Tuple, Type, TypeVar, Union, _GenericAlias # type: ignore +from typing import Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, _GenericAlias # type: ignore from ._namespace import Namespace from ._type_checking import ArgumentParser @@ -31,6 +31,7 @@ def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType: lenient_check: ContextVar[Union[bool, str]] = ContextVar("lenient_check", default=False) load_value_mode: ContextVar[Optional[str]] = ContextVar("load_value_mode", default=None) class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators") +nested_links: ContextVar[List[dict]] = ContextVar("nested_links", default=[]) parser_context_vars = dict( @@ -40,6 +41,7 @@ def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType: lenient_check=lenient_check, load_value_mode=load_value_mode, class_instantiators=class_instantiators, + nested_links=nested_links, ) diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py index 43cc172f..0cad3fb6 100644 --- a/jsonargparse/_core.py +++ b/jsonargparse/_core.py @@ -955,7 +955,7 @@ def get_defaults(self, skip_check: bool = False) -> Namespace: ): cfg[action.dest] = recreate_branches(action.default) - self._logger.debug("Loaded default values from parser") + self._logger.debug("Loaded parser defaults: %s", cfg) default_config_files = self._get_default_config_files() for key, default_config_file in default_config_files: @@ -1165,7 +1165,11 @@ def instantiate_classes( pass else: if value is not None: - with parser_context(parent_parser=self, class_instantiators=self._get_instantiators()): + with parser_context( + parent_parser=self, + nested_links=ActionLink.get_nested_links(self, component), + class_instantiators=self._get_instantiators(), + ): parent[key] = component.instantiate_classes(value) else: with parser_context(load_value_mode=self.parser_mode, class_instantiators=self._get_instantiators()): diff --git a/jsonargparse/_link_arguments.py b/jsonargparse/_link_arguments.py index 831291be..a54acc76 100644 --- a/jsonargparse/_link_arguments.py +++ b/jsonargparse/_link_arguments.py @@ -118,13 +118,14 @@ def __init__( if not hasattr(parser, "_links_group"): parser._links_group = parser.add_argument_group("Linked arguments") self.parser = parser + self._target = target + self._source = source = (source,) if isinstance(source, str) else source self.apply_on = apply_on self.compute_fn = compute_fn self._initial_input_checks(source, target) # Set and check source actions or group exclude = (ActionLink, _ActionConfigLoad, _ActionSubCommands, ActionConfigFile) - source = (source,) if isinstance(source, str) else source if apply_on == "instantiate": self.source = [(s, find_subclass_action_or_class_group(parser, s, exclude=exclude)) for s in source] for key, action in self.source: @@ -219,13 +220,21 @@ def __init__( help=help_str, ) + def get_kwargs(self) -> dict: + return { + "source": self._source, + "target": self._target, + "apply_on": self.apply_on, + "compute_fn": self.compute_fn, + } + def _initial_input_checks(self, source, target): # Check apply_on if self.apply_on not in {"parse", "instantiate"}: raise ValueError("apply_on must be 'parse' or 'instantiate'.") # Check compute function - if self.compute_fn is None and not isinstance(source, str): + if self.compute_fn is None and not (isinstance(source, str) or len(source) == 1): raise ValueError("Multiple source keys requires a compute function.") if self.apply_on == "parse": @@ -266,21 +275,19 @@ def apply_parsing_links(parser: "ArgumentParser", cfg: Namespace) -> None: ActionLink.apply_parsing_links(subparser, cfg[subcommand]) # type: ignore if not hasattr(parser, "_links_group"): return - for action in parser._links_group._group_actions: - if action.apply_on != "parse": - continue + for action in get_link_actions(parser, "parse"): from ._typehints import ActionTypeHint args = [] skip_link = False for source_key, source_action in action.source: - if ActionTypeHint.is_subclass_typehint(source_action[0]) and source_key not in cfg: + if ActionTypeHint.is_subclass_typehint(source_action[0]) and source_key not in cfg: # type: ignore parser.logger.debug( f"Link '{action.option_strings[0]}' ignored since source '{source_key}' not found in namespace." ) skip_link = True break - for source_action_n in [a for a in source_action if a.dest in cfg]: + for source_action_n in [a for a in source_action if a.dest in cfg]: # type: ignore parser._check_value_key(source_action_n, cfg[source_action_n.dest], source_action_n.dest, None) args.append(cfg[source_key]) if skip_link: @@ -309,6 +316,7 @@ def apply_parsing_links(parser: "ArgumentParser", cfg: Namespace) -> None: # Compute value value = action.call_compute_fn(args) ActionLink.set_target_value(action, value, cfg, parser.logger) + parser.logger.debug(f"Applied link '{action.option_strings[0]}'.") @staticmethod def apply_instantiation_links(parser, cfg, target=None, order=None): @@ -317,14 +325,15 @@ def apply_instantiation_links(parser, cfg, target=None, order=None): applied_key = "__applied_instantiation_links__" applied_links = cfg.pop(applied_key) if applied_key in cfg else set() - link_actions = [ - a for a in parser._links_group._group_actions if a.apply_on == "instantiate" and a not in applied_links - ] + link_actions = get_link_actions(parser, "instantiate", skip=applied_links) if order and link_actions: link_actions = ActionLink.reorder(order, link_actions) for action in link_actions: - if not (order or action.target[0] == target or action.target[0].startswith(target + ".")): + target_key = action.target[0] + if not ( + order or target_key == target or target_key.startswith(f"{target}.") + ) or is_nested_instantiation_link(action): continue source_objects = [] for source_key, source_action in action.source: @@ -350,10 +359,25 @@ def apply_instantiation_links(parser, cfg, target=None, order=None): value = action.call_compute_fn(source_objects) ActionLink.set_target_value(action, value, cfg, parser.logger) applied_links.add(action) + parser.logger.debug(f"Applied link '{action.option_strings[0]}'.") if target: cfg[applied_key] = applied_links + @staticmethod + def get_nested_links(parser, action): + def trim_param_keys(params: dict): + params = params.copy() + params["source"] = tuple(k[len(f"{action.dest}.") :] for k in params["source"]) + params["target"] = params["target"][len(f"{action.dest}.init_args.") :] + return params + + links = [] + for link in get_link_actions(parser, "instantiate"): + if link.target[1] is action and is_nested_instantiation_link(link): + links.append(trim_param_keys(link.get_kwargs())) + return links + @staticmethod def set_target_value(action: "ActionLink", value: Any, cfg: Namespace, logger) -> None: target_key, target_action = action.target @@ -378,15 +402,14 @@ def set_target_value(action: "ActionLink", value: Any, cfg: Namespace, logger) - @staticmethod def instantiation_order(parser): - if hasattr(parser, "_links_group"): - actions = [a for a in parser._links_group._group_actions if a.apply_on == "instantiate"] - if len(actions) > 0: - graph = DirectedGraph() - for action in actions: - target = re.sub(r"\.init_args$", "", split_key_leaf(action.target[0])[0]) - for _, source_action in action.source: - graph.add_edge(source_action.dest, target) - return graph.get_topological_order() + actions = get_link_actions(parser, "instantiate") + if actions: + graph = DirectedGraph() + for action in actions: + target = re.sub(r"\.init_args$", "", split_key_leaf(action.target[0])[0]) + for _, source_action in action.source: + graph.add_edge(source_action.dest, target) + return graph.get_topological_order() return [] @staticmethod @@ -428,6 +451,25 @@ def del_target_key(target_key): ActionLink.strip_link_target_keys(subparsers[num], cfg[subcommand]) +def get_link_actions(parser: "ArgumentParser", apply_on: str, skip=set()) -> List[ActionLink]: + if not hasattr(parser, "_links_group"): + return [] + return [a for a in parser._links_group._group_actions if a.apply_on == apply_on and a not in skip] + + +def is_nested_instantiation_link(action: ActionLink) -> bool: + from ._typehints import ActionTypeHint + + target_key, target_action = action.target + assert target_action + return ( + target_key.startswith(f"{target_action.dest}.init_args.") + and ActionTypeHint.is_subclass_typehint(target_action) + and all(a is target_action for _, a in action.source) + and all(k.startswith(f"{target_action.dest}.") for k, _ in action.source) + ) + + class ArgumentLinking: """Method for linking arguments.""" diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index b6e31ff3..0cd7afe0 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -41,7 +41,7 @@ _is_action_value_list, remove_actions, ) -from ._common import get_class_instantiator, is_dataclass_like, is_subclass, parent_parser, parser_context +from ._common import get_class_instantiator, is_dataclass_like, is_subclass, nested_links, parent_parser, parser_context from ._loaders_dumpers import ( get_loader_exceptions, load_value, @@ -528,10 +528,15 @@ def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0): parser = type(parser)(exit_on_error=False, logger=parser.logger) remove_actions(parser, (ActionConfigFile, _ActionPrintConfig)) parser.add_class_arguments(val_class, **kwargs) + if "linked_targets" in kwargs and parser.required_args: for key in kwargs["linked_targets"]: if key in parser.required_args: parser.required_args.remove(key) + + for link_kwargs in nested_links.get(): + parser.link_arguments(**link_kwargs) + return parser def extra_help(self): diff --git a/jsonargparse_tests/test_link_arguments.py b/jsonargparse_tests/test_link_arguments.py index 6096745c..53b44a5a 100644 --- a/jsonargparse_tests/test_link_arguments.py +++ b/jsonargparse_tests/test_link_arguments.py @@ -669,6 +669,103 @@ def test_on_instantiate_add_argument_subclass_required_params(parser): assert init.cls2.a == 1 +class WithinDeepSource: + def __init__(self, model_name: str): + self.output_channels = dict( + modelA=16, + modelB=32, + )[model_name] + + +class WithinDeepTarget: + def __init__(self, input_channels: int): + self.input_channels = input_channels + + +class WithinDeepModel: + def __init__( + self, + encoder: WithinDeepSource, + decoder: WithinDeepTarget, + ): + self.encoder = encoder + self.decoder = decoder + + +within_deep_config = { + "model": { + "class_path": f"{__name__}.WithinDeepModel", + "init_args": { + "encoder": { + "class_path": f"{__name__}.WithinDeepSource", + "init_args": { + "model_name": "modelA", + }, + }, + "decoder": { + "class_path": f"{__name__}.WithinDeepTarget", + }, + }, + }, +} + + +def test_on_instantiate_within_deep_subclass(parser, caplog): + parser.logger = {"level": "DEBUG"} + parser.logger.handlers = [caplog.handler] + + parser.add_argument("--cfg", action=ActionConfigFile) + parser.add_argument("--model", type=WithinDeepModel) + parser.link_arguments( + "model.encoder.output_channels", + "model.init_args.decoder.init_args.input_channels", + apply_on="instantiate", + ) + + cfg = parser.parse_args([f"--cfg={within_deep_config}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.model, WithinDeepModel) + assert isinstance(init.model.encoder, WithinDeepSource) + assert isinstance(init.model.decoder, WithinDeepTarget) + assert init.model.decoder.input_channels == 16 + assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text + + +class WithinDeeperSystem: + def __init__(self, model: WithinDeepModel): + self.model = model + + +within_deeper_config = { + "system": { + "class_path": f"{__name__}.WithinDeeperSystem", + "init_args": within_deep_config, + }, +} + + +def test_on_instantiate_within_deeper_subclass(parser, caplog): + parser.logger = {"level": "DEBUG"} + parser.logger.handlers = [caplog.handler] + + parser.add_argument("--cfg", action=ActionConfigFile) + parser.add_subclass_arguments(WithinDeeperSystem, "system") + parser.link_arguments( + "system.model.encoder.output_channels", + "system.init_args.model.init_args.decoder.init_args.input_channels", + apply_on="instantiate", + ) + + cfg = parser.parse_args([f"--cfg={within_deeper_config}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.system, WithinDeeperSystem) + assert isinstance(init.system.model, WithinDeepModel) + assert isinstance(init.system.model.encoder, WithinDeepSource) + assert isinstance(init.system.model.decoder, WithinDeepTarget) + assert init.system.model.decoder.input_channels == 16 + assert "Applied link 'encoder.output_channels --> decoder.init_args.input_channels'" in caplog.text + + # link creation failures