Skip to content

Commit

Permalink
Fix argument links not working for target init_args in an optional li…
Browse files Browse the repository at this point in the history
…st (#434)
  • Loading branch information
mauvilsa authored Jan 23, 2024
1 parent a9fc1cd commit fd6a44b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
9 changes: 9 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ The semantic versioning only considers the public API as described in
paths are considered internals and can change in minor and patch releases.


v4.27.3 (2024-01-??)
--------------------

Fixed
^^^^^
- Argument links not working for target ``init_args`` in an optional list (`#433
<https://github.com/omni-us/jsonargparse/issues/433>`__).


v4.27.2 (2024-01-18)
--------------------

Expand Down
3 changes: 2 additions & 1 deletion jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False):
if typehint_origin == Union or (also_lists and typehint_origin in sequence_origin_types):
subtypes = [a for a in typehint.__args__ if a != NoneType]
test = all if all_subtypes else any
return test(ActionTypeHint.is_subclass_typehint(s) for s in subtypes)
k = {"also_lists": also_lists}
return test(ActionTypeHint.is_subclass_typehint(s, **k) for s in subtypes)
return (
inspect.isclass(typehint)
and typehint not in leaf_or_root_types
Expand Down
22 changes: 22 additions & 0 deletions jsonargparse_tests/test_link_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,28 @@ def test_on_parse_subclass_target_in_union_list(parser):
assert all(x.init_args == Namespace(save_dir="logs") for x in cfg.trainer.logger)


class TrainerLoggerOptionalList:
def __init__(
self,
save_dir: Optional[str] = None,
logger: Optional[List[Logger]] = None,
):
pass


def test_on_parse_subclass_target_in_optional_list(parser):
parser.add_class_arguments(TrainerLoggerOptionalList, "trainer")
parser.link_arguments("trainer.save_dir", "trainer.logger.init_args.save_dir")
cfg = parser.parse_args([])
assert cfg.trainer == Namespace(logger=None, save_dir=None)
cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger+=Logger"])
assert cfg.trainer.save_dir == "logs"
assert cfg.trainer.logger[0].init_args == Namespace(save_dir="logs")
cfg = parser.parse_args(["--trainer.save_dir=logs", "--trainer.logger=[Logger, Logger]"])
assert len(cfg.trainer.logger) == 2
assert all(x.init_args == Namespace(save_dir="logs") for x in cfg.trainer.logger)


class ClassF:
def __init__(
self,
Expand Down

0 comments on commit fd6a44b

Please sign in to comment.