Skip to content

Commit

Permalink
Add support for on parse argument links with target subclasses in a l…
Browse files Browse the repository at this point in the history
…ist (#394, lightning#18161).
  • Loading branch information
mauvilsa committed Oct 16, 2023
1 parent dd0f9e3 commit 2b20bc8
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 14 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@ The semantic versioning only considers the public API as described in
paths are considered internals and can change in minor and patch releases.


v4.25.1 (2023-10-??)
v4.26.0 (2023-10-??)
--------------------

Added
^^^^^
- Support for on parse argument links with target subclasses in a list (`#394
<https://github.com/omni-us/jsonargparse/issues/394>`__, `lightning#18161
<https://github.com/Lightning-AI/lightning/issues/18161>`__).

Fixed
^^^^^
- Failures with subcommands and default_config_files when keys are repeated
Expand Down
34 changes: 32 additions & 2 deletions DOCUMENTATION.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2071,9 +2071,12 @@ There are two types of links, defined with ``apply_on='parse'`` or
calling one of the parse methods and the latter are set when calling
:py:meth:`.ArgumentParser.instantiate_classes`.

Applied on parse
----------------

For parsing links, source keys can be individual arguments or nested groups. The
target key has to be a single argument. The keys can be inside init_args of a
subclass. The compute function should accept as many positional arguments as
target key has to be a single argument. The keys can be inside ``init_args`` of
a subclass. The compute function should accept as many positional arguments as
there are sources and return a value of type compatible with the target. An
example would be the following:

Expand All @@ -2097,6 +2100,33 @@ example would be the following:
As argument and in config files only ``data.batch_size`` should be specified.
Then whatever value it has will be propagated to ``model.batch_size``.

An example of a target being in a subclass is:

.. testcode::

class Logger:
def __init__(self, save_dir: Optional[str] = None):
self.save_dir = save_dir

class Trainer:
def __init__(
self,
save_dir: Optional[str] = None,
logger: Union[bool, Logger, List[Logger]] = False,
):
self.logger = logger

parser = ArgumentParser()
parser.add_class_arguments(Trainer, "trainer")
parser.link_arguments("trainer.save_dir", "trainer.logger.init_args.save_dir")

The link gets applied to the ``logger`` parameter when it is a single subclass
and applied to all elements of a list of subclasses. If a subclass does not
define the targeted ``init_args`` parameter, the link is ignored.

Applied on instantiate
----------------------

For instantiation links, sources can be class groups (added with
:py:meth:`.SignatureArguments.add_class_arguments`) or subclass arguments (see
:ref:`sub-classes`). The source key can be the entire instantiated object or an
Expand Down
26 changes: 19 additions & 7 deletions jsonargparse/_link_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(

from ._typehints import ActionTypeHint

is_target_subclass = ActionTypeHint.is_subclass_typehint(self.target[1], all_subtypes=False)
is_target_subclass = ActionTypeHint.is_subclass_typehint(self.target[1], all_subtypes=False, also_lists=True)
valid_target_init_arg = is_target_subclass and target.startswith(f"{self.target[1].dest}.init_args.")
valid_target_leaf = self.target[1].dest == target
if not valid_target_leaf and is_target_subclass and not valid_target_init_arg:
Expand All @@ -163,7 +163,10 @@ def __init__(
group._group_actions.remove(self.target[1])
if is_target_subclass:
help_dest = f"{self.target[1].dest}.help"
group._group_actions.remove(next(a for a in group._group_actions if a.dest == help_dest))
for action in group._group_actions:
if action.dest == help_dest: # type: ignore
group._group_actions.remove(action)
break
if group._group_actions and all(isinstance(a, _ActionConfigLoad) for a in group._group_actions):
group.description = (
f"Group '{group._group_actions[0].dest}': All arguments are derived from links."
Expand Down Expand Up @@ -354,14 +357,23 @@ def apply_instantiation_links(parser, cfg, target=None, order=None):
@staticmethod
def set_target_value(action: "ActionLink", value: Any, cfg: Namespace, logger) -> None:
target_key, target_action = action.target
assert target_action
from ._typehints import ActionTypeHint

if ActionTypeHint.is_subclass_typehint(target_action, all_subtypes=False):
if target_key == target_action.dest: # type: ignore
if ActionTypeHint.is_subclass_typehint(target_action, all_subtypes=False, also_lists=True):
if target_key == target_action.dest:
target_action._check_type(value) # type: ignore
elif target_key not in cfg:
logger.debug(f"Link '{action.option_strings[0]}' ignored since target not found in namespace.")
return
else:
parent = cfg.get(target_action.dest)
child_key = target_key[len(target_action.dest) + 1 :]
if isinstance(parent, list) and any(isinstance(i, Namespace) and child_key in i for i in parent):
for item in parent:
if child_key in item:
item[child_key] = value
return
if target_key not in cfg:
logger.debug(f"Link '{action.option_strings[0]}' ignored since target not found.")
return
cfg[target_key] = value

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def is_supported_typehint(typehint, full=False):
return supported

@staticmethod
def is_subclass_typehint(typehint, all_subtypes=True):
def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False):
typehint = typehint_from_action(typehint)
if typehint is None:
return False
typehint_origin = get_typehint_origin(typehint)
if typehint_origin == Union:
if typehint_origin == Union or (also_lists and typehint_origin in sequence_origin_types):
subtypes = [a for a in typehint.__args__ if a != NoneType]
test = all if all_subtypes else any
return test(ActionTypeHint.is_subclass_typehint(s) for s in subtypes)
Expand Down
49 changes: 47 additions & 2 deletions jsonargparse_tests/test_link_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __init__(self, save_dir: Optional[str] = None):
pass


class Trainer:
class TrainerLoggerUnion:
def __init__(
self,
save_dir: Optional[str] = None,
Expand All @@ -227,7 +227,7 @@ def __init__(


def test_on_parse_subclass_target_in_union(parser):
parser.add_class_arguments(Trainer, "trainer")
parser.add_class_arguments(TrainerLoggerUnion, "trainer")
parser.link_arguments("trainer.save_dir", "trainer.logger.init_args.save_dir")
cfg = parser.parse_args([])
assert cfg.trainer == Namespace(logger=False, save_dir=None)
Expand All @@ -236,6 +236,51 @@ def test_on_parse_subclass_target_in_union(parser):
assert cfg.trainer.logger.init_args == Namespace(save_dir="logs")


class TrainerLoggerList:
def __init__(
self,
save_dir: Optional[str] = None,
logger: List[Logger] = [],
):
pass


def test_on_parse_subclass_target_in_list(parser):
parser.add_class_arguments(TrainerLoggerList, "trainer")
parser.link_arguments("trainer.save_dir", "trainer.logger.init_args.save_dir")
cfg = parser.parse_args([])
assert cfg.trainer == Namespace(logger=[], save_dir=None)
cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=[Logger]"])
assert cfg.trainer.save_dir == "logs"
assert len(cfg.trainer.logger) == 1
assert cfg.trainer.logger[0].init_args == Namespace(save_dir="logs")
cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=[Logger, Logger]"])
assert len(cfg.trainer.logger) == 2
assert all(x.init_args == Namespace(save_dir="logs") for x in cfg.trainer.logger)


class TrainerLoggerUnionList:
def __init__(
self,
save_dir: Optional[str] = None,
logger: Union[bool, Logger, List[Logger]] = False,
):
pass


def test_on_parse_subclass_target_in_union_list(parser):
parser.add_class_arguments(TrainerLoggerUnionList, "trainer")
parser.link_arguments("trainer.save_dir", "trainer.logger.init_args.save_dir")
cfg = parser.parse_args([])
assert cfg.trainer == Namespace(logger=False, save_dir=None)
cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=Logger"])
assert cfg.trainer.save_dir == "logs"
assert cfg.trainer.logger.init_args == Namespace(save_dir="logs")
cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=[Logger, Logger]"])
assert len(cfg.trainer.logger) == 2
assert all(x.init_args == Namespace(save_dir="logs") for x in cfg.trainer.logger)


class ClassF:
def __init__(
self,
Expand Down

0 comments on commit 2b20bc8

Please sign in to comment.