Skip to content

Commit 534a287

Browse files
authored
Fix: Add function signature failing when conditionally calling different functions (#468)
1 parent bca1588 commit 534a287

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Fixed
2626
produces an invalid string default.
2727
- dataclass single parameter change incorrectly resetting previous values (`#464
2828
<https://github.com/omni-us/jsonargparse/issues/464>`__).
29+
- Add function signature failing when conditionally calling different functions
30+
(`#467 <https://github.com/omni-us/jsonargparse/issues/467>`__).
2931

3032

3133
v4.27.5 (2024-02-12)

jsonargparse/_signatures.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import dataclasses
44
import inspect
55
import re
6-
from argparse import SUPPRESS
6+
from argparse import SUPPRESS, ArgumentParser
77
from contextlib import suppress
88
from typing import Any, Callable, List, Optional, Set, Tuple, Type, Union
99

@@ -255,7 +255,7 @@ def _add_signature_arguments(
255255
## Create group if requested ##
256256
doc_group = get_doc_short_description(function_or_class, method_name, logger=self.logger)
257257
component = getattr(function_or_class, method_name) if method_name else function_or_class
258-
group = self._create_group_if_requested(
258+
container = self._create_group_if_requested(
259259
component,
260260
nested_key,
261261
as_group,
@@ -268,7 +268,7 @@ def _add_signature_arguments(
268268
added_args: List[str] = []
269269
for param in params:
270270
self._add_signature_parameter(
271-
group,
271+
container,
272272
nested_key,
273273
param,
274274
added_args,
@@ -283,7 +283,7 @@ def _add_signature_arguments(
283283

284284
def _add_signature_parameter(
285285
self,
286-
group,
286+
container,
287287
nested_key: Optional[str],
288288
param,
289289
added_args: List[str],
@@ -339,11 +339,14 @@ def _add_signature_parameter(
339339
dest = (nested_key + "." if nested_key else "") + name
340340
args = [dest if is_required and as_positional else "--" + dest]
341341
if param.origin:
342+
parser = container
343+
if not isinstance(container, ArgumentParser):
344+
parser = getattr(container, "parser")
342345
group_name = "; ".join(str(o) for o in param.origin)
343-
if group_name in group.parser.groups:
344-
group = group.parser.groups[group_name]
346+
if group_name in parser.groups:
347+
container = parser.groups[group_name]
345348
else:
346-
group = group.parser.add_argument_group(
349+
container = parser.add_argument_group(
347350
f"Conditional arguments [origins: {group_name}]",
348351
name=group_name,
349352
)
@@ -372,7 +375,7 @@ def _add_signature_parameter(
372375
args=args,
373376
kwargs=kwargs,
374377
enable_path=enable_path,
375-
container=group,
378+
container=container,
376379
logger=self.logger,
377380
sub_add_kwargs=sub_add_kwargs,
378381
)
@@ -387,7 +390,7 @@ def _add_signature_parameter(
387390
if is_dataclass_like_typehint:
388391
kwargs.update(sub_add_kwargs)
389392
with ActionTypeHint.allow_default_instance_context():
390-
action = group.add_argument(*args, **kwargs)
393+
action = container.add_argument(*args, **kwargs)
391394
action.sub_add_kwargs = sub_add_kwargs
392395
if is_subclass_typehint and len(subclass_skip) > 0:
393396
action.sub_add_kwargs["skip"] = subclass_skip

jsonargparse_tests/test_cli.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from jsonargparse import CLI, capture_parser, lazy_instance
1616
from jsonargparse._optionals import docstring_parser_support, ruyaml_support
17+
from jsonargparse._typehints import Literal
1718
from jsonargparse.typing import final
1819
from jsonargparse_tests.conftest import skip_if_docstring_parser_unavailable
1920

@@ -120,6 +121,31 @@ def test_multiple_functions_subcommand_help():
120121
assert "--a2 A2" in out
121122

122123

124+
def conditionalA(foo: int = 1):
125+
return foo
126+
127+
128+
def conditionalB(bar: int = 2):
129+
return bar
130+
131+
132+
def conditional_function(fn: "Literal['A', 'B']", *args, **kwargs):
133+
if fn == "A":
134+
return conditionalA(*args, **kwargs)
135+
elif fn == "B":
136+
return conditionalB(*args, **kwargs)
137+
raise NotImplementedError(fn)
138+
139+
140+
@pytest.mark.skipif(condition=sys.version_info < (3, 9), reason="python>=3.9 is required")
141+
@pytest.mark.skipif(condition=not Literal, reason="Literal is required")
142+
def test_literal_conditional_function():
143+
out = get_cli_stdout(conditional_function, args=["--help"])
144+
assert "Conditional arguments" in out
145+
assert "--foo FOO (type: int, default: Conditional<ast-resolver> {1, NOT_ACCEPTED})" in out
146+
assert "--bar BAR (type: int, default: Conditional<ast-resolver> {2, NOT_ACCEPTED})" in out
147+
148+
123149
# single class tests
124150

125151

0 commit comments

Comments
 (0)