Skip to content

Commit

Permalink
Fix sub_configs=True not working for callable types that return a cla…
Browse files Browse the repository at this point in the history
…ss (#419).
  • Loading branch information
mauvilsa committed Nov 23, 2023
1 parent bd4d840 commit acfed44
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Fixed
- Confusing error message when adding signature parameters that conflict with
existing arguments.
- Deprecation warnings not printing the correct file and line of code.
- ``sub_configs=True`` not working for callable types that return a class (`#419
<https://github.com/omni-us/jsonargparse/issues/419>`__).


v4.27.0 (2023-11-02)
Expand Down
5 changes: 4 additions & 1 deletion jsonargparse/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,13 @@ def _add_signature_parameter(
sub_add_kwargs["skip"] = subclass_skip
else:
register_pydantic_type(annotation)
enable_path = sub_configs and (
is_subclass_typehint or ActionTypeHint.is_return_subclass_typehint(annotation)
)
args = ActionTypeHint.prepare_add_argument(
args=args,
kwargs=kwargs,
enable_path=is_subclass_typehint and sub_configs,
enable_path=enable_path,
container=group,
logger=self.logger,
sub_add_kwargs=sub_add_kwargs,
Expand Down
9 changes: 9 additions & 0 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,15 @@ def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False):
and not is_subclass(typehint, (Path, Enum))
)

@staticmethod
def is_return_subclass_typehint(typehint):
typehint_origin = get_typehint_origin(typehint)
if typehint_origin in callable_origin_types:
return_type = get_callable_return_type(typehint)
if ActionTypeHint.is_subclass_typehint(return_type):
return True
return False

@staticmethod
def is_mapping_typehint(typehint):
typehint_origin = get_typehint_origin(typehint) or typehint
Expand Down
22 changes: 22 additions & 0 deletions jsonargparse_tests/test_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import uuid
from calendar import Calendar
from enum import Enum
from pathlib import Path
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -784,6 +785,27 @@ def test_callable_args_return_type_union_of_classes(parser, subtests):
assert f"{__name__}.{name}" in help_str


class CallableSubconfig:
def __init__(self, o: Callable[[int], Optimizer]):
self.o = o


def test_callable_args_return_type_class_subconfig(parser, tmp_cwd):
config = {
"class_path": "Adam",
"init_args": {"momentum": 0.8},
}
Path("optimizer.yaml").write_text(yaml.safe_dump(config))

parser.add_class_arguments(CallableSubconfig, "m", sub_configs=True)
cfg = parser.parse_args(["--m.o=optimizer.yaml"])
assert cfg.m.o.class_path == f"{__name__}.Adam"
init = parser.instantiate_classes(cfg)
optimizer = init.m.o(1)
assert isinstance(optimizer, Adam)
assert optimizer.momentum == 0.8


# lazy_instance tests


Expand Down

0 comments on commit acfed44

Please sign in to comment.