Skip to content

Commit

Permalink
Merge branch 'main' into nested_attrs_dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa authored Dec 17, 2024
2 parents 8d5b125 + 8f22f69 commit 9bd7fba
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 9 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ The semantic versioning only considers the public API as described in
paths are considered internals and can change in minor and patch releases.


v4.35.1 (2024-12-??)
--------------------

Fixed
^^^^^
- Help for ``Protocol`` types not working correctly (`#645
<https://github.com/omni-us/jsonargparse/pull/645>`__).


v4.35.0 (2024-12-16)
--------------------

Expand Down
24 changes: 18 additions & 6 deletions jsonargparse/_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,18 +350,31 @@ def __init__(self, typehint=None, **kwargs):
super().__init__(**kwargs)

def update_init_kwargs(self, kwargs):
from ._typehints import get_optional_arg, get_subclass_names, get_unaliased_type
from ._typehints import (
get_optional_arg,
get_subclass_names,
get_subclass_types,
get_unaliased_type,
is_protocol,
)

typehint = get_unaliased_type(get_optional_arg(self._typehint))
if get_typehint_origin(typehint) is not Union:
assert "nargs" not in kwargs
kwargs["nargs"] = "?"
self._basename = iter_to_set_str(get_subclass_names(self._typehint, callable_return=True))
self._baseclasses = get_subclass_types(typehint, callable_return=True)
assert self._baseclasses

self._kind = "subclass of"
if any(is_protocol(b) for b in self._baseclasses):
self._kind = "subclass or implementer of protocol"

kwargs.update(
{
"metavar": "CLASS_PATH_OR_NAME",
"default": SUPPRESS,
"help": f"Show the help for the given subclass of {self._basename} and exit.",
"help": f"Show the help for the given {self._kind} {self._basename} and exit.",
}
)

Expand All @@ -375,23 +388,22 @@ def print_help(self, call_args):
from ._typehints import (
ActionTypeHint,
get_optional_arg,
get_subclass_types,
get_unaliased_type,
implements_protocol,
resolve_class_path_by_name,
)

parser, _, value, option_string = call_args
try:
typehint = get_unaliased_type(get_optional_arg(self._typehint))
baseclasses = get_subclass_types(typehint, callable_return=True)
if self.nargs == "?" and value is None:
val_class = typehint
else:
val_class = import_object(resolve_class_path_by_name(typehint, value))
except Exception as ex:
raise TypeError(f"{option_string}: {ex}") from ex
if not any(is_subclass(val_class, b) for b in baseclasses):
raise TypeError(f'{option_string}: Class "{value}" is not a subclass of {self._basename}')
if not any(is_subclass(val_class, b) or implements_protocol(val_class, b) for b in self._baseclasses):
raise TypeError(f'{option_string}: Class "{value}" is not a {self._kind} {self._basename}')
dest = re.sub("\\.help$", "", self.dest)
subparser = type(parser)(description=f"Help for {option_string}={get_import_path(val_class)}")
if ActionTypeHint.is_callable_typehint(typehint) and hasattr(typehint, "__args__"):
Expand Down
2 changes: 1 addition & 1 deletion jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ def implements_protocol(value, protocol) -> bool:
from jsonargparse._parameter_resolvers import get_signature_parameters
from jsonargparse._postponed_annotations import get_return_type

if not inspect.isclass(value) or value is object:
if not inspect.isclass(value) or value is object or not is_protocol(protocol):
return False
members = 0
for name, _ in inspect.getmembers(protocol, predicate=inspect.isfunction):
Expand Down
25 changes: 23 additions & 2 deletions jsonargparse_tests/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,14 @@ def predict(self, items: List[float]) -> List[float]:
return items


class SubclassImplementsInterface(Interface):
def __init__(self, max_items: int):
self.max_items = max_items

def predict(self, items: List[float]) -> List[float]:
return items


class NotImplementsInterface1:
def predict(self, items: str) -> List[float]:
return []
Expand All @@ -1462,6 +1470,7 @@ def predict(self, items: List[float]) -> None:
"expected, value",
[
(True, ImplementsInterface),
(True, SubclassImplementsInterface),
(False, ImplementsInterface(1)),
(False, NotImplementsInterface1),
(False, NotImplementsInterface2),
Expand All @@ -1488,14 +1497,22 @@ def test_is_instance_or_supports_protocol(expected, value):

def test_parse_implements_protocol(parser):
parser.add_argument("--cls", type=Interface)
assert "known subclasses:" not in get_parser_help(parser)
cfg = parser.parse_args([f"--cls={__name__}.ImplementsInterface", "--cls.batch_size=5"])
assert cfg.cls.class_path == f"{__name__}.ImplementsInterface"
assert cfg.cls.init_args == Namespace(batch_size=5)
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, ImplementsInterface)
assert init.cls.batch_size == 5
assert init.cls.predict([1.0, 2.0]) == [1.0, 2.0]

help_str = get_parser_help(parser)
assert "known subclasses:" in help_str
assert f"{__name__}.SubclassImplementsInterface" in help_str
help_str = get_parse_args_stdout(parser, ["--cls.help=SubclassImplementsInterface"])
assert "--cls.max_items" in help_str
with pytest.raises(ArgumentError, match="not a subclass or implementer of protocol"):
parser.parse_args([f"--cls.help={__name__}.NotImplementsInterface1"])

with pytest.raises(ArgumentError, match="is a protocol"):
parser.parse_args([f"--cls={__name__}.Interface"])
with pytest.raises(ArgumentError, match="does not implement protocol"):
Expand Down Expand Up @@ -1551,13 +1568,17 @@ def test_implements_callable_protocol(expected, value):

def test_parse_implements_callable_protocol(parser):
parser.add_argument("--cls", type=CallableInterface)
assert "known subclasses:" not in get_parser_help(parser)
cfg = parser.parse_args([f"--cls={__name__}.ImplementsCallableInterface", "--cls.batch_size=7"])
assert cfg.cls.class_path == f"{__name__}.ImplementsCallableInterface"
assert cfg.cls.init_args == Namespace(batch_size=7)
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, ImplementsCallableInterface)
assert init.cls([1.0, 2.0]) == [1.0, 2.0]

assert "known subclasses:" not in get_parser_help(parser)
help_str = get_parse_args_stdout(parser, [f"--cls.help={__name__}.ImplementsCallableInterface"])
assert "--cls.batch_size" in help_str

with pytest.raises(ArgumentError, match="is a protocol"):
parser.parse_args([f"--cls={__name__}.CallableInterface"])
with pytest.raises(ArgumentError, match="does not implement protocol"):
Expand Down

0 comments on commit 9bd7fba

Please sign in to comment.