From 2b20bc85e511bba2ee531d9981ab5e1654761b6e Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Thu, 12 Oct 2023 07:06:59 +0200 Subject: [PATCH] Add support for on parse argument links with target subclasses in a list (#394, lightning#18161). --- CHANGELOG.rst | 8 +++- DOCUMENTATION.rst | 34 +++++++++++++++- jsonargparse/_link_arguments.py | 26 ++++++++---- jsonargparse/_typehints.py | 4 +- jsonargparse_tests/test_link_arguments.py | 49 ++++++++++++++++++++++- 5 files changed, 107 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 79e1c54f..ee689ed8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,9 +12,15 @@ The semantic versioning only considers the public API as described in paths are considered internals and can change in minor and patch releases. -v4.25.1 (2023-10-??) +v4.26.0 (2023-10-??) -------------------- +Added +^^^^^ +- Support for on parse argument links with target subclasses in a list (`#394 + `__, `lightning#18161 + `__). + Fixed ^^^^^ - Failures with subcommands and default_config_files when keys are repeated diff --git a/DOCUMENTATION.rst b/DOCUMENTATION.rst index 20bc653b..f634fd5c 100644 --- a/DOCUMENTATION.rst +++ b/DOCUMENTATION.rst @@ -2071,9 +2071,12 @@ There are two types of links, defined with ``apply_on='parse'`` or calling one of the parse methods and the latter are set when calling :py:meth:`.ArgumentParser.instantiate_classes`. +Applied on parse +---------------- + For parsing links, source keys can be individual arguments or nested groups. The -target key has to be a single argument. The keys can be inside init_args of a -subclass. The compute function should accept as many positional arguments as +target key has to be a single argument. The keys can be inside ``init_args`` of +a subclass. The compute function should accept as many positional arguments as there are sources and return a value of type compatible with the target. An example would be the following: @@ -2097,6 +2100,33 @@ example would be the following: As argument and in config files only ``data.batch_size`` should be specified. Then whatever value it has will be propagated to ``model.batch_size``. +An example of a target being in a subclass is: + +.. testcode:: + + class Logger: + def __init__(self, save_dir: Optional[str] = None): + self.save_dir = save_dir + + class Trainer: + def __init__( + self, + save_dir: Optional[str] = None, + logger: Union[bool, Logger, List[Logger]] = False, + ): + self.logger = logger + + parser = ArgumentParser() + parser.add_class_arguments(Trainer, "trainer") + parser.link_arguments("trainer.save_dir", "trainer.logger.init_args.save_dir") + +The link gets applied to the ``logger`` parameter when it is a single subclass +and applied to all elements of a list of subclasses. If a subclass does not +define the targeted ``init_args`` parameter, the link is ignored. + +Applied on instantiate +---------------------- + For instantiation links, sources can be class groups (added with :py:meth:`.SignatureArguments.add_class_arguments`) or subclass arguments (see :ref:`sub-classes`). The source key can be the entire instantiated object or an diff --git a/jsonargparse/_link_arguments.py b/jsonargparse/_link_arguments.py index 8ccabcac..f5f0af2b 100644 --- a/jsonargparse/_link_arguments.py +++ b/jsonargparse/_link_arguments.py @@ -146,7 +146,7 @@ def __init__( from ._typehints import ActionTypeHint - is_target_subclass = ActionTypeHint.is_subclass_typehint(self.target[1], all_subtypes=False) + is_target_subclass = ActionTypeHint.is_subclass_typehint(self.target[1], all_subtypes=False, also_lists=True) valid_target_init_arg = is_target_subclass and target.startswith(f"{self.target[1].dest}.init_args.") valid_target_leaf = self.target[1].dest == target if not valid_target_leaf and is_target_subclass and not valid_target_init_arg: @@ -163,7 +163,10 @@ def __init__( group._group_actions.remove(self.target[1]) if is_target_subclass: help_dest = f"{self.target[1].dest}.help" - group._group_actions.remove(next(a for a in group._group_actions if a.dest == help_dest)) + for action in group._group_actions: + if action.dest == help_dest: # type: ignore + group._group_actions.remove(action) + break if group._group_actions and all(isinstance(a, _ActionConfigLoad) for a in group._group_actions): group.description = ( f"Group '{group._group_actions[0].dest}': All arguments are derived from links." @@ -354,14 +357,23 @@ def apply_instantiation_links(parser, cfg, target=None, order=None): @staticmethod def set_target_value(action: "ActionLink", value: Any, cfg: Namespace, logger) -> None: target_key, target_action = action.target + assert target_action from ._typehints import ActionTypeHint - if ActionTypeHint.is_subclass_typehint(target_action, all_subtypes=False): - if target_key == target_action.dest: # type: ignore + if ActionTypeHint.is_subclass_typehint(target_action, all_subtypes=False, also_lists=True): + if target_key == target_action.dest: target_action._check_type(value) # type: ignore - elif target_key not in cfg: - logger.debug(f"Link '{action.option_strings[0]}' ignored since target not found in namespace.") - return + else: + parent = cfg.get(target_action.dest) + child_key = target_key[len(target_action.dest) + 1 :] + if isinstance(parent, list) and any(isinstance(i, Namespace) and child_key in i for i in parent): + for item in parent: + if child_key in item: + item[child_key] = value + return + if target_key not in cfg: + logger.debug(f"Link '{action.option_strings[0]}' ignored since target not found.") + return cfg[target_key] = value @staticmethod diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index b5436914..f318ddb3 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -254,12 +254,12 @@ def is_supported_typehint(typehint, full=False): return supported @staticmethod - def is_subclass_typehint(typehint, all_subtypes=True): + def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False): typehint = typehint_from_action(typehint) if typehint is None: return False typehint_origin = get_typehint_origin(typehint) - if typehint_origin == Union: + if typehint_origin == Union or (also_lists and typehint_origin in sequence_origin_types): subtypes = [a for a in typehint.__args__ if a != NoneType] test = all if all_subtypes else any return test(ActionTypeHint.is_subclass_typehint(s) for s in subtypes) diff --git a/jsonargparse_tests/test_link_arguments.py b/jsonargparse_tests/test_link_arguments.py index 538c8664..982a1f17 100644 --- a/jsonargparse_tests/test_link_arguments.py +++ b/jsonargparse_tests/test_link_arguments.py @@ -217,7 +217,7 @@ def __init__(self, save_dir: Optional[str] = None): pass -class Trainer: +class TrainerLoggerUnion: def __init__( self, save_dir: Optional[str] = None, @@ -227,7 +227,7 @@ def __init__( def test_on_parse_subclass_target_in_union(parser): - parser.add_class_arguments(Trainer, "trainer") + parser.add_class_arguments(TrainerLoggerUnion, "trainer") parser.link_arguments("trainer.save_dir", "trainer.logger.init_args.save_dir") cfg = parser.parse_args([]) assert cfg.trainer == Namespace(logger=False, save_dir=None) @@ -236,6 +236,51 @@ def test_on_parse_subclass_target_in_union(parser): assert cfg.trainer.logger.init_args == Namespace(save_dir="logs") +class TrainerLoggerList: + def __init__( + self, + save_dir: Optional[str] = None, + logger: List[Logger] = [], + ): + pass + + +def test_on_parse_subclass_target_in_list(parser): + parser.add_class_arguments(TrainerLoggerList, "trainer") + parser.link_arguments("trainer.save_dir", "trainer.logger.init_args.save_dir") + cfg = parser.parse_args([]) + assert cfg.trainer == Namespace(logger=[], save_dir=None) + cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=[Logger]"]) + assert cfg.trainer.save_dir == "logs" + assert len(cfg.trainer.logger) == 1 + assert cfg.trainer.logger[0].init_args == Namespace(save_dir="logs") + cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=[Logger, Logger]"]) + assert len(cfg.trainer.logger) == 2 + assert all(x.init_args == Namespace(save_dir="logs") for x in cfg.trainer.logger) + + +class TrainerLoggerUnionList: + def __init__( + self, + save_dir: Optional[str] = None, + logger: Union[bool, Logger, List[Logger]] = False, + ): + pass + + +def test_on_parse_subclass_target_in_union_list(parser): + parser.add_class_arguments(TrainerLoggerUnionList, "trainer") + parser.link_arguments("trainer.save_dir", "trainer.logger.init_args.save_dir") + cfg = parser.parse_args([]) + assert cfg.trainer == Namespace(logger=False, save_dir=None) + cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=Logger"]) + assert cfg.trainer.save_dir == "logs" + assert cfg.trainer.logger.init_args == Namespace(save_dir="logs") + cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=[Logger, Logger]"]) + assert len(cfg.trainer.logger) == 2 + assert all(x.init_args == Namespace(save_dir="logs") for x in cfg.trainer.logger) + + class ClassF: def __init__( self,