diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4b32e060..9482571f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -20,6 +20,8 @@ Fixed - Confusing error message when adding signature parameters that conflict with existing arguments. - Deprecation warnings not printing the correct file and line of code. +- ``sub_configs=True`` not working for callable types that return a class (`#419 + `__). v4.27.0 (2023-11-02) diff --git a/jsonargparse/_signatures.py b/jsonargparse/_signatures.py index 455e2ae5..cc1aa464 100644 --- a/jsonargparse/_signatures.py +++ b/jsonargparse/_signatures.py @@ -365,10 +365,13 @@ def _add_signature_parameter( sub_add_kwargs["skip"] = subclass_skip else: register_pydantic_type(annotation) + enable_path = sub_configs and ( + is_subclass_typehint or ActionTypeHint.is_return_subclass_typehint(annotation) + ) args = ActionTypeHint.prepare_add_argument( args=args, kwargs=kwargs, - enable_path=is_subclass_typehint and sub_configs, + enable_path=enable_path, container=group, logger=self.logger, sub_add_kwargs=sub_add_kwargs, diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 3042de38..4cdb2859 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -278,6 +278,15 @@ def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False): and not is_subclass(typehint, (Path, Enum)) ) + @staticmethod + def is_return_subclass_typehint(typehint): + typehint_origin = get_typehint_origin(typehint) + if typehint_origin in callable_origin_types: + return_type = get_callable_return_type(typehint) + if ActionTypeHint.is_subclass_typehint(return_type): + return True + return False + @staticmethod def is_mapping_typehint(typehint): typehint_origin = get_typehint_origin(typehint) or typehint diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index 00a52064..fd714a0f 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -6,6 +6,7 @@ import uuid from calendar import Calendar from enum import Enum +from pathlib import Path from typing import ( Any, Callable, @@ -784,6 +785,27 @@ def test_callable_args_return_type_union_of_classes(parser, subtests): assert f"{__name__}.{name}" in help_str +class CallableSubconfig: + def __init__(self, o: Callable[[int], Optimizer]): + self.o = o + + +def test_callable_args_return_type_class_subconfig(parser, tmp_cwd): + config = { + "class_path": "Adam", + "init_args": {"momentum": 0.8}, + } + Path("optimizer.yaml").write_text(yaml.safe_dump(config)) + + parser.add_class_arguments(CallableSubconfig, "m", sub_configs=True) + cfg = parser.parse_args(["--m.o=optimizer.yaml"]) + assert cfg.m.o.class_path == f"{__name__}.Adam" + init = parser.instantiate_classes(cfg) + optimizer = init.m.o(1) + assert isinstance(optimizer, Adam) + assert optimizer.momentum == 0.8 + + # lazy_instance tests