Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix help for Protocol types not working correctly #645

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ The semantic versioning only considers the public API as described in
paths are considered internals and can change in minor and patch releases.


v4.35.1 (2024-12-??)
--------------------

Fixed
^^^^^
- Help for ``Protocol`` types not working correctly (`#645
<https://github.com/omni-us/jsonargparse/pull/645>`__).


v4.35.0 (2024-12-16)
--------------------

Expand Down
24 changes: 18 additions & 6 deletions jsonargparse/_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,18 +350,31 @@ def __init__(self, typehint=None, **kwargs):
super().__init__(**kwargs)

def update_init_kwargs(self, kwargs):
from ._typehints import get_optional_arg, get_subclass_names, get_unaliased_type
from ._typehints import (
get_optional_arg,
get_subclass_names,
get_subclass_types,
get_unaliased_type,
is_protocol,
)

typehint = get_unaliased_type(get_optional_arg(self._typehint))
if get_typehint_origin(typehint) is not Union:
assert "nargs" not in kwargs
kwargs["nargs"] = "?"
self._basename = iter_to_set_str(get_subclass_names(self._typehint, callable_return=True))
self._baseclasses = get_subclass_types(typehint, callable_return=True)
assert self._baseclasses

self._kind = "subclass of"
if any(is_protocol(b) for b in self._baseclasses):
self._kind = "subclass or implementer of protocol"

kwargs.update(
{
"metavar": "CLASS_PATH_OR_NAME",
"default": SUPPRESS,
"help": f"Show the help for the given subclass of {self._basename} and exit.",
"help": f"Show the help for the given {self._kind} {self._basename} and exit.",
}
)

Expand All @@ -375,23 +388,22 @@ def print_help(self, call_args):
from ._typehints import (
ActionTypeHint,
get_optional_arg,
get_subclass_types,
get_unaliased_type,
implements_protocol,
resolve_class_path_by_name,
)

parser, _, value, option_string = call_args
try:
typehint = get_unaliased_type(get_optional_arg(self._typehint))
baseclasses = get_subclass_types(typehint, callable_return=True)
if self.nargs == "?" and value is None:
val_class = typehint
else:
val_class = import_object(resolve_class_path_by_name(typehint, value))
except Exception as ex:
raise TypeError(f"{option_string}: {ex}") from ex
if not any(is_subclass(val_class, b) for b in baseclasses):
raise TypeError(f'{option_string}: Class "{value}" is not a subclass of {self._basename}')
if not any(is_subclass(val_class, b) or implements_protocol(val_class, b) for b in self._baseclasses):
raise TypeError(f'{option_string}: Class "{value}" is not a {self._kind} {self._basename}')
dest = re.sub("\\.help$", "", self.dest)
subparser = type(parser)(description=f"Help for {option_string}={get_import_path(val_class)}")
if ActionTypeHint.is_callable_typehint(typehint) and hasattr(typehint, "__args__"):
Expand Down
2 changes: 1 addition & 1 deletion jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ def implements_protocol(value, protocol) -> bool:
from jsonargparse._parameter_resolvers import get_signature_parameters
from jsonargparse._postponed_annotations import get_return_type

if not inspect.isclass(value) or value is object:
if not inspect.isclass(value) or value is object or not is_protocol(protocol):
return False
members = 0
for name, _ in inspect.getmembers(protocol, predicate=inspect.isfunction):
Expand Down
25 changes: 23 additions & 2 deletions jsonargparse_tests/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,14 @@ def predict(self, items: List[float]) -> List[float]:
return items


class SubclassImplementsInterface(Interface):
def __init__(self, max_items: int):
self.max_items = max_items

def predict(self, items: List[float]) -> List[float]:
return items


class NotImplementsInterface1:
def predict(self, items: str) -> List[float]:
return []
Expand All @@ -1462,6 +1470,7 @@ def predict(self, items: List[float]) -> None:
"expected, value",
[
(True, ImplementsInterface),
(True, SubclassImplementsInterface),
(False, ImplementsInterface(1)),
(False, NotImplementsInterface1),
(False, NotImplementsInterface2),
Expand All @@ -1488,14 +1497,22 @@ def test_is_instance_or_supports_protocol(expected, value):

def test_parse_implements_protocol(parser):
parser.add_argument("--cls", type=Interface)
assert "known subclasses:" not in get_parser_help(parser)
cfg = parser.parse_args([f"--cls={__name__}.ImplementsInterface", "--cls.batch_size=5"])
assert cfg.cls.class_path == f"{__name__}.ImplementsInterface"
assert cfg.cls.init_args == Namespace(batch_size=5)
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, ImplementsInterface)
assert init.cls.batch_size == 5
assert init.cls.predict([1.0, 2.0]) == [1.0, 2.0]

help_str = get_parser_help(parser)
assert "known subclasses:" in help_str
assert f"{__name__}.SubclassImplementsInterface" in help_str
help_str = get_parse_args_stdout(parser, ["--cls.help=SubclassImplementsInterface"])
assert "--cls.max_items" in help_str
with pytest.raises(ArgumentError, match="not a subclass or implementer of protocol"):
parser.parse_args([f"--cls.help={__name__}.NotImplementsInterface1"])

with pytest.raises(ArgumentError, match="is a protocol"):
parser.parse_args([f"--cls={__name__}.Interface"])
with pytest.raises(ArgumentError, match="does not implement protocol"):
Expand Down Expand Up @@ -1551,13 +1568,17 @@ def test_implements_callable_protocol(expected, value):

def test_parse_implements_callable_protocol(parser):
parser.add_argument("--cls", type=CallableInterface)
assert "known subclasses:" not in get_parser_help(parser)
cfg = parser.parse_args([f"--cls={__name__}.ImplementsCallableInterface", "--cls.batch_size=7"])
assert cfg.cls.class_path == f"{__name__}.ImplementsCallableInterface"
assert cfg.cls.init_args == Namespace(batch_size=7)
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, ImplementsCallableInterface)
assert init.cls([1.0, 2.0]) == [1.0, 2.0]

assert "known subclasses:" not in get_parser_help(parser)
help_str = get_parse_args_stdout(parser, [f"--cls.help={__name__}.ImplementsCallableInterface"])
assert "--cls.batch_size" in help_str

with pytest.raises(ArgumentError, match="is a protocol"):
parser.parse_args([f"--cls={__name__}.CallableInterface"])
with pytest.raises(ArgumentError, match="does not implement protocol"):
Expand Down
Loading