diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4ebc193d..ff1ac63f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,6 +12,15 @@ The semantic versioning only considers the public API as described in paths are considered internals and can change in minor and patch releases. +v4.35.1 (2024-12-??) +-------------------- + +Fixed +^^^^^ +- Help for ``Protocol`` types not working correctly (`#645 + `__). + + v4.35.0 (2024-12-16) -------------------- diff --git a/jsonargparse/_actions.py b/jsonargparse/_actions.py index a3a55ed5..47061fa9 100644 --- a/jsonargparse/_actions.py +++ b/jsonargparse/_actions.py @@ -350,18 +350,31 @@ def __init__(self, typehint=None, **kwargs): super().__init__(**kwargs) def update_init_kwargs(self, kwargs): - from ._typehints import get_optional_arg, get_subclass_names, get_unaliased_type + from ._typehints import ( + get_optional_arg, + get_subclass_names, + get_subclass_types, + get_unaliased_type, + is_protocol, + ) typehint = get_unaliased_type(get_optional_arg(self._typehint)) if get_typehint_origin(typehint) is not Union: assert "nargs" not in kwargs kwargs["nargs"] = "?" self._basename = iter_to_set_str(get_subclass_names(self._typehint, callable_return=True)) + self._baseclasses = get_subclass_types(typehint, callable_return=True) + assert self._baseclasses + + self._kind = "subclass of" + if any(is_protocol(b) for b in self._baseclasses): + self._kind = "subclass or implementer of protocol" + kwargs.update( { "metavar": "CLASS_PATH_OR_NAME", "default": SUPPRESS, - "help": f"Show the help for the given subclass of {self._basename} and exit.", + "help": f"Show the help for the given {self._kind} {self._basename} and exit.", } ) @@ -375,23 +388,22 @@ def print_help(self, call_args): from ._typehints import ( ActionTypeHint, get_optional_arg, - get_subclass_types, get_unaliased_type, + implements_protocol, resolve_class_path_by_name, ) parser, _, value, option_string = call_args try: typehint = get_unaliased_type(get_optional_arg(self._typehint)) - baseclasses = get_subclass_types(typehint, callable_return=True) if self.nargs == "?" and value is None: val_class = typehint else: val_class = import_object(resolve_class_path_by_name(typehint, value)) except Exception as ex: raise TypeError(f"{option_string}: {ex}") from ex - if not any(is_subclass(val_class, b) for b in baseclasses): - raise TypeError(f'{option_string}: Class "{value}" is not a subclass of {self._basename}') + if not any(is_subclass(val_class, b) or implements_protocol(val_class, b) for b in self._baseclasses): + raise TypeError(f'{option_string}: Class "{value}" is not a {self._kind} {self._basename}') dest = re.sub("\\.help$", "", self.dest) subparser = type(parser)(description=f"Help for {option_string}={get_import_path(val_class)}") if ActionTypeHint.is_callable_typehint(typehint) and hasattr(typehint, "__args__"): diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 4ac35b3a..986d1d04 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -1103,7 +1103,7 @@ def implements_protocol(value, protocol) -> bool: from jsonargparse._parameter_resolvers import get_signature_parameters from jsonargparse._postponed_annotations import get_return_type - if not inspect.isclass(value) or value is object: + if not inspect.isclass(value) or value is object or not is_protocol(protocol): return False members = 0 for name, _ in inspect.getmembers(protocol, predicate=inspect.isfunction): diff --git a/jsonargparse_tests/test_subclasses.py b/jsonargparse_tests/test_subclasses.py index a29e3431..e5a8c872 100644 --- a/jsonargparse_tests/test_subclasses.py +++ b/jsonargparse_tests/test_subclasses.py @@ -1443,6 +1443,14 @@ def predict(self, items: List[float]) -> List[float]: return items +class SubclassImplementsInterface(Interface): + def __init__(self, max_items: int): + self.max_items = max_items + + def predict(self, items: List[float]) -> List[float]: + return items + + class NotImplementsInterface1: def predict(self, items: str) -> List[float]: return [] @@ -1462,6 +1470,7 @@ def predict(self, items: List[float]) -> None: "expected, value", [ (True, ImplementsInterface), + (True, SubclassImplementsInterface), (False, ImplementsInterface(1)), (False, NotImplementsInterface1), (False, NotImplementsInterface2), @@ -1488,7 +1497,6 @@ def test_is_instance_or_supports_protocol(expected, value): def test_parse_implements_protocol(parser): parser.add_argument("--cls", type=Interface) - assert "known subclasses:" not in get_parser_help(parser) cfg = parser.parse_args([f"--cls={__name__}.ImplementsInterface", "--cls.batch_size=5"]) assert cfg.cls.class_path == f"{__name__}.ImplementsInterface" assert cfg.cls.init_args == Namespace(batch_size=5) @@ -1496,6 +1504,15 @@ def test_parse_implements_protocol(parser): assert isinstance(init.cls, ImplementsInterface) assert init.cls.batch_size == 5 assert init.cls.predict([1.0, 2.0]) == [1.0, 2.0] + + help_str = get_parser_help(parser) + assert "known subclasses:" in help_str + assert f"{__name__}.SubclassImplementsInterface" in help_str + help_str = get_parse_args_stdout(parser, ["--cls.help=SubclassImplementsInterface"]) + assert "--cls.max_items" in help_str + with pytest.raises(ArgumentError, match="not a subclass or implementer of protocol"): + parser.parse_args([f"--cls.help={__name__}.NotImplementsInterface1"]) + with pytest.raises(ArgumentError, match="is a protocol"): parser.parse_args([f"--cls={__name__}.Interface"]) with pytest.raises(ArgumentError, match="does not implement protocol"): @@ -1551,13 +1568,17 @@ def test_implements_callable_protocol(expected, value): def test_parse_implements_callable_protocol(parser): parser.add_argument("--cls", type=CallableInterface) - assert "known subclasses:" not in get_parser_help(parser) cfg = parser.parse_args([f"--cls={__name__}.ImplementsCallableInterface", "--cls.batch_size=7"]) assert cfg.cls.class_path == f"{__name__}.ImplementsCallableInterface" assert cfg.cls.init_args == Namespace(batch_size=7) init = parser.instantiate_classes(cfg) assert isinstance(init.cls, ImplementsCallableInterface) assert init.cls([1.0, 2.0]) == [1.0, 2.0] + + assert "known subclasses:" not in get_parser_help(parser) + help_str = get_parse_args_stdout(parser, [f"--cls.help={__name__}.ImplementsCallableInterface"]) + assert "--cls.batch_size" in help_str + with pytest.raises(ArgumentError, match="is a protocol"): parser.parse_args([f"--cls={__name__}.CallableInterface"]) with pytest.raises(ArgumentError, match="does not implement protocol"):