diff --git a/docs/tutorial/parameter-types/enum.md b/docs/tutorial/parameter-types/enum.md index 185ad420c3..fd442fbf21 100644 --- a/docs/tutorial/parameter-types/enum.md +++ b/docs/tutorial/parameter-types/enum.md @@ -19,11 +19,11 @@ Check it: ```console $ python main.py --help -// Notice the predefined values [simple|conv|lstm] +// Notice the predefined values Usage: main.py [OPTIONS] Options: - --network [simple|conv|lstm] [default: simple] + --network [default: simple] --help Show this message and exit. // Try it @@ -91,8 +91,8 @@ $ python main.py --help Usage: main.py [OPTIONS] Options: - --groceries [Eggs|Bacon|Cheese] [default: Eggs, Cheese] - --help Show this message and exit. + --groceries [default: Eggs, Cheese] + --help Show this message and exit. // Try it with the default values $ python main.py @@ -123,11 +123,11 @@ You can also use `Literal` to represent a set of possible predefined choices, wi ```console $ python main.py --help -// Notice the predefined values [simple|conv|lstm] +// Notice the predefined values Usage: main.py [OPTIONS] Options: - --network [simple|conv|lstm] [default: simple] + --network [default: simple] --help Show this message and exit. // Try it diff --git a/docs_src/parameter_types/number/tutorial001_an_py310.py b/docs_src/parameter_types/number/tutorial001_an_py310.py index 1784b13f06..995f64b67f 100644 --- a/docs_src/parameter_types/number/tutorial001_an_py310.py +++ b/docs_src/parameter_types/number/tutorial001_an_py310.py @@ -7,11 +7,11 @@ @app.command() def main( - id: Annotated[int, typer.Argument(min=0, max=1000)], + ID: Annotated[int, typer.Argument(min=0, max=1000)], age: Annotated[int, typer.Option(min=18)] = 20, score: Annotated[float, typer.Option(max=100)] = 0, ): - print(f"ID is {id}") + print(f"ID is {ID}") print(f"--age is {age}") print(f"--score is {score}") diff --git a/docs_src/parameter_types/number/tutorial001_py310.py b/docs_src/parameter_types/number/tutorial001_py310.py index fc4fe0d30e..37e9e39ff8 100644 --- a/docs_src/parameter_types/number/tutorial001_py310.py +++ b/docs_src/parameter_types/number/tutorial001_py310.py @@ -5,11 +5,11 @@ @app.command() def main( - id: int = typer.Argument(..., min=0, max=1000), + ID: int = typer.Argument(..., min=0, max=1000), age: int = typer.Option(20, min=18), score: float = typer.Option(0, max=100), ): - print(f"ID is {id}") + print(f"ID is {ID}") print(f"--age is {age}") print(f"--score is {score}") diff --git a/docs_src/parameter_types/number/tutorial002_an_py310.py b/docs_src/parameter_types/number/tutorial002_an_py310.py index 5d3835817c..9df5e4838e 100644 --- a/docs_src/parameter_types/number/tutorial002_an_py310.py +++ b/docs_src/parameter_types/number/tutorial002_an_py310.py @@ -7,11 +7,11 @@ @app.command() def main( - id: Annotated[int, typer.Argument(min=0, max=1000)], + ID: Annotated[int, typer.Argument(min=0, max=1000)], rank: Annotated[int, typer.Option(max=10, clamp=True)] = 0, score: Annotated[float, typer.Option(min=0, max=100, clamp=True)] = 0, ): - print(f"ID is {id}") + print(f"ID is {ID}") print(f"--rank is {rank}") print(f"--score is {score}") diff --git a/docs_src/parameter_types/number/tutorial002_py310.py b/docs_src/parameter_types/number/tutorial002_py310.py index c0daadfbd5..f4a624619f 100644 --- a/docs_src/parameter_types/number/tutorial002_py310.py +++ b/docs_src/parameter_types/number/tutorial002_py310.py @@ -5,11 +5,11 @@ @app.command() def main( - id: int = typer.Argument(..., min=0, max=1000), + ID: int = typer.Argument(..., min=0, max=1000), rank: int = typer.Option(0, max=10, clamp=True), score: float = typer.Option(0, min=0, max=100, clamp=True), ): - print(f"ID is {id}") + print(f"ID is {ID}") print(f"--rank is {rank}") print(f"--score is {score}") diff --git a/pyproject.toml b/pyproject.toml index f4d0af0ebb..198f92e4da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] dependencies = [ + "pydantic >=2.5.3", "shellingham >=1.3.0", "rich >=13.8.0", "annotated-doc >=0.0.2", diff --git a/tests/assets/cli/multiapp-docs-title.md b/tests/assets/cli/multiapp-docs-title.md index ffde843736..368275838f 100644 --- a/tests/assets/cli/multiapp-docs-title.md +++ b/tests/assets/cli/multiapp-docs-title.md @@ -65,8 +65,8 @@ $ multiapp sub hello [OPTIONS] **Options**: -* `--name TEXT`: [default: World] -* `--age INTEGER`: The age of the user [default: 0] +* `--name `: [default: World] +* `--age `: The age of the user [default: 0] * `--help`: Show this message and exit. ### `multiapp sub hi` @@ -76,12 +76,12 @@ Say Hi **Usage**: ```console -$ multiapp sub hi [OPTIONS] [USER] +$ multiapp sub hi [OPTIONS] [user] ``` **Arguments**: -* `[USER]`: The name of the user to greet [default: World] +* `[user]`: The name of the user to greet [default: World] **Options**: diff --git a/tests/assets/cli/multiapp-docs.md b/tests/assets/cli/multiapp-docs.md index 67d02568db..e1dad5f804 100644 --- a/tests/assets/cli/multiapp-docs.md +++ b/tests/assets/cli/multiapp-docs.md @@ -65,8 +65,8 @@ $ multiapp sub hello [OPTIONS] **Options**: -* `--name TEXT`: [default: World] -* `--age INTEGER`: The age of the user [default: 0] +* `--name `: [default: World] +* `--age `: The age of the user [default: 0] * `--help`: Show this message and exit. ### `multiapp sub hi` @@ -76,12 +76,12 @@ Say Hi **Usage**: ```console -$ multiapp sub hi [OPTIONS] [USER] +$ multiapp sub hi [OPTIONS] [user] ``` **Arguments**: -* `[USER]`: The name of the user to greet [default: World] +* `[user]`: The name of the user to greet [default: World] **Options**: diff --git a/tests/assets/cli/richformattedapp-docs.md b/tests/assets/cli/richformattedapp-docs.md index 678a2daf6f..a6d1ab1dc3 100644 --- a/tests/assets/cli/richformattedapp-docs.md +++ b/tests/assets/cli/richformattedapp-docs.md @@ -5,13 +5,13 @@ Say cool name of the user [required] -* `[USER_2]`: The world [default: The World] +* `user_1`: The cool name of the user [required] +* `[user_2]`: The world [default: The World] **Options**: diff --git a/tests/assets/completion_argument.py b/tests/assets/completion_argument.py index e2754c4357..ad064b9166 100644 --- a/tests/assets/completion_argument.py +++ b/tests/assets/completion_argument.py @@ -1,10 +1,11 @@ import typer from typer import _click +from typer.core import TyperParameter app = typer.Typer() -def shell_complete(ctx: _click.Context, param: _click.Parameter, incomplete: str): +def shell_complete(ctx: _click.Context, param: TyperParameter, incomplete: str): typer.echo(f"ctx: {ctx.info_name}", err=True) typer.echo(f"arg is: {param.name}", err=True) typer.echo(f"incomplete is: {incomplete}", err=True) diff --git a/tests/test_cli/test_help.py b/tests/test_cli/test_help.py index e829c5801b..3238c0658b 100644 --- a/tests/test_cli/test_help.py +++ b/tests/test_cli/test_help.py @@ -121,7 +121,7 @@ def cmd(value: str) -> None: output_lines = result.output.splitlines() usage_idx = output_lines.index("Usage: very-long-program-name-that-forces-wrap ") args_line = output_lines[usage_idx + 1] - assert args_line.lstrip() == "[OPTIONS] VALUE" + assert args_line.lstrip() == "[OPTIONS] {value}" assert args_line.startswith(" ") diff --git a/tests/test_coercion.py b/tests/test_coercion.py new file mode 100644 index 0000000000..e43f94c320 --- /dev/null +++ b/tests/test_coercion.py @@ -0,0 +1,129 @@ +from pathlib import Path + +import typer +from typer.testing import CliRunner + +runner = CliRunner() + + +def test_coercion() -> None: + app = typer.Typer() + seen: dict[str, object] = {} + + @app.command() + def main(items: list[int], active: bool = False, val=42): + seen["items"] = items + seen["active"] = active + seen["val"] = val + + result = runner.invoke(app, ["1", "2", "--active", "--val", "7"]) + assert result.exit_code == 0, result.output + assert seen == {"items": [1, 2], "active": True, "val": 7} + + +def test_coercion_invalid() -> None: + app = typer.Typer() + + @app.command() + def main(age: int): + pass + + result = runner.invoke(app, ["not-an-int"]) + assert "Input should be a valid integer" in result.stderr + assert result.exit_code == 2 + + +def test_coercion_path(tmp_path: Path) -> None: + target = tmp_path / "config.txt" + target.write_text("hello\n", encoding="utf-8") + app = typer.Typer() + seen: list[Path] = [] + + @app.command() + def main(config: Path = typer.Option(..., exists=True)): + seen.append(config) + + result = runner.invoke(app, ["--config", str(target)]) + assert result.exit_code == 0 + assert seen == [target] + + +def test_coercion_tuple_files(tmp_path: Path) -> None: + first = tmp_path / "first.txt" + second = tmp_path / "second.txt" + first.write_text("first-content\n", encoding="utf-8") + second.write_text("second-content\n", encoding="utf-8") + app = typer.Typer() + seen: list[str] = [] + + @app.command() + def main(files: tuple[typer.FileText, typer.FileText]): + seen.append(files[0].read()) + seen.append(files[1].read()) + + result = runner.invoke(app, [str(first), str(second)]) + assert result.exit_code == 0, result.output + assert seen == ["first-content\n", "second-content\n"] + + +def test_passthrough_runtime_param_default() -> None: + class Widget: + def __init__(self, value: int) -> None: + self.value = value + + def __repr__(self) -> str: + return f"Widget({self.value})" + + app = typer.Typer() + seen: dict[str, Widget] = {} + + @app.command() + def main(val=Widget(42)): + seen["val"] = val + + param = next(p for p in typer.main.get_command(app).params if p.name == "val") + assert param.runtime_param is not None + assert param.runtime_param.annotation is Widget + + result = runner.invoke(app) + assert result.exit_code == 0 + assert isinstance(seen["val"], Widget) + assert seen["val"].value == 42 + + result = runner.invoke(app, ["--val", "666"]) + assert result.exit_code == 2 + # This doesn't work because there's no parser + assert "is not a valid Widget" in result.output + + +def test_widget_parsed_from_cli_with_parser() -> None: + class Widget: + def __init__(self, value: int) -> None: + self.value = value + + def __repr__(self) -> str: + return f"Widget({self.value})" + + def parse_widget(value: str) -> Widget: + return Widget(int(value)) + + app = typer.Typer() + seen: dict[str, Widget] = {} + + @app.command() + def main(val: Widget = typer.Option("42", parser=parse_widget)): + seen["val"] = val + + param = next(p for p in typer.main.get_command(app).params if p.name == "val") + assert param.runtime_param is not None + assert param.runtime_param.annotation is Widget + + result = runner.invoke(app) + assert result.exit_code == 0 + assert isinstance(seen["val"], Widget) + assert seen["val"].value == 42 + + result = runner.invoke(app, ["--val", "666"]) + assert result.exit_code == 0 + assert isinstance(seen["val"], Widget) + assert seen["val"].value == 666 diff --git a/tests/test_core.py b/tests/test_core.py index 2cb759918b..8f47f77af0 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,31 +5,12 @@ import typer._completion_shared import typer.completion from typer import _click -from typer.core import TyperArgument, TyperCommand, TyperGroup, TyperOption, _split_opt +from typer.core import TyperCommand, TyperGroup, _split_opt from typer.testing import CliRunner runner = CliRunner() -def test_human_readable_name() -> None: - app = typer.Typer() - - @app.command() - def main( - my_arg_1: Annotated[str, typer.Argument()], - my_arg_2: Annotated[str, typer.Argument(metavar="META_ARG")], - my_opt: Annotated[str, typer.Option()], - ): - pass # pragma: no cover - - command = typer.main.get_command(app) - params = {param.name: param for param in command.params} - - assert params["my_arg_1"].human_readable_name == "MY_ARG_1" - assert params["my_arg_2"].human_readable_name == "META_ARG" - assert params["my_opt"].human_readable_name == "my_opt" - - def test_parameter_metavar() -> None: app = typer.Typer(rich_markup_mode=None) @@ -42,70 +23,62 @@ def cmd(name: Annotated[str, typer.Option(metavar="CUSTOM")]) -> None: assert "--name CUSTOM" in result.output -def test_parameter_nargs_gt_1() -> None: - param = TyperArgument(param_decls=["value"], type=str, nargs=2) - ctx = _click.Context(TyperCommand(name="cmd")) +def test_tuple_argument_wrong_arity() -> None: + app = typer.Typer() - assert param.type_cast_value(ctx, ("one", "two")) == ("one", "two") + @app.command() + def cmd(value: tuple[str, str]): + pass # pragma: no cover - with pytest.raises( - _click.exceptions.BadParameter, match="Takes 2 values but 1 given." - ): - param.type_cast_value(ctx, ("one",)) + result = runner.invoke(app, ["only-one"]) + assert result.exit_code == 2 + assert "takes 2 values" in result.output -def test_parameter_constructor() -> None: - # no param_decl and expose_value is False: sets name to None - arg = TyperArgument(param_decls=[], expose_value=False) - assert arg.name is None - assert arg.opts == [] - assert arg.secondary_opts == [] +def test_count_option() -> None: + app = typer.Typer() - # no param_decl and expose_value is True: raises - with pytest.raises(TypeError, match="does not have a name."): - TyperArgument(param_decls=[], expose_value=True) + @app.command() + def main(verbose: int = typer.Option(0, "--verbose", "-v", count=True)): + print(verbose) - # len(param_decl) > 1: raises - with pytest.raises(TypeError, match="take exactly one parameter declaration"): - TyperArgument(param_decls=["first", "second"]) + result = runner.invoke(app, ["-vvv"]) + assert result.exit_code == 0 + assert "3" in result.stdout + + +def test_duplicate_declaration_raises() -> None: + app = typer.Typer() + + @app.command() + def main(name: str = typer.Option(..., "name", "name")): + pass # pragma: no cover - # duplicated identifier in option declarations: raises with pytest.raises(TypeError, match="Name 'name' defined twice"): - TyperOption(param_decls=["name", "name"], required=False) + typer.main.get_command(app) + + +def test_invalid_boolean_flag_declaration_raises() -> None: + app = typer.Typer() + + @app.command() + def main(flag: bool = typer.Option(False, "--flag/--flag")): + pass # pragma: no cover - # same true/false flag in boolean option declaration: raises with pytest.raises(ValueError, match="cannot use the same flag for true/false"): - TyperOption(param_decls=["flag", "--flag/--flag"], required=False, is_flag=True) - - # inferred name is not a valid identifier: sets name to None - unnamed_option = TyperOption(param_decls=["--123"], required=False) - assert unnamed_option.name is None - - # no param_decl and prompt=True: raises - with pytest.raises(TypeError, match="'name' is required with 'prompt=True'."): - TyperOption(param_decls=[], expose_value=False, prompt=True, required=False) - - # count works - option = TyperOption( - param_decls=["verbose", "--verbose", "-v"], - type=None, - default=0, - required=False, - count=True, - ) - assert isinstance(option.type, _click.types.IntRange) - assert option.type.min == 0 + typer.main.get_command(app) def test_option_error_hint() -> None: - option = TyperOption( - param_decls=["name", "--name"], - required=False, - show_envvar=True, - envvar="APP_NAME", - ) - hint = option.get_error_hint(_click.Context(TyperCommand(name="cmd"))) - assert "(env var: 'APP_NAME')" in hint + app = typer.Typer() + + @app.command() + def main(age: int = typer.Option(..., envvar="APP_NAME", show_envvar=True)): + pass # pragma: no cover + + result = runner.invoke(app, ["--age", "not-int"]) + assert result.exit_code == 2 + assert "(env var: 'APP_NAME')" in result.output def test_group_init() -> None: @@ -181,31 +154,41 @@ def test_option_resolve_envvar( set_env: bool, expected: str | None, ) -> None: - option = TyperOption( - param_decls=["name", "--name"], - required=False, - envvar=envvar, - ) + context_settings = {"auto_envvar_prefix": auto_prefix} if auto_prefix else {} + app = typer.Typer(context_settings=context_settings) + + @app.command() + def main(name: str = typer.Option("fallback", envvar=envvar)): + print(name) + if set_env: monkeypatch.setenv("APP_NAME", "my-precious") - ctx = _click.Context(TyperCommand(name="cmd"), auto_envvar_prefix=auto_prefix) - assert option.resolve_envvar_value(ctx) == expected + result = runner.invoke(app, []) + assert result.exit_code == 0 + if expected is None: + assert "fallback" in result.stdout + else: + assert expected in result.stdout def test_option_resolve_envvar_list( monkeypatch: pytest.MonkeyPatch, ) -> None: - option = TyperOption( - param_decls=["name", "--name"], - required=False, - envvar=["APP_NAME_1", "APP_NAME_2"], - ) + app = typer.Typer() + + @app.command() + def main( + name: str = typer.Option("fallback", envvar=["APP_NAME_1", "APP_NAME_2"]), + ): + print(name) + monkeypatch.delenv("APP_NAME_1", raising=False) monkeypatch.delenv("APP_NAME_2", raising=False) - ctx = _click.Context(TyperCommand(name="cmd")) - assert option.resolve_envvar_value(ctx) is None + result = runner.invoke(app, []) + assert result.exit_code == 0 + assert "fallback" in result.stdout def test_context_auto_envvar() -> None: @@ -377,3 +360,63 @@ def main(names: list[str] = typer.Option(None)): result = runner.invoke(app, [], default_map={"names": "not-a-list"}) assert result.exit_code == 2 assert "Invalid value" in result.output + + +def test_parameter_name_casing(): + app = typer.Typer() + + @app.command() + def main( + arg1: int, + arg2: int = 42, + arg3: int = typer.Argument(...), + ARG4: int = typer.Argument(42), + ARG5: int = typer.Option(...), + arg6: int = typer.Option(42), + arg7: int = typer.Argument(42, metavar="meta7"), + arg8: int = typer.Argument(metavar="ARG8"), + arg9: int = typer.Option(metavar="ARG9"), + ): + print( + f"arg1={arg1} arg2={arg2} arg3={arg3} ARG4={ARG4} ARG5={ARG5} " + f"arg6={arg6} arg7={arg7} arg8={arg8} arg9={arg9}" + ) + + result = runner.invoke( + app, + [ + "1", + "3", + "4", + "7", + "8", + "--arg2", + "2", + "--ARG5", + "5", + "--arg6", + "6", + "--ARG9", + "9", + ], + ) + assert result.exit_code == 0 + assert ( + "arg1=1 arg2=2 arg3=3 ARG4=4 ARG5=5 arg6=6 arg7=7 arg8=8 arg9=9" + in result.output + ) + + result = runner.invoke(app, ["1", "3", "4", "7", "8", "--ARG5", "5", "--ARG9", "9"]) + assert result.exit_code == 0 + assert ( + "arg1=1 arg2=42 arg3=3 ARG4=4 ARG5=5 arg6=42 arg7=7 arg8=8 arg9=9" + in result.output + ) + + result = runner.invoke(app, ["1", "3", "4", "7", "8", "--arg5", "5", "--ARG9", "9"]) + assert result.exit_code != 0 + assert "No such option: --arg5" in result.output + + result = runner.invoke(app, ["1", "3", "4", "7", "8", "--ARG5", "5", "--arg9", "9"]) + assert result.exit_code != 0 + assert "No such option: --arg9" in result.output diff --git a/tests/test_others.py b/tests/test_others.py index d2cc8696f1..f6790ac2e3 100644 --- a/tests/test_others.py +++ b/tests/test_others.py @@ -12,7 +12,7 @@ import typer.completion from typer import _click from typer.main import solve_typer_info_defaults, solve_typer_info_help -from typer.models import ParameterInfo, TyperInfo +from typer.models import TyperInfo from typer.testing import CliRunner from .utils import requires_completion_permission @@ -32,50 +32,6 @@ def test_defaults_from_info(): assert value -def test_too_many_parsers(): - def custom_parser(value: str) -> int: - return int(value) # pragma: no cover - - class CustomClickParser(_click.types.ParamType): - name = "custom_parser" - - def convert( - self, - value: str, - param: _click.Parameter | None, - ctx: _click.Context | None, - ) -> typing.Any: - return int(value) # pragma: no cover - - expected_error = ( - "Multiple custom type parsers provided. " - "`parser` and `click_type` may not both be provided." - ) - - with pytest.raises(ValueError, match=expected_error): - ParameterInfo(parser=custom_parser, click_type=CustomClickParser()) - - -def test_valid_parser_permutations(): - def custom_parser(value: str) -> int: - return int(value) # pragma: no cover - - class CustomClickParser(_click.types.ParamType): - name = "custom_parser" - - def convert( - self, - value: str, - param: _click.Parameter | None, - ctx: _click.Context | None, - ) -> typing.Any: - return int(value) # pragma: no cover - - ParameterInfo() - ParameterInfo(parser=custom_parser) - ParameterInfo(click_type=CustomClickParser()) - - @requires_completion_permission def test_install_invalid_shell(): app = typer.Typer() @@ -450,7 +406,8 @@ def main(arg1, arg2: int, arg3: "int", arg4: bool = False, arg5: "bool" = False) result = runner.invoke(app, ["Hello", "2", "invalid"]) - assert "Invalid value for 'ARG3': 'invalid' is not a valid integer" in result.output + assert "Invalid value for 'arg3'" in result.output + assert "Input should be a valid integer" in result.output result = runner.invoke(app, ["Hello", "2", "3", "--arg4", "--arg5"]) assert ( "arg1: Hello\narg2: 2\narg3: 3\narg4: True\narg5: True\n" diff --git a/tests/test_prog_name.py b/tests/test_prog_name.py index cfb5a3464f..5de6fd3f7a 100644 --- a/tests/test_prog_name.py +++ b/tests/test_prog_name.py @@ -10,4 +10,4 @@ def test_custom_prog_name(): capture_output=True, encoding="utf-8", ) - assert "Usage: custom-name [OPTIONS] I" in result.stdout + assert "Usage: custom-name [OPTIONS] {i}" in result.stdout diff --git a/tests/test_rich_markup_mode.py b/tests/test_rich_markup_mode.py index abfae82790..94fb4c2ae8 100644 --- a/tests/test_rich_markup_mode.py +++ b/tests/test_rich_markup_mode.py @@ -25,7 +25,7 @@ def main(arg: str): assert "Hello World" in result.stdout result = runner.invoke(app, ["--help"]) - assert "ARG [required]" in result.stdout + assert "arg [required]" in result.stdout assert all(c not in result.stdout for c in rounded) diff --git a/tests/test_rich_utils.py b/tests/test_rich_utils.py index beb914b961..6d11652d77 100644 --- a/tests/test_rich_utils.py +++ b/tests/test_rich_utils.py @@ -101,7 +101,7 @@ def main(bar: str): result = runner.invoke(app, ["--help"]) assert "Usage" in result.stdout - assert "BAR" in result.stdout + assert "{bar}" in result.stdout @needs_rich @@ -233,8 +233,8 @@ def main( arg1: int, arg2: int = 42, arg3: int = typer.Argument(...), - arg4: int = typer.Argument(42), - arg5: int = typer.Option(...), + ARG4: int = typer.Argument(42), + ARG5: int = typer.Option(...), arg6: int = typer.Option(42), arg7: int = typer.Argument(42, metavar="meta7"), arg8: int = typer.Argument(metavar="ARG8"), @@ -243,23 +243,26 @@ def main( pass # pragma: no cover result = runner.invoke(app, ["--help"]) - assert "Usage: main [OPTIONS] ARG1 ARG3 [ARG4] [meta7] ARG8 arg9" in result.stdout + assert ( + "Usage: main [OPTIONS] {arg1} {arg3} [ARG4] [meta7] {ARG8} {arg9}" + in result.stdout + ) out_nospace = result.stdout.replace(" ", "") # arguments - assert "arg1INTEGER" in out_nospace - assert "arg3INTEGER" in out_nospace - assert "[arg4]INTEGER" in out_nospace - assert "[meta7]INTEGER" in out_nospace - assert "ARG8INTEGER" in out_nospace - assert "arg9INTEGER" in out_nospace + assert "arg1" in out_nospace + assert "arg3" in out_nospace + assert "[ARG4]" in out_nospace + assert "[meta7]" in out_nospace + assert "ARG8" in out_nospace + assert "arg9" in out_nospace assert "arg7" not in result.stdout.lower() assert "arg8" not in result.stdout assert "ARG9" not in result.stdout # options - assert "arg2INTEGER" in out_nospace - assert "arg5INTEGER" in out_nospace - assert "arg6INTEGER" in out_nospace + assert "--arg2" in out_nospace + assert "--ARG5" in out_nospace + assert "--arg6" in out_nospace diff --git a/tests/test_tutorial/test_app_dir/test_tutorial001.py b/tests/test_tutorial/test_app_dir/test_tutorial001.py index a0f5e001dc..00ed43a79c 100644 --- a/tests/test_tutorial/test_app_dir/test_tutorial001.py +++ b/tests/test_tutorial/test_app_dir/test_tutorial001.py @@ -11,20 +11,20 @@ runner = CliRunner() +@pytest.fixture(name="app_dir") +def isolated_app_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + monkeypatch.setattr(typer, "get_app_dir", lambda app_name: str(tmp_path)) + return tmp_path + + @pytest.fixture(name="config_file") -def create_config_file(): - app_dir = Path(typer.get_app_dir("my-super-cli-app")) - app_dir.mkdir(parents=True, exist_ok=True) +def create_config_file(app_dir: Path) -> Path: config_path = app_dir / "config.json" config_path.touch(exist_ok=True) - - yield config_path - - config_path.unlink() - app_dir.rmdir() + return config_path -def test_cli_config_doesnt_exist(): +def test_cli_config_doesnt_exist(app_dir: Path): result = runner.invoke(mod.app) assert result.exit_code == 0 assert "Config file doesn't exist yet" in result.output diff --git a/tests/test_tutorial/test_arguments/test_default/test_tutorial001.py b/tests/test_tutorial/test_arguments/test_default/test_tutorial001.py index 75104a7d9b..81212ce98b 100644 --- a/tests/test_tutorial/test_arguments/test_default/test_tutorial001.py +++ b/tests/test_tutorial/test_arguments/test_default/test_tutorial001.py @@ -25,7 +25,7 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAME]" in result.output + assert "[OPTIONS] [name]" in result.output assert "Arguments" in result.output assert "[default: Wade Wilson]" in result.output diff --git a/tests/test_tutorial/test_arguments/test_default/test_tutorial002.py b/tests/test_tutorial/test_arguments/test_default/test_tutorial002.py index f68b61df82..37eedd779d 100644 --- a/tests/test_tutorial/test_arguments/test_default/test_tutorial002.py +++ b/tests/test_tutorial/test_arguments/test_default/test_tutorial002.py @@ -25,7 +25,7 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAME]" in result.output + assert "[OPTIONS] [name]" in result.output assert "Arguments" in result.output assert "[default: (dynamic)]" in result.output diff --git a/tests/test_tutorial/test_arguments/test_envvar/test_tutorial001.py b/tests/test_tutorial/test_arguments/test_envvar/test_tutorial001.py index 689030a258..179f6936fc 100644 --- a/tests/test_tutorial/test_arguments/test_envvar/test_tutorial001.py +++ b/tests/test_tutorial/test_arguments/test_envvar/test_tutorial001.py @@ -27,7 +27,7 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAME]" in result.output + assert "[OPTIONS] [name]" in result.output assert "Arguments" in result.output assert "env var: AWESOME_NAME" in result.output assert "default: World" in result.output @@ -37,7 +37,7 @@ def test_help_no_rich(monkeypatch: pytest.MonkeyPatch, mod: ModuleType): monkeypatch.setattr(typer.core, "HAS_RICH", False) result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAME]" in result.output + assert "[OPTIONS] [name]" in result.output assert "Arguments" in result.output assert "env var: AWESOME_NAME" in result.output assert "default: World" in result.output diff --git a/tests/test_tutorial/test_arguments/test_envvar/test_tutorial002.py b/tests/test_tutorial/test_arguments/test_envvar/test_tutorial002.py index 23679b04b5..3d06a1c869 100644 --- a/tests/test_tutorial/test_arguments/test_envvar/test_tutorial002.py +++ b/tests/test_tutorial/test_arguments/test_envvar/test_tutorial002.py @@ -25,7 +25,7 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAME]" in result.output + assert "[OPTIONS] [name]" in result.output assert "Arguments" in result.output assert "env var: AWESOME_NAME, GOD_NAME" in result.output assert "default: World" in result.output diff --git a/tests/test_tutorial/test_arguments/test_envvar/test_tutorial003.py b/tests/test_tutorial/test_arguments/test_envvar/test_tutorial003.py index ee979e4762..042620f9a0 100644 --- a/tests/test_tutorial/test_arguments/test_envvar/test_tutorial003.py +++ b/tests/test_tutorial/test_arguments/test_envvar/test_tutorial003.py @@ -25,7 +25,7 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAME]" in result.output + assert "[OPTIONS] [name]" in result.output assert "Arguments" in result.output assert "env var: AWESOME_NAME" not in result.output assert "default: World" in result.output diff --git a/tests/test_tutorial/test_arguments/test_help/test_tutorial001.py b/tests/test_tutorial/test_arguments/test_help/test_tutorial001.py index fabe28be28..e2e24fe647 100644 --- a/tests/test_tutorial/test_arguments/test_help/test_tutorial001.py +++ b/tests/test_tutorial/test_arguments/test_help/test_tutorial001.py @@ -26,9 +26,9 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] NAME" in result.output + assert "[OPTIONS] {name}" in result.output assert "Arguments" in result.output - assert "NAME" in result.output + assert "{name}" in result.output assert "The name of the user to greet" in result.output assert "[required]" in result.output @@ -37,9 +37,9 @@ def test_help_no_rich(monkeypatch: pytest.MonkeyPatch, mod: ModuleType): monkeypatch.setattr(typer.core, "HAS_RICH", False) result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] NAME" in result.output + assert "[OPTIONS] {name}" in result.output assert "Arguments" in result.output - assert "NAME" in result.output + assert "{name}" in result.output assert "The name of the user to greet" in result.output assert "[required]" in result.output diff --git a/tests/test_tutorial/test_arguments/test_help/test_tutorial002.py b/tests/test_tutorial/test_arguments/test_help/test_tutorial002.py index 91eb30e375..f3fd10a9d1 100644 --- a/tests/test_tutorial/test_arguments/test_help/test_tutorial002.py +++ b/tests/test_tutorial/test_arguments/test_help/test_tutorial002.py @@ -25,10 +25,10 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] NAME" in result.output + assert "[OPTIONS] {name}" in result.output assert "Say hi to NAME very gently, like Dirk." in result.output assert "Arguments" in result.output - assert "NAME" in result.output + assert "{name}" in result.output assert "The name of the user to greet" in result.output assert "[required]" in result.output diff --git a/tests/test_tutorial/test_arguments/test_help/test_tutorial003.py b/tests/test_tutorial/test_arguments/test_help/test_tutorial003.py index 20bd6b76ea..dc7ef0234c 100644 --- a/tests/test_tutorial/test_arguments/test_help/test_tutorial003.py +++ b/tests/test_tutorial/test_arguments/test_help/test_tutorial003.py @@ -25,10 +25,10 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAME]" in result.output + assert "[OPTIONS] [name]" in result.output assert "Say hi to NAME very gently, like Dirk." in result.output assert "Arguments" in result.output - assert "NAME" in result.output + assert "[name]" in result.output assert "Who to greet" in result.output assert "[default: World]" in result.output diff --git a/tests/test_tutorial/test_arguments/test_help/test_tutorial004.py b/tests/test_tutorial/test_arguments/test_help/test_tutorial004.py index 7a20f48979..451e3272ad 100644 --- a/tests/test_tutorial/test_arguments/test_help/test_tutorial004.py +++ b/tests/test_tutorial/test_arguments/test_help/test_tutorial004.py @@ -25,10 +25,10 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAME]" in result.output + assert "[OPTIONS] [name]" in result.output assert "Say hi to NAME very gently, like Dirk." in result.output assert "Arguments" in result.output - assert "NAME" in result.output + assert "[name]" in result.output assert "Who to greet" in result.output assert "[default: World]" not in result.output diff --git a/tests/test_tutorial/test_arguments/test_help/test_tutorial005.py b/tests/test_tutorial/test_arguments/test_help/test_tutorial005.py index 8f6d356d81..4cb17b4e86 100644 --- a/tests/test_tutorial/test_arguments/test_help/test_tutorial005.py +++ b/tests/test_tutorial/test_arguments/test_help/test_tutorial005.py @@ -25,7 +25,7 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "Usage: main [OPTIONS] [NAME]" in result.output + assert "Usage: main [OPTIONS] [name]" in result.output assert "Arguments" in result.output assert "Who to greet" in result.output assert "[default: (Deadpoolio the amazing's name)]" in result.output diff --git a/tests/test_tutorial/test_arguments/test_help/test_tutorial008.py b/tests/test_tutorial/test_arguments/test_help/test_tutorial008.py index f21c883434..004dbbcfe1 100644 --- a/tests/test_tutorial/test_arguments/test_help/test_tutorial008.py +++ b/tests/test_tutorial/test_arguments/test_help/test_tutorial008.py @@ -26,7 +26,7 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAME]" in result.output + assert "[OPTIONS] [name]" in result.output assert "Say hi to NAME very gently, like Dirk." in result.output assert "Arguments" not in result.output assert "[default: World]" not in result.output @@ -36,7 +36,7 @@ def test_help_no_rich(monkeypatch: pytest.MonkeyPatch, mod: ModuleType): monkeypatch.setattr(typer.core, "HAS_RICH", False) result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAME]" in result.output + assert "[OPTIONS] [name]" in result.output assert "Say hi to NAME very gently, like Dirk." in result.output assert "Arguments" not in result.output assert "[default: World]" not in result.output diff --git a/tests/test_tutorial/test_arguments/test_optional/test_tutorial000.py b/tests/test_tutorial/test_arguments/test_optional/test_tutorial000.py index 8f0bd34c7f..76d86ed82e 100644 --- a/tests/test_tutorial/test_arguments/test_optional/test_tutorial000.py +++ b/tests/test_tutorial/test_arguments/test_optional/test_tutorial000.py @@ -39,7 +39,7 @@ def test_cli(app: typer.Typer): def test_cli_missing_argument(app: typer.Typer): result = runner.invoke(app) assert result.exit_code == 2 - assert "Missing argument 'NAME'" in result.output + assert "Missing argument 'name'" in result.output def test_script(mod: ModuleType): diff --git a/tests/test_tutorial/test_arguments/test_optional/test_tutorial001.py b/tests/test_tutorial/test_arguments/test_optional/test_tutorial001.py index 7e26b9eb86..60d4b5d30d 100644 --- a/tests/test_tutorial/test_arguments/test_optional/test_tutorial001.py +++ b/tests/test_tutorial/test_arguments/test_optional/test_tutorial001.py @@ -26,7 +26,7 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_call_no_arg(mod: ModuleType): result = runner.invoke(mod.app) assert result.exit_code != 0 - assert "Missing argument 'NAME'." in result.output + assert "Missing argument 'name'." in result.output def test_call_no_arg_standalone(mod: ModuleType): @@ -40,7 +40,7 @@ def test_call_no_arg_no_rich(monkeypatch: pytest.MonkeyPatch, mod: ModuleType): monkeypatch.setattr(typer.core, "HAS_RICH", False) result = runner.invoke(mod.app) assert result.exit_code != 0 - assert "Error: Missing argument 'NAME'" in result.output + assert "Error: Missing argument 'name'" in result.output def test_call_arg(mod: ModuleType): diff --git a/tests/test_tutorial/test_arguments/test_optional/test_tutorial002.py b/tests/test_tutorial/test_arguments/test_optional/test_tutorial002.py index 3e3fdba384..b1b1c9b07b 100644 --- a/tests/test_tutorial/test_arguments/test_optional/test_tutorial002.py +++ b/tests/test_tutorial/test_arguments/test_optional/test_tutorial002.py @@ -25,7 +25,7 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAME]" in result.output + assert "[OPTIONS] [name]" in result.output def test_call_no_arg(mod: ModuleType): diff --git a/tests/test_tutorial/test_arguments/test_optional/test_tutorial003.py b/tests/test_tutorial/test_arguments/test_optional/test_tutorial003.py index 60addad04d..274616e04c 100644 --- a/tests/test_tutorial/test_arguments/test_optional/test_tutorial003.py +++ b/tests/test_tutorial/test_arguments/test_optional/test_tutorial003.py @@ -15,7 +15,7 @@ def test_call_no_arg(): result = runner.invoke(app) assert result.exit_code != 0 - assert "Missing argument 'NAME'." in result.output + assert "Missing argument 'name'." in result.output def test_call_no_arg_standalone(): @@ -29,7 +29,7 @@ def test_call_no_arg_no_rich(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(typer.core, "HAS_RICH", False) result = runner.invoke(app) assert result.exit_code != 0 - assert "Error: Missing argument 'NAME'" in result.output + assert "Error: Missing argument 'name'" in result.output def test_call_arg(): diff --git a/tests/test_tutorial/test_commands/test_arguments/test_tutorial001.py b/tests/test_tutorial/test_commands/test_arguments/test_tutorial001.py index 4efc1fe9ef..58ac250108 100644 --- a/tests/test_tutorial/test_commands/test_arguments/test_tutorial001.py +++ b/tests/test_tutorial/test_commands/test_arguments/test_tutorial001.py @@ -13,13 +13,13 @@ def test_help_create(): result = runner.invoke(app, ["create", "--help"]) assert result.exit_code == 0 - assert "create [OPTIONS] USERNAME" in result.output + assert "create [OPTIONS] {username}" in result.output def test_help_delete(): result = runner.invoke(app, ["delete", "--help"]) assert result.exit_code == 0 - assert "delete [OPTIONS] USERNAME" in result.output + assert "delete [OPTIONS] {username}" in result.output def test_create(): diff --git a/tests/test_tutorial/test_commands/test_help/test_tutorial001.py b/tests/test_tutorial/test_commands/test_help/test_tutorial001.py index 9993765b37..751022c673 100644 --- a/tests/test_tutorial/test_commands/test_help/test_tutorial001.py +++ b/tests/test_tutorial/test_commands/test_help/test_tutorial001.py @@ -39,14 +39,14 @@ def test_help(mod: ModuleType): def test_help_create(mod: ModuleType): result = runner.invoke(mod.app, ["create", "--help"]) assert result.exit_code == 0 - assert "create [OPTIONS] USERNAME" in result.output + assert "create [OPTIONS] {username}" in result.output assert "Create a new user with USERNAME." in result.output def test_help_delete(mod: ModuleType): result = runner.invoke(mod.app, ["delete", "--help"]) assert result.exit_code == 0 - assert "delete [OPTIONS] USERNAME" in result.output + assert "delete [OPTIONS] {username}" in result.output assert "Delete a user with USERNAME." in result.output assert "--force" in result.output assert "--no-force" in result.output diff --git a/tests/test_tutorial/test_commands/test_help/test_tutorial007.py b/tests/test_tutorial/test_commands/test_help/test_tutorial007.py index 76c7cb2a4f..7b8fe367ed 100644 --- a/tests/test_tutorial/test_commands/test_help/test_tutorial007.py +++ b/tests/test_tutorial/test_commands/test_help/test_tutorial007.py @@ -36,7 +36,7 @@ def test_main_help(mod: ModuleType): def test_create_help(mod: ModuleType): result = runner.invoke(mod.app, ["create", "--help"]) assert result.exit_code == 0 - assert "create [OPTIONS] USERNAME [LASTNAME]" in result.output + assert "create [OPTIONS] {username} [lastname]" in result.output assert "username" in result.output assert "The username to create" in result.output assert "Secondary Arguments" in result.output diff --git a/tests/test_tutorial/test_first_steps/test_tutorial002.py b/tests/test_tutorial/test_first_steps/test_tutorial002.py index 952f983963..baceee2bc2 100644 --- a/tests/test_tutorial/test_first_steps/test_tutorial002.py +++ b/tests/test_tutorial/test_first_steps/test_tutorial002.py @@ -15,7 +15,7 @@ def test_1(): result = runner.invoke(app, []) assert result.exit_code != 0 - assert "Missing argument 'NAME'" in result.output + assert "Missing argument 'name'" in result.output def test_2(): diff --git a/tests/test_tutorial/test_first_steps/test_tutorial003.py b/tests/test_tutorial/test_first_steps/test_tutorial003.py index e92f23521e..5c5ddb39cf 100644 --- a/tests/test_tutorial/test_first_steps/test_tutorial003.py +++ b/tests/test_tutorial/test_first_steps/test_tutorial003.py @@ -15,7 +15,7 @@ def test_1(): result = runner.invoke(app, ["Camila"]) assert result.exit_code != 0 - assert "Missing argument 'LASTNAME'" in result.output + assert "Missing argument 'lastname'" in result.output def test_2(): diff --git a/tests/test_tutorial/test_first_steps/test_tutorial004.py b/tests/test_tutorial/test_first_steps/test_tutorial004.py index 6ac326f074..0232605f97 100644 --- a/tests/test_tutorial/test_first_steps/test_tutorial004.py +++ b/tests/test_tutorial/test_first_steps/test_tutorial004.py @@ -16,9 +16,9 @@ def test_help(): result = runner.invoke(app, ["--help"]) assert result.exit_code == 0 assert "Arguments" in result.output - assert "NAME" in result.output + assert "{name}" in result.output assert "[required]" in result.output - assert "LASTNAME" in result.output + assert "{lastname}" in result.output assert "[required]" in result.output assert "--formal" in result.output assert "--no-formal" in result.output diff --git a/tests/test_tutorial/test_first_steps/test_tutorial005.py b/tests/test_tutorial/test_first_steps/test_tutorial005.py index 87ce389410..48c80573f1 100644 --- a/tests/test_tutorial/test_first_steps/test_tutorial005.py +++ b/tests/test_tutorial/test_first_steps/test_tutorial005.py @@ -16,10 +16,10 @@ def test_help(): result = runner.invoke(app, ["--help"]) assert result.exit_code == 0 assert "Arguments" in result.output - assert "NAME" in result.output + assert "{name}" in result.output assert "[required]" in result.output assert "--lastname" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "--formal" in result.output assert "--no-formal" in result.output diff --git a/tests/test_tutorial/test_multiple_values/test_arguments_with_multiple_values/test_tutorial002.py b/tests/test_tutorial/test_multiple_values/test_arguments_with_multiple_values/test_tutorial002.py index 0fd8ed8f33..c31fabce69 100644 --- a/tests/test_tutorial/test_multiple_values/test_arguments_with_multiple_values/test_tutorial002.py +++ b/tests/test_tutorial/test_multiple_values/test_arguments_with_multiple_values/test_tutorial002.py @@ -27,7 +27,7 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "[OPTIONS] [NAMES]..." in result.output + assert "[OPTIONS] [names]..." in result.output assert "Arguments" in result.output assert "[default: Harry, Hermione, Ron]" in result.output diff --git a/tests/test_tutorial/test_options/test_help/test_tutorial001.py b/tests/test_tutorial/test_options/test_help/test_tutorial001.py index 067bdffe51..6432aa035e 100644 --- a/tests/test_tutorial/test_options/test_help/test_tutorial001.py +++ b/tests/test_tutorial/test_options/test_help/test_tutorial001.py @@ -22,7 +22,8 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: return mod -def test_help(mod: ModuleType): +def test_help(mod: ModuleType, monkeypatch): + monkeypatch.setenv("COLUMNS", "200") result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "Say hi to NAME, optionally with a --lastname." in result.output diff --git a/tests/test_tutorial/test_options/test_help/test_tutorial003.py b/tests/test_tutorial/test_options/test_help/test_tutorial003.py index de99469dfa..4bb1cc10a9 100644 --- a/tests/test_tutorial/test_options/test_help/test_tutorial003.py +++ b/tests/test_tutorial/test_options/test_help/test_tutorial003.py @@ -32,7 +32,7 @@ def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "--fullname" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "[default: Wade Wilson]" not in result.output diff --git a/tests/test_tutorial/test_options/test_help/test_tutorial004.py b/tests/test_tutorial/test_options/test_help/test_tutorial004.py index 22f902197b..5936ebca97 100644 --- a/tests/test_tutorial/test_options/test_help/test_tutorial004.py +++ b/tests/test_tutorial/test_options/test_help/test_tutorial004.py @@ -34,7 +34,7 @@ def test_help(monkeypatch, mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "--fullname" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "[default: (Deadpoolio the amazing's name)]" in result.output diff --git a/tests/test_tutorial/test_options/test_name/test_tutorial001.py b/tests/test_tutorial/test_options/test_name/test_tutorial001.py index 1c4f2a7cf0..d375a92bfd 100644 --- a/tests/test_tutorial/test_options/test_name/test_tutorial001.py +++ b/tests/test_tutorial/test_options/test_name/test_tutorial001.py @@ -26,7 +26,7 @@ def test_option_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "--name" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "--user-name" not in result.output diff --git a/tests/test_tutorial/test_options/test_name/test_tutorial002.py b/tests/test_tutorial/test_options/test_name/test_tutorial002.py index ac9b3db6f1..9e0dcb4892 100644 --- a/tests/test_tutorial/test_options/test_name/test_tutorial002.py +++ b/tests/test_tutorial/test_options/test_name/test_tutorial002.py @@ -27,7 +27,7 @@ def test_option_help(mod: ModuleType): assert result.exit_code == 0 assert "-n" in result.output assert "--name" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "--user-name" not in result.output diff --git a/tests/test_tutorial/test_options/test_name/test_tutorial003.py b/tests/test_tutorial/test_options/test_name/test_tutorial003.py index 0503b410af..e20ca92a63 100644 --- a/tests/test_tutorial/test_options/test_name/test_tutorial003.py +++ b/tests/test_tutorial/test_options/test_name/test_tutorial003.py @@ -26,7 +26,7 @@ def test_option_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "-n" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "--user-name" not in result.output assert "--name" not in result.output diff --git a/tests/test_tutorial/test_options/test_name/test_tutorial004.py b/tests/test_tutorial/test_options/test_name/test_tutorial004.py index f200021042..862890b9c7 100644 --- a/tests/test_tutorial/test_options/test_name/test_tutorial004.py +++ b/tests/test_tutorial/test_options/test_name/test_tutorial004.py @@ -27,7 +27,7 @@ def test_option_help(mod: ModuleType): assert result.exit_code == 0 assert "-n" in result.output assert "--user-name" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "--name" not in result.output diff --git a/tests/test_tutorial/test_options/test_name/test_tutorial005.py b/tests/test_tutorial/test_options/test_name/test_tutorial005.py index 492a0fc4ef..520073b98e 100644 --- a/tests/test_tutorial/test_options/test_name/test_tutorial005.py +++ b/tests/test_tutorial/test_options/test_name/test_tutorial005.py @@ -27,7 +27,7 @@ def test_option_help(mod: ModuleType): assert result.exit_code == 0 assert "-n" in result.output assert "--name" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "-f" in result.output assert "--formal" in result.output diff --git a/tests/test_tutorial/test_options/test_password/test_tutorial001.py b/tests/test_tutorial/test_options/test_password/test_tutorial001.py index 4aca1c0bd8..8374a497a3 100644 --- a/tests/test_tutorial/test_options/test_password/test_tutorial001.py +++ b/tests/test_tutorial/test_options/test_password/test_tutorial001.py @@ -44,7 +44,7 @@ def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 output_without_double_spaces = strip_double_spaces(result.output) - assert "--email TEXT [required]" in output_without_double_spaces + assert "--email [required]" in output_without_double_spaces def test_script(mod: ModuleType): diff --git a/tests/test_tutorial/test_options/test_password/test_tutorial002.py b/tests/test_tutorial/test_options/test_password/test_tutorial002.py index 08d43ff87e..4b63dbc20c 100644 --- a/tests/test_tutorial/test_options/test_password/test_tutorial002.py +++ b/tests/test_tutorial/test_options/test_password/test_tutorial002.py @@ -48,7 +48,7 @@ def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 output_without_double_spaces = strip_double_spaces(result.output) - assert "--password TEXT [required]" in output_without_double_spaces + assert "--password [required]" in output_without_double_spaces def test_script(mod: ModuleType): diff --git a/tests/test_tutorial/test_options/test_prompt/test_tutorial001.py b/tests/test_tutorial/test_options/test_prompt/test_tutorial001.py index 1236f33e57..c0d1e5770c 100644 --- a/tests/test_tutorial/test_options/test_prompt/test_tutorial001.py +++ b/tests/test_tutorial/test_options/test_prompt/test_tutorial001.py @@ -39,7 +39,7 @@ def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "--lastname" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "[required]" in result.output diff --git a/tests/test_tutorial/test_options/test_prompt/test_tutorial002.py b/tests/test_tutorial/test_options/test_prompt/test_tutorial002.py index 0947a5e77e..42def22bd6 100644 --- a/tests/test_tutorial/test_options/test_prompt/test_tutorial002.py +++ b/tests/test_tutorial/test_options/test_prompt/test_tutorial002.py @@ -39,7 +39,7 @@ def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "--lastname" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "[required]" in result.output diff --git a/tests/test_tutorial/test_options/test_prompt/test_tutorial003.py b/tests/test_tutorial/test_options/test_prompt/test_tutorial003.py index 6016dc4acf..daefde09fa 100644 --- a/tests/test_tutorial/test_options/test_prompt/test_tutorial003.py +++ b/tests/test_tutorial/test_options/test_prompt/test_tutorial003.py @@ -48,7 +48,7 @@ def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "--project-name" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "[required]" in result.output diff --git a/tests/test_tutorial/test_options/test_required/test_tutorial001_tutorial002.py b/tests/test_tutorial/test_options/test_required/test_tutorial001_tutorial002.py index 72ad6e04cd..5003c80ae3 100644 --- a/tests/test_tutorial/test_options/test_required/test_tutorial001_tutorial002.py +++ b/tests/test_tutorial/test_options/test_required/test_tutorial001_tutorial002.py @@ -41,7 +41,7 @@ def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "--lastname" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "[required]" in result.output @@ -50,7 +50,7 @@ def test_help_no_rich(monkeypatch: pytest.MonkeyPatch, mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "--lastname" in result.output - assert "TEXT" in result.output + assert "" in result.output assert "[required]" in result.output diff --git a/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py index 32c07214c0..6941a35c37 100644 --- a/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_bool/test_tutorial001.py @@ -4,7 +4,6 @@ from types import ModuleType import pytest -import typer from typer.testing import CliRunner runner = CliRunner() @@ -23,12 +22,6 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: return mod -def test_type_repr(mod: ModuleType): - command = typer.main.get_command(mod.app) - force_param = next(param for param in command.params if param.name == "force") - assert repr(force_param.type) == "BOOL" - - def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 diff --git a/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py index eea63c5a8b..e4f2130342 100644 --- a/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_datetime/test_tutorial001.py @@ -2,7 +2,6 @@ import sys from datetime import datetime -import typer from typer.testing import CliRunner from docs_src.parameter_types.datetime import tutorial001_py310 as mod @@ -11,16 +10,10 @@ app = mod.app -def test_type_repr(): - command = typer.main.get_command(app) - birth_param = next(param for param in command.params if param.name == "birth") - assert repr(birth_param.type) == "DateTime" - - def test_help(): result = runner.invoke(app, ["--help"]) assert result.exit_code == 0 - assert "[%Y-%m-%d|%Y-%m-%dT%H:%M:%S|%Y-%m-%d %H:%M:%S]" in result.output + assert "<%Y-%m-%d>" in result.output def test_main(): @@ -42,14 +35,8 @@ def test_main_datetime_object(): def test_invalid(): result = runner.invoke(app, ["july-19-1989"]) assert result.exit_code != 0 - assert ( - "Invalid value for 'BIRTH:[%Y-%m-%d|%Y-%m-%dT%H:%M:%S|%Y-%m-%d %H:%M:%S]':" - in result.output - ) - assert "'july-19-1989' does not match the formats" in result.output - assert "%Y-%m-%d" in result.output - assert "%Y-%m-%dT%H:%M:%S" in result.output - assert "%Y-%m-%d %H:%M:%S" in result.output + assert "Invalid value for 'birth'" in result.output + assert "should be a valid datetime" in result.output def test_script(): diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial001.py index 6ec636d7ec..a27d32f1bb 100644 --- a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial001.py @@ -13,7 +13,7 @@ def test_help(): result = runner.invoke(app, ["--help"]) assert result.exit_code == 0 assert "--network" in result.output - assert "[simple|conv|lstm]" in result.output + assert "" in result.output assert "default: simple" in result.output diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial003.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial003.py index d6c0e532c9..7214a522df 100644 --- a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial003.py +++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial003.py @@ -23,10 +23,11 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): + mod.app.rich_markup_mode = None result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "--groceries" in result.output - assert "[Eggs|Bacon|Cheese]" in result.output + assert "" in result.output assert "default: Eggs, Cheese" in result.output diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial004.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial004.py index 84f0eb3b16..0791fab333 100644 --- a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial004.py +++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial004.py @@ -25,7 +25,7 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 - assert "--network [simple|conv|lstm]" in result.output.replace(" ", "") + assert "--network " in result.output.replace(" ", "") def test_main(mod): diff --git a/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py index ee0daa9f06..9832e47045 100644 --- a/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_index/test_tutorial001.py @@ -1,7 +1,6 @@ import subprocess import sys -import typer from typer.testing import CliRunner from docs_src.parameter_types.index import tutorial001_py310 as mod @@ -10,23 +9,13 @@ app = mod.app -def test_type_repr(): - command = typer.main.get_command(app) - age_param = next(param for param in command.params if param.name == "age") - height_meters_param = next( - param for param in command.params if param.name == "height_meters" - ) - assert repr(age_param.type) == "INT" - assert repr(height_meters_param.type) == "FLOAT" - - def test_help(): result = runner.invoke(app, ["--help"]) assert result.exit_code == 0 assert "--age" in result.output - assert "INTEGER" in result.output + assert "" in result.output assert "--height-meters" in result.output - assert "FLOAT" in result.output + assert "" in result.output def test_params(): @@ -44,7 +33,7 @@ def test_invalid(): result = runner.invoke(app, ["Camila", "--age", "15.3"]) assert result.exit_code != 0 assert "Invalid value for '--age'" in result.output - assert "'15.3' is not a valid integer" in result.output + assert "Input should be a valid integer" in result.output def test_script(): diff --git a/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py index ffe3ad09a0..3f6239592d 100644 --- a/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_number/test_tutorial001.py @@ -24,26 +24,13 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: return mod -def test_type_repr(mod: ModuleType): - command = typer.main.get_command(mod.app) - - id_param = next(param for param in command.params if param.name == "id") - assert repr(id_param.type) == "" - - age_param = next(param for param in command.params if param.name == "age") - assert repr(age_param.type) == "=18>" - - score_param = next(param for param in command.params if param.name == "score") - assert repr(score_param.type) == "" - - def test_help(mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "--age" in result.output - assert "INTEGER RANGE" in result.output + assert "int range" in result.output assert "--score" in result.output - assert "FLOAT RANGE" in result.output + assert "float range" in result.output def test_help_no_rich(monkeypatch: pytest.MonkeyPatch, mod: ModuleType): @@ -51,9 +38,9 @@ def test_help_no_rich(monkeypatch: pytest.MonkeyPatch, mod: ModuleType): result = runner.invoke(mod.app, ["--help"]) assert result.exit_code == 0 assert "--age" in result.output - assert "INTEGER RANGE" in result.output + assert "int range" in result.output assert "--score" in result.output - assert "FLOAT RANGE" in result.output + assert "float range" in result.output def test_params(mod: ModuleType): @@ -67,16 +54,15 @@ def test_params(mod: ModuleType): def test_invalid_id(mod: ModuleType): result = runner.invoke(mod.app, ["1002"]) assert result.exit_code != 0 - assert ( - "Invalid value for 'ID': 1002 is not in the range 0<=x<=1000." in result.output - ) + assert "Invalid value for 'ID'" in result.output + assert "should be less than or equal to 1000" in result.output def test_invalid_age(mod: ModuleType): result = runner.invoke(mod.app, ["5", "--age", "15"]) assert result.exit_code != 0 assert "Invalid value for '--age'" in result.output - assert "15 is not in the range x>=18" in result.output + assert "should be greater than or equal to 18" in result.output def test_invalid_score(monkeypatch: pytest.MonkeyPatch, mod: ModuleType): @@ -84,7 +70,7 @@ def test_invalid_score(monkeypatch: pytest.MonkeyPatch, mod: ModuleType): result = runner.invoke(mod.app, ["5", "--age", "20", "--score", "100.5"]) assert result.exit_code != 0 assert "Invalid value for '--score'" in result.output - assert "100.5 is not in the range x<=100." in result.output + assert "should be less than or equal to 100" in result.output def test_negative_score(mod: ModuleType): diff --git a/tests/test_tutorial/test_parameter_types/test_number/test_tutorial002.py b/tests/test_tutorial/test_parameter_types/test_number/test_tutorial002.py index 48162c763e..acdc0994fe 100644 --- a/tests/test_tutorial/test_parameter_types/test_number/test_tutorial002.py +++ b/tests/test_tutorial/test_parameter_types/test_number/test_tutorial002.py @@ -25,9 +25,8 @@ def get_mod(request: pytest.FixtureRequest) -> ModuleType: def test_invalid_id(mod: ModuleType): result = runner.invoke(mod.app, ["1002"]) assert result.exit_code != 0 - assert ( - "Invalid value for 'ID': 1002 is not in the range 0<=x<=1000" in result.output - ) + assert "Invalid value for 'ID'" in result.output + assert "should be less than or equal to 1000" in result.output def test_clamped(mod: ModuleType): diff --git a/tests/test_tutorial/test_parameter_types/test_path/test_tutorial002.py b/tests/test_tutorial/test_parameter_types/test_path/test_tutorial002.py index ff248d5822..9c9cde7501 100644 --- a/tests/test_tutorial/test_parameter_types/test_path/test_tutorial002.py +++ b/tests/test_tutorial/test_parameter_types/test_path/test_tutorial002.py @@ -30,7 +30,6 @@ def test_not_exists(tmpdir, mod: ModuleType): result = runner.invoke(mod.app, ["--config", f"{config_file}"]) assert result.exit_code != 0 assert "Invalid value for '--config'" in result.output - assert "File" in result.output assert "does not exist" in result.output @@ -47,7 +46,7 @@ def test_dir(mod: ModuleType): result = runner.invoke(mod.app, ["--config", "./"]) assert result.exit_code != 0 assert "Invalid value for '--config'" in result.output - assert "File './' is a directory." in result.output + assert "file './' is a directory." in result.output def test_script(mod: ModuleType): diff --git a/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py index 7b79e81405..11678e1b7b 100644 --- a/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_uuid/test_tutorial001.py @@ -2,7 +2,6 @@ import sys import uuid -import typer from typer.testing import CliRunner from docs_src.parameter_types.uuid import tutorial001_py310 as mod @@ -11,12 +10,6 @@ app = mod.app -def test_type_repr(): - command = typer.main.get_command(app) - user_id_param = next(param for param in command.params if param.name == "user_id") - assert repr(user_id_param.type) == "UUID" - - def test_main(): result = runner.invoke(app, ["d48edaa6-871a-4082-a196-4daab372d4a1"]) assert result.exit_code == 0 @@ -35,10 +28,8 @@ def test_main_with_uuid_object(): def test_invalid_uuid(): result = runner.invoke(app, ["7479706572-72756c6573"]) assert result.exit_code != 0 - assert ( - "Invalid value for 'USER_ID': '7479706572-72756c6573' is not a valid UUID" - in result.output - ) + assert "Invalid value for 'user_id'" in result.output + assert "should be a valid UUID" in result.output def test_script(): diff --git a/tests/test_tutorial/test_typer_app/test_tutorial001.py b/tests/test_tutorial/test_typer_app/test_tutorial001.py index a951fe1881..175841b304 100644 --- a/tests/test_tutorial/test_typer_app/test_tutorial001.py +++ b/tests/test_tutorial/test_typer_app/test_tutorial001.py @@ -13,7 +13,7 @@ def test_no_arg(): result = runner.invoke(app) assert result.exit_code != 0 - assert "Missing argument 'NAME'." in result.output + assert "Missing argument 'name'." in result.output def test_arg(): diff --git a/tests/test_type_conversion.py b/tests/test_type_conversion.py index 26709db0e2..b1f990dbeb 100644 --- a/tests/test_type_conversion.py +++ b/tests/test_type_conversion.py @@ -1,11 +1,13 @@ import os +import sys +from datetime import datetime from enum import Enum from pathlib import Path -from typing import Any +from typing import Annotated, Any, Literal, get_args, get_origin import pytest import typer -from typer import _click, models +from typer import _click, param_types from typer.testing import CliRunner from tests.utils import needs_linux, needs_windows @@ -136,8 +138,9 @@ def tuple_recursive_conversion(container: type_annotation): assert result.exit_code == 0 -def test_tuple_wrong_arity(): +def test_tuple_wrong_arity(monkeypatch): app = typer.Typer() + monkeypatch.setenv("COLUMNS", "200") @app.command() def tuple_arity(value: tuple[str, str] = typer.Option(...)): @@ -161,8 +164,9 @@ def custom_parser( assert result.exit_code == 0 -def test_custom_parse_value_error(): +def test_custom_parse_value_error(monkeypatch): app = typer.Typer() + monkeypatch.setenv("COLUMNS", "200") @app.command() def custom_parser( @@ -173,25 +177,15 @@ def custom_parser( result = runner.invoke(app, ["not-a-hex"]) assert result.exit_code == 2 assert "Invalid value" in result.output + assert "invalid literal for int()" in result.output -def test_custom_click_type(): - class BaseNumberParamType(_click.types.ParamType): - name = "base_integer" - - def convert( - self, - value: Any, - param: _click.Parameter | None, - ctx: _click.Context | None, - ) -> Any: - return int(value, 0) - +def test_custom_parser_hex(): app = typer.Typer() @app.command() - def custom_click_type( - hex_value: int = typer.Argument(None, click_type=BaseNumberParamType()), + def custom_parser_hex( + hex_value: int = typer.Argument(None, parser=lambda x: int(x, 0)), ): assert hex_value == 0x56 @@ -199,21 +193,38 @@ def custom_click_type( assert result.exit_code == 0 -def test_int_range_open_bound_clamp(): +@pytest.mark.parametrize( + ("cli_value", "expected"), + [ + ("true", True), + ("false", False), + ("yes", True), + ("no", False), + ("1", True), + ("0", False), + ("on", True), + ("off", False), + ("t", True), + ("f", False), + ("y", True), + ("n", False), + ("", False), + (" true ", True), + (" FALSE ", False), + ("TRUE", True), + ("No", False), + ], +) +def test_bool_convert_valid(cli_value: str, expected: bool) -> None: app = typer.Typer() @app.command() - def custom_click_type( - value: int = typer.Argument( - ..., - click_type=_click.types.IntRange(min=1, min_open=True, clamp=True), - ), - ): + def main(value: bool): print(value) - result = runner.invoke(app, ["1"]) + result = runner.invoke(app, [cli_value]) assert result.exit_code == 0 - assert "2" in result.output + assert str(expected) in result.output def test_bool_convert_invalid(): @@ -225,9 +236,7 @@ def main(value: bool): result = runner.invoke(app, ["maybe"]) assert result.exit_code == 2 - assert "is not a valid boolean" in result.output - assert "yes" in result.output - assert "false" in result.output + assert "Input should be a valid boolean" in result.output @pytest.mark.parametrize( @@ -252,12 +261,8 @@ def test_string_param_type_converts_bytes( def show(name: str = typer.Option(...)): print(name) - command = typer.main.get_command(app) - name_param = next(param for param in command.params if param.name == "name") - assert repr(name_param.type) == "STRING" - - monkeypatch.setattr(_click.types, "_get_argv_encoding", lambda: arg_enc) - monkeypatch.setattr(_click.types.sys, "getfilesystemencoding", lambda: system_enc) + monkeypatch.setattr(_click._compat, "_get_argv_encoding", lambda: arg_enc) + monkeypatch.setattr(sys, "getfilesystemencoding", lambda: system_enc) result = runner.invoke(app, [], default_map={"name": raw_value}) assert result.exit_code == 0 @@ -270,7 +275,7 @@ def test_path_coerced(path_type) -> None: app = typer.Typer() @app.command() - def show(path: Any = typer.Option(..., path_type=path_type)): + def show(path: Path = typer.Option(..., path_type=path_type)): print(path) result = runner.invoke(app, ["--path", "dir/my_awesome_file.txt"]) @@ -309,7 +314,7 @@ def fake_access(path: str, mode: int) -> bool: return False return original_access(path, mode) # pragma: no cover - monkeypatch.setattr(models.os, "access", fake_access) + monkeypatch.setattr(param_types.os, "access", fake_access) path = tmp_path / "some_path" if create_file: @@ -322,50 +327,120 @@ def fake_access(path: str, mode: int) -> bool: assert expected_error in result.output -def test_convert_type(): - from typer._click.types import convert_type +@pytest.mark.parametrize( + ("default", "expected_annotation"), + [ + (42, int), + (0.5, float), + ("morty", str), + (False, bool), + ("False", str), + ((1, "x"), tuple[int, str]), + ], +) +def test_default_infers_param_type( + default: Any, + expected_annotation: Any, +) -> None: + app = typer.Typer() + seen: dict[str, Any] = {} - # str - assert convert_type(str) is _click.types.STRING - assert convert_type(None) is _click.types.STRING - assert convert_type(None, default=["a"]) is _click.types.STRING + @app.command() + def cmd(val=default): + seen["val"] = val + + param = next(p for p in typer.main.get_command(app).params if p.name == "val") + assert param.runtime_param is not None + if get_origin(expected_annotation) is tuple: + assert get_origin(param.runtime_param.annotation) is tuple + assert get_args(param.runtime_param.annotation) == get_args(expected_annotation) + else: + assert param.runtime_param.annotation is expected_annotation - # tuples - tuple_type = convert_type((str, int)) - assert isinstance(tuple_type, _click.types.Tuple) - assert [type(item) for item in tuple_type.types] == [ - type(_click.types.STRING), - type(_click.types.INT), - ] + result = runner.invoke(app) + assert result.exit_code == 0, result.output + assert seen["val"] == default + if expected_annotation in (int, float, bool, str): + assert type(seen["val"]) is expected_annotation + elif get_origin(expected_annotation) is tuple: + assert isinstance(seen["val"], tuple) - guessed_tuple = convert_type(None, default=[(1, "x")]) - assert isinstance(guessed_tuple, _click.types.Tuple) - assert [type(item) for item in guessed_tuple.types] == [ - type(_click.types.INT), - type(_click.types.STRING), - ] - # numbers - assert convert_type(int) is _click.types.INT - assert convert_type(float) is _click.types.FLOAT - assert convert_type(bool) is _click.types.BOOL +@pytest.mark.parametrize( + ("parameter", "expected_metavar"), + [ + pytest.param(Annotated[str, typer.Option(...)], ""), + pytest.param(Annotated[str, typer.Argument(...)], ""), + pytest.param(Annotated[int, typer.Option(...)], ""), + pytest.param(Annotated[int, typer.Argument(...)], ""), + pytest.param(Annotated[float, typer.Option(...)], ""), + pytest.param( + Annotated[float, typer.Option(..., min=0.666, max=3.42)], "" + ), + pytest.param(Annotated[bytes, typer.Option(...)], ""), + pytest.param(Annotated[list[str], typer.Option(...)], ""), + pytest.param(Annotated[tuple[str, int], typer.Option(...)], "..."), + pytest.param(Annotated[tuple[Path, str], typer.Option(...)], "..."), + pytest.param(Annotated[str, typer.Option(..., resolve_path=True)], ""), + pytest.param(Annotated[Path, typer.Option(...)], ""), + pytest.param(Annotated[Path, typer.Option(..., dir_okay=False)], ""), + pytest.param(Annotated[Path, typer.Option(..., file_okay=False)], ""), + pytest.param(Annotated[SomeEnum, typer.Option(...)], ""), + pytest.param(Annotated[SomeEnum, typer.Argument()], ""), + pytest.param( + Annotated[SomeEnum, typer.Option(..., show_choices=False)], "" + ), + pytest.param( + Annotated[list[SomeEnum], typer.Option(...)], "" + ), + pytest.param( + Annotated[list[SomeEnum], typer.Option(..., show_choices=False)], + "", + ), + pytest.param(Annotated[Literal["x", "y"], typer.Option(...)], ""), + pytest.param(Annotated[typer.FileText, typer.Option(...)], ""), + pytest.param(Annotated[datetime, typer.Option(...)], "<%Y-%m-%d>"), + pytest.param( + Annotated[datetime, typer.Option(..., formats=["%Y-%m-%d", "%d/%m/%Y"])], + "<%Y-%m-%d|%d/%m/%Y>", + ), + ], +) +def test_param_type_help_metavar(parameter: Any, expected_metavar: str) -> None: + app = typer.Typer() - param_type = _click.types.IntRange(min=0, max=10) - assert convert_type(param_type) is param_type + @app.command() + # TODO: type-specific default + def with_default(value: parameter = "my_default"): + pass # pragma: no cover - guessed_int = convert_type(None, default=42) - assert guessed_int is _click.types.INT + @app.command() + def without_default(value: parameter): + pass # pragma: no cover - # custom type - class CustomType: - pass + result = runner.invoke(app, ["with-default", "--help"]) + assert result.exit_code == 0 + assert expected_metavar in result.output - guessed_unknown = convert_type(None, default=CustomType()) - assert guessed_unknown is _click.types.STRING + result = runner.invoke(app, ["without-default", "--help"]) + assert result.exit_code == 0 + assert expected_metavar in result.output + + +def test_int_rejects_float_default() -> None: + app = typer.Typer() - func_type = convert_type(CustomType) - assert isinstance(func_type, _click.types.FuncParamType) - assert func_type.name == "CustomType" + @app.command() + def main(age: int = typer.Option(15.3)): + typer.echo(age) + + result = runner.invoke(app, ["--age", 42]) + assert "42" in result.stdout + + # Pydantic validation rejects floats as int instead of converting int(15.3) to 15 + result = runner.invoke(app) + assert result.exit_code != 0 + assert "Input should be a valid integer" in result.stderr @pytest.mark.parametrize( @@ -382,19 +457,21 @@ def test_argv_encoding( stdin_encoding: str | None, filesystem_encoding: str, ) -> None: - sys = _click._compat.sys + app = typer.Typer() + + @app.command() + def show(name: str = typer.Option(...)): + print(name) + if platform_case == "windows": import locale monkeypatch.setattr(locale, "getpreferredencoding", lambda: "latin-1") else: - - class FakeStdin: - def __init__(self, encoding: str | None) -> None: - self.encoding = encoding - - monkeypatch.setattr(sys, "stdin", FakeStdin(stdin_encoding)) + argv_encoding = stdin_encoding or filesystem_encoding + monkeypatch.setattr(_click._compat, "_get_argv_encoding", lambda: argv_encoding) monkeypatch.setattr(sys, "getfilesystemencoding", lambda: filesystem_encoding) - converted = _click.types.STRING.convert(b"\xff", None, None) - assert converted == "ÿ" + result = runner.invoke(app, [], default_map={"name": b"\xff"}) + assert result.exit_code == 0 + assert "ÿ" in result.output diff --git a/tests/test_types.py b/tests/test_types.py index db6dae08da..d60907a79f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,8 +1,6 @@ from enum import Enum -import pytest import typer -from typer import _click from typer.testing import CliRunner app = typer.Typer(context_settings={"token_normalize_func": str.lower}) @@ -83,21 +81,14 @@ def test_enum_choice() -> None: assert "Hello Rick!" in result.output -def test_enum_choice_repr() -> None: - root_command = typer.main.get_command(app) - command = root_command.commands["hello-option"] - name_param = next(param for param in command.params if param.name == "name") - assert repr(name_param.type).startswith("Choice([") - - def test_enum_choice_help() -> None: result = runner.invoke(app, ["hello-argument", "--help"]) assert result.exit_code == 0 - assert "{rick|morty}" in result.output + assert "" in result.output result = runner.invoke(app, ["hello-option", "--help"]) assert result.exit_code == 0 - assert "[rick|morty]" in result.output + assert "" in result.output result = runner.invoke(app, ["hello-no-choices", "--help"]) assert result.exit_code == 0 @@ -158,8 +149,3 @@ def test_list_pair() -> None: assert result.exit_code == 0 assert "items=['a', 'b', 'c']" in result.output assert "pair=('x', 'y')" in result.output - - -def test_float_range_open_bounds_with_clamp_not_allowed(): - with pytest.raises(TypeError, match="Clamping is not supported for open bounds."): - _click.types.FloatRange(min=0.0, min_open=True, clamp=True) diff --git a/typer/_click/core.py b/typer/_click/core.py index 580b558b9f..afff00f5fe 100644 --- a/typer/_click/core.py +++ b/typer/_click/core.py @@ -8,6 +8,7 @@ from typing import ( TYPE_CHECKING, Any, + ClassVar, Literal, NoReturn, TypeVar, @@ -16,12 +17,10 @@ overload, ) -from . import types from .exceptions import ( Abort, BadParameter, Exit, - MissingParameter, NoArgsIsHelpError, UsageError, ) @@ -32,7 +31,7 @@ from .utils import echo, make_default_short_help if TYPE_CHECKING: - from ..core import TyperOption + from ..core import TyperOption, TyperParameter from .shell_completion import CompletionItem F = TypeVar("F", bound="Callable[..., Any]") @@ -60,7 +59,7 @@ def _complete_visible_commands( @contextmanager def augment_usage_errors( - ctx: "Context", param: Union["Parameter", None] = None + ctx: "Context", param: Union["TyperParameter", None] = None ) -> Iterator[None]: """Context manager that attaches extra information to exceptions.""" try: @@ -819,7 +818,6 @@ class Parameter(ABC): def __init__( self, param_decls: Sequence[str] | None = None, - type: types.ParamType | Any | None = None, required: bool = False, default: Any | Callable[[], Any] | None = None, callback: Callable[[Context, "Parameter", Any], Any] | None = None, @@ -840,15 +838,11 @@ def __init__( self.name, self.opts, self.secondary_opts = self._parse_decls( param_decls or (), expose_value ) - self.type: types.ParamType = types.convert_type(type, default) # Default nargs to what the type tells us if we have that # information available. if nargs is None: - if self.type.is_composite: - nargs = self.type.arity - else: - nargs = 1 + nargs = 1 self.required = required self.callback = callback @@ -870,28 +864,6 @@ def _parse_decls( ) -> tuple[str | None, list[str], list[str]]: pass # pragma: no cover - @property - def human_readable_name(self) -> str: - """Returns the human readable name of this parameter. This is the - same as the name for options, but the metavar for arguments. - """ - assert self.name is not None, "self.name should be set" - return self.name - - def make_metavar(self, ctx: Context) -> str: - if self.metavar is not None: - return self.metavar - - metavar = self.type.get_metavar(param=self, ctx=ctx) - - if metavar is None: - metavar = self.type.name.upper() - - if self.nargs != 1: - metavar += "..." - - return metavar - @overload def get_default(self, ctx: Context, call: Literal[True] = True) -> Any | None: ... @@ -938,65 +910,14 @@ def consume_value( return value, source - def type_cast_value(self, ctx: Context, value: Any) -> Any: - """Convert and validate a value against the parameter's - `type`, `multiple`, and `nargs`. - """ - if value is None: - return () if self.multiple or self.nargs == -1 else None - - def check_iter(value: Any) -> Iterator[Any]: - if isinstance(value, str): - raise BadParameter("Value must be an iterable.", ctx=ctx, param=self) - else: - return iter(value) - - # Define the conversion function based on nargs and type. - if self.nargs == 1 or self.type.is_composite: - - def convert(value: Any) -> Any: - return self.type(value, param=self, ctx=ctx) - - elif self.nargs == -1: - - def convert(value: Any) -> Any: # tuple[t.Any, ...] - return tuple(self.type(x, self, ctx) for x in check_iter(value)) - - # TODO: evaluate whether we need to keep this in Typer - else: # nargs > 1 - - def convert(value: Any) -> Any: # tuple[t.Any, ...] - value = tuple(check_iter(value)) - - if len(value) != self.nargs: - raise BadParameter( - f"Takes {self.nargs} values but {len(value)} given.", - ctx=ctx, - param=self, - ) - - return tuple(self.type(x, self, ctx) for x in value) - - if self.multiple: - return tuple(convert(x) for x in check_iter(value)) - - return convert(value) - @abstractmethod def value_is_missing(self, value: Any) -> bool: pass # pragma: no cover + @abstractmethod def process_value(self, ctx: Context, value: Any) -> Any: - """Process the value of this parameter""" - value = self.type_cast_value(ctx, value) - - if self.required and self.value_is_missing(value): - raise MissingParameter(ctx=ctx, param=self) - - if self.callback is not None: - value = self.callback(ctx, self, value) - - return value + """Process the value of this parameter.""" + pass # pragma: no cover def resolve_envvar_value(self, ctx: Context) -> str | None: """Returns the value found in the environment variable(s) attached to this @@ -1035,17 +956,28 @@ def resolve_envvar_value(self, ctx: Context) -> str | None: return None + envvar_list_splitter: ClassVar[str | None] = None + + def split_envvar_value(self, rv: str) -> Sequence[str]: + """Given a value from an environment variable this splits it up + into small chunks depending on the defined envvar list splitter. + + If the splitter is set to `None`, which means that whitespace splits, + then leading and trailing whitespace is ignored. Otherwise, leading + and trailing splitters usually lead to empty items being included. + """ + return (rv or "").split(self.envvar_list_splitter) + def value_from_envvar(self, ctx: Context) -> str | Sequence[str] | None: - """Process the raw environment variable string for this parameter. + """Process the value from the environment variable. Returns the string as-is or splits it into a sequence of strings if the - parameter is expecting multiple values (i.e. its `nargs` property is set - to a value other than ``1``). + parameter is expecting multiple values. """ rv: Any | None = self.resolve_envvar_value(ctx) - if rv is not None and self.nargs != 1: - rv = self.type.split_envvar_value(rv) + if rv is not None and (self.nargs != 1 or self.multiple): + rv = self.split_envvar_value(rv) return rv @@ -1061,7 +993,9 @@ def handle_parse_result( the value has been explicitly set by the user (and as such, is not coming from a default). """ - with augment_usage_errors(ctx, param=self): + from ..core import TyperParameter + + with augment_usage_errors(ctx, param=cast(TyperParameter, self)): value, source = self.consume_value(ctx, opts) ctx.set_parameter_source(self.name, source) # type: ignore @@ -1086,17 +1020,9 @@ def get_help_record(self, ctx: Context) -> tuple[str, str] | None: def get_usage_pieces(self, ctx: Context) -> list[str]: return [] - def get_error_hint(self, ctx: Context) -> str: - """Get a stringified version of the param for use in error messages to - indicate which param caused the error. - """ - hint_list = self.opts or [self.human_readable_name] - return " / ".join(f"'{x}'" for x in hint_list) - def shell_complete(self, ctx: Context, incomplete: str) -> list["CompletionItem"]: """Return a list of completions for the incomplete value. If a ``shell_complete`` function was given during init, it is used. - Otherwise, the `type` `ParamType.shell_complete` function is used. """ if self._custom_shell_complete is not None: results = self._custom_shell_complete(ctx, self, incomplete) @@ -1108,4 +1034,9 @@ def shell_complete(self, ctx: Context, incomplete: str) -> list["CompletionItem" return cast("list[CompletionItem]", results) - return self.type.shell_complete(ctx, self, incomplete) + # All Parameter objects will in fact be TyperParameter objects + # This will be cleaned up in later iterations + from ..core import TyperParameter + + param = cast(TyperParameter, self) + return param.shell_complete(ctx, incomplete) diff --git a/typer/_click/decorators.py b/typer/_click/decorators.py index 28ad656a8c..af3fbded3d 100644 --- a/typer/_click/decorators.py +++ b/typer/_click/decorators.py @@ -40,6 +40,7 @@ def decorator(f: Command) -> Command: def help_option(param_decls: list[str]) -> Callable[[Command], Command]: """Help option which prints the help page and exits the program.""" + from ..coercion import bool_flag_runtime_param, bool_flag_type_descriptor def show_help(ctx: Context, param: Parameter, value: bool) -> None: """Callback that print the help page on ```` and exits.""" @@ -57,4 +58,6 @@ def show_help(ctx: Context, param: Parameter, value: bool) -> None: help="Show this message and exit.", callback=show_help, required=False, + runtime_param=bool_flag_runtime_param(), + type_descriptor=bool_flag_type_descriptor(), ) diff --git a/typer/_click/exceptions.py b/typer/_click/exceptions.py index 75ba2296bb..37fc57b3ca 100644 --- a/typer/_click/exceptions.py +++ b/typer/_click/exceptions.py @@ -6,7 +6,8 @@ from .utils import echo, format_filename if TYPE_CHECKING: - from .core import Command, Context, Parameter + from ..core import TyperParameter + from .core import Command, Context def _join_param_hints(param_hint: Sequence[str] | str | None) -> str | None: @@ -90,7 +91,7 @@ def __init__( self, message: str, ctx: Union["Context", None] = None, - param: Union["Parameter", None] = None, + param: Union["TyperParameter", None] = None, param_hint: Sequence[str] | str | None = None, ) -> None: super().__init__(message, ctx) @@ -101,7 +102,7 @@ def format_message(self) -> str: if self.param_hint is not None: param_hint = self.param_hint elif self.param is not None: - param_hint = self.param.get_error_hint(self.ctx) # type: ignore + param_hint = self.param.get_error_hint() else: return f"Invalid value: {self.message}" @@ -118,7 +119,7 @@ def __init__( self, message: str | None = None, ctx: Union["Context", None] = None, - param: Union["Parameter", None] = None, + param: Union["TyperParameter", None] = None, param_hint: Sequence[str] | str | None = None, param_type: str | None = None, ) -> None: @@ -129,7 +130,7 @@ def format_message(self) -> str: if self.param_hint is not None: param_hint: Sequence[str] | str | None = self.param_hint elif self.param is not None: - param_hint = self.param.get_error_hint(self.ctx) # type: ignore + param_hint = self.param.get_error_hint() else: param_hint = None @@ -142,28 +143,20 @@ def format_message(self) -> str: msg = self.message if self.param is not None: - msg_extra = self.param.type.get_missing_message( - param=self.param, ctx=self.ctx - ) + from ..core import TyperParameter + + assert isinstance(self.param, TyperParameter) + msg_extra = self.param.get_missing_message(ctx=self.ctx) if msg_extra: if msg: msg += f". {msg_extra}" else: msg = msg_extra - msg = f" {msg}" if msg else "" - - # Translate param_type for known types. - if param_type == "argument": - missing = "Missing argument" - elif param_type == "option": - missing = "Missing option" - elif param_type == "parameter": - missing = "Missing parameter" - else: - missing = f"Missing {param_type}" + if msg: + return f"Missing {param_type}{param_hint}. {msg}" - return f"{missing}{param_hint}.{msg}" + return f"Missing {param_type}{param_hint}." def __str__(self) -> str: if not self.message: diff --git a/typer/_click/termui.py b/typer/_click/termui.py index 0a8c82574d..1e858976ec 100644 --- a/typer/_click/termui.py +++ b/typer/_click/termui.py @@ -5,7 +5,6 @@ from .exceptions import Abort, UsageError from .globals import resolve_color_default -from .types import ParamType, convert_type from .utils import LazyFile, echo if TYPE_CHECKING: @@ -51,14 +50,18 @@ def _build_prompt( show_default: bool = False, default: Any | None = None, show_choices: bool = True, - type: ParamType | None = None, + annotation: Any | None = None, ) -> str: # prevent circular imports - from .._types import TyperChoice + from ..models import OptionInfo + from ..param_types import choice_as_str, choice_coercion_annotation prompt = text - if type is not None and show_choices and isinstance(type, TyperChoice): - prompt += f" ({', '.join(map(str, type.choices))})" + if show_choices and annotation is not None: + choice = choice_coercion_annotation(annotation, OptionInfo()) + if choice is not None: + choices, _ = choice + prompt += f" ({', '.join(map(choice_as_str, choices))})" if default is not None and show_default: prompt = f"{prompt} [{_format_default(default)}]" return f"{prompt}{suffix}" @@ -76,7 +79,7 @@ def prompt( default: Any | None = None, hide_input: bool = False, confirmation_prompt: bool | str = False, - type: ParamType | Any | None = None, + type: Any | None = None, value_proc: Callable[[str], Any] | None = None, prompt_suffix: str = ": ", show_default: bool = True, @@ -107,18 +110,26 @@ def prompt_func(text: str) -> str: echo(None, err=err) raise Abort() from None + from ..param_types import annotation_from_prompt + + annotation = annotation_from_prompt(type, default) if value_proc is None: - value_proc = convert_type(type, default) + from ..coercion import prompt_value_proc + from ..param_types import annotation_from_prompt + + value_proc = prompt_value_proc(type, default) prompt = _build_prompt( - text, prompt_suffix, show_default, default, show_choices, type + text, prompt_suffix, show_default, default, show_choices, annotation ) if confirmation_prompt: if confirmation_prompt is True: confirmation_prompt = "Repeat for confirmation" - confirmation_prompt = _build_prompt(confirmation_prompt, prompt_suffix) + confirmation_prompt = _build_prompt( + confirmation_prompt, prompt_suffix, annotation=annotation + ) while True: while True: diff --git a/typer/_click/types.py b/typer/_click/types.py deleted file mode 100644 index 5ccf15fe1b..0000000000 --- a/typer/_click/types.py +++ /dev/null @@ -1,695 +0,0 @@ -import os -import sys -from collections.abc import Callable, Sequence -from datetime import datetime -from typing import ( - IO, - TYPE_CHECKING, - Any, - ClassVar, - Literal, - NoReturn, - TypedDict, - TypeGuard, - TypeVar, - Union, - cast, -) - -from ._compat import _get_argv_encoding, open_stream -from .exceptions import BadParameter -from .utils import LazyFile, format_filename, safecall - -if TYPE_CHECKING: - from .core import Context, Parameter - from .shell_completion import CompletionItem - -ParamTypeValue = TypeVar("ParamTypeValue") - - -class ParamType: - """Represents the type of a parameter. Validates and converts values - from the command line or Python into the correct type. - - To implement a custom type, subclass and implement at least the - following: - - - The `name` class attribute must be set. - - Calling an instance of the type with ``None`` must return - ``None``. This is already implemented by default. - - `convert` must convert string values to the correct type. - - `convert` must accept values that are already the correct - type. - - It must be able to convert a value if the ``ctx`` and ``param`` - arguments are ``None``. This can occur when converting prompt - input. - """ - - is_composite: ClassVar[bool] = False - arity: ClassVar[int] = 1 - name: str - - # if a list of this type is expected and the value is pulled from a - # string environment variable, this is what splits it up. `None` - # means any whitespace. For all parameters the general rule is that - # whitespace splits them up. The exception are paths and files which - # are split by ``os.path.pathsep`` by default (":" on Unix and ";" on - # Windows). - envvar_list_splitter: ClassVar[str | None] = None - - def __call__( - self, - value: Any, - param: Union["Parameter", None] = None, - ctx: Union["Context", None] = None, - ) -> Any: - if value is not None: - return self.convert(value, param, ctx) - - def get_metavar(self, param: "Parameter", ctx: "Context") -> str | None: - """Returns the metavar default for this param if it provides one.""" - pass # pragma: no cover - - def get_missing_message( - self, param: "Parameter", ctx: Union["Context", None] - ) -> str | None: - """Optionally might return extra information about a missing - parameter. - """ - pass # pragma: no cover - - def convert( - self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] - ) -> Any: - pass # pragma: no cover - - def split_envvar_value(self, rv: str) -> Sequence[str]: - """Given a value from an environment variable this splits it up - into small chunks depending on the defined envvar list splitter. - - If the splitter is set to `None`, which means that whitespace splits, - then leading and trailing whitespace is ignored. Otherwise, leading - and trailing splitters usually lead to empty items being included. - """ - return (rv or "").split(self.envvar_list_splitter) - - def fail( - self, - message: str, - param: Union["Parameter", None] = None, - ctx: Union["Context", None] = None, - ) -> NoReturn: - """Helper method to fail with an invalid value message.""" - raise BadParameter(message, ctx=ctx, param=param) - - def shell_complete( - self, ctx: "Context", param: "Parameter", incomplete: str - ) -> list["CompletionItem"]: - """Return a list of `CompletionItem` objects for the - incomplete value. Most types do not provide completions, but - some do, and this allows custom types to provide custom - completions as well. - """ - return [] - - -class CompositeParamType(ParamType): - is_composite = True - - @property - def arity(self) -> int: # type: ignore - raise NotImplementedError() # pragma: no cover - - -class FuncParamType(ParamType): - def __init__(self, func: Callable[[Any], Any]) -> None: - self.name: str = getattr(func, "__name__", "function") - self.func = func - - def convert( - self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] - ) -> Any: - try: - return self.func(value) - except ValueError: - try: - value = str(value) - except UnicodeError: # pragma: no cover - assert isinstance(value, bytes) - value = value.decode("utf-8", "replace") - - self.fail(value, param, ctx) - - -class StringParamType(ParamType): - name = "text" - - def convert( - self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] - ) -> Any: - if isinstance(value, bytes): - enc = _get_argv_encoding() - try: - value = value.decode(enc) - except UnicodeError: - fs_enc = sys.getfilesystemencoding() - if fs_enc != enc: - try: - value = value.decode(fs_enc) - except UnicodeError: - value = value.decode("utf-8", "replace") - else: - value = value.decode("utf-8", "replace") - return value - return str(value) - - def __repr__(self) -> str: - return "STRING" - - -class DateTime(ParamType): - """The DateTime type converts date strings into `datetime` objects. - - The format strings which are checked are configurable, but default to some - common (non-timezone aware) ISO 8601 formats. - - When specifying *DateTime* formats, you should only pass a list or a tuple. - Other iterables, like generators, may lead to surprising results. - - The format strings are processed using ``datetime.strptime``, and this - consequently defines the format strings which are allowed. - - Parsing is tried using each format, in order, and the first format which - parses successfully is used. - """ - - name = "datetime" - - def __init__(self, formats: Sequence[str] | None = None): - self.formats: Sequence[str] = formats or [ - "%Y-%m-%d", - "%Y-%m-%dT%H:%M:%S", - "%Y-%m-%d %H:%M:%S", - ] - - def get_metavar(self, param: "Parameter", ctx: "Context") -> str | None: - return f"[{'|'.join(self.formats)}]" - - def _try_to_convert_date(self, value: Any, format: str) -> datetime | None: - try: - return datetime.strptime(value, format) - except ValueError: - return None - - def convert( - self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] - ) -> Any: - if isinstance(value, datetime): - return value - - for format in self.formats: - converted = self._try_to_convert_date(value, format) - - if converted is not None: - return converted - - formats_str = ", ".join(map(repr, self.formats)) - self.fail( - f"{value!r} does not match the formats {formats_str}.", - param, - ctx, - ) - - def __repr__(self) -> str: - return "DateTime" - - -class _NumberParamTypeBase(ParamType): - _number_class: ClassVar[type[Any]] - - def convert( - self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] - ) -> Any: - try: - return self._number_class(value) - except ValueError: - self.fail( - f"{value!r} is not a valid {self.name}.", - param, - ctx, - ) - - -class _NumberRangeBase(_NumberParamTypeBase): - def __init__( - self, - min: float | None = None, - max: float | None = None, - min_open: bool = False, - max_open: bool = False, - clamp: bool = False, - ) -> None: - self.min = min - self.max = max - self.min_open = min_open - self.max_open = max_open - self.clamp = clamp - - def convert( - self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] - ) -> Any: - import operator - - rv = super().convert(value, param, ctx) - lt_min: bool = self.min is not None and ( - operator.le if self.min_open else operator.lt - )(rv, self.min) - gt_max: bool = self.max is not None and ( - operator.ge if self.max_open else operator.gt - )(rv, self.max) - - if self.clamp: - if lt_min: - return self._clamp(self.min, 1, self.min_open) # type: ignore[arg-type] - - if gt_max: - return self._clamp(self.max, -1, self.max_open) # type: ignore[arg-type] - - if lt_min or gt_max: - self.fail( - f"{rv} is not in the range {self._describe_range()}.", - param, - ctx, - ) - - return rv - - def _clamp(self, bound: float, dir: Literal[1, -1], open: bool) -> float: - """Find the valid value to clamp to bound in the given - direction. - """ - raise NotImplementedError # pragma: no cover - - def _describe_range(self) -> str: - """Describe the range for use in help text.""" - if self.min is None: - op = "<" if self.max_open else "<=" - return f"x{op}{self.max}" - - if self.max is None: - op = ">" if self.min_open else ">=" - return f"x{op}{self.min}" - - lop = "<" if self.min_open else "<=" - rop = "<" if self.max_open else "<=" - return f"{self.min}{lop}x{rop}{self.max}" - - def __repr__(self) -> str: - clamp = " clamped" if self.clamp else "" - return f"<{type(self).__name__} {self._describe_range()}{clamp}>" - - -class IntParamType(_NumberParamTypeBase): - name = "integer" - _number_class = int - - def __repr__(self) -> str: - return "INT" - - -class IntRange(_NumberRangeBase, IntParamType): - """Restrict an `INT` value to a range of accepted values. See - - If ``min`` or ``max`` are not passed, any value is accepted in that - direction. If ``min_open`` or ``max_open`` are enabled, the - corresponding boundary is not included in the range. - - If ``clamp`` is enabled, a value outside the range is clamped to the - boundary instead of failing. - """ - - name = "integer range" - - def _clamp( # type: ignore - self, bound: int, dir: Literal[1, -1], open: bool - ) -> int: - if not open: - return bound - - return bound + dir - - -class FloatParamType(_NumberParamTypeBase): - name = "float" - _number_class = float - - def __repr__(self) -> str: - return "FLOAT" - - -class FloatRange(_NumberRangeBase, FloatParamType): - """Restrict a `FLOAT` value to a range of accepted - values. See `ranges`. - - If ``min`` or ``max`` are not passed, any value is accepted in that - direction. If ``min_open`` or ``max_open`` are enabled, the - corresponding boundary is not included in the range. - - If ``clamp`` is enabled, a value outside the range is clamped to the - boundary instead of failing. This is not supported if either - boundary is marked ``open``. - """ - - name = "float range" - - def __init__( - self, - min: float | None = None, - max: float | None = None, - min_open: bool = False, - max_open: bool = False, - clamp: bool = False, - ) -> None: - super().__init__( - min=min, max=max, min_open=min_open, max_open=max_open, clamp=clamp - ) - - if (min_open or max_open) and clamp: - raise TypeError("Clamping is not supported for open bounds.") - - def _clamp(self, bound: float, dir: Literal[1, -1], open: bool) -> float: - if not open: - return bound - - # Could use math.nextafter here, but clamping an - # open float range doesn't seem to be particularly useful. It's - # left up to the user to write a callback to do it if needed. - raise RuntimeError( - "Clamping is not supported for open bounds." - ) # pragma: no cover - - -class BoolParamType(ParamType): - name = "boolean" - - bool_states: dict[str, bool] = { - "1": True, - "0": False, - "yes": True, - "no": False, - "true": True, - "false": False, - "on": True, - "off": False, - "t": True, - "f": False, - "y": True, - "n": False, - # Absence of value is considered False. - "": False, - } - """A mapping of string values to boolean states. - - Mapping is inspired by `configparser.ConfigParser.BOOLEAN_STATES` - and extends it. - """ - - @staticmethod - def str_to_bool(value: str | bool) -> bool | None: - """Convert a string to a boolean value. - - If the value is already a boolean, it is returned as-is. If the value is a - string, it is stripped of whitespaces and lower-cased, then checked against - the known boolean states pre-defined in the `BoolParamType.bool_states` mapping - above. - - Returns `None` if the value does not match any known boolean state. - """ - if isinstance(value, bool): - return value - return BoolParamType.bool_states.get(value.strip().lower()) - - def convert( - self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] - ) -> bool: - normalized = self.str_to_bool(value) - if normalized is None: - states = ", ".join(sorted(self.bool_states)) - self.fail( - f"{value!r} is not a valid boolean. Recognized values: {states}", - param, - ctx, - ) - return normalized - - def __repr__(self) -> str: - return "BOOL" - - -class UUIDParameterType(ParamType): - name = "uuid" - - def convert( - self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] - ) -> Any: - import uuid - - if isinstance(value, uuid.UUID): - return value - - value = value.strip() - - try: - return uuid.UUID(value) - except ValueError: - self.fail(f"{value!r} is not a valid UUID.", param, ctx) - - def __repr__(self) -> str: - return "UUID" - - -class File(ParamType): - """Declares a parameter to be a file for reading or writing. The file - is automatically closed once the context tears down (after the command - finished working). - - Files can be opened for reading or writing. The special value ``-`` - indicates stdin or stdout depending on the mode. - - By default, the file is opened for reading text data, but it can also be - opened in binary mode or for writing. The encoding parameter can be used - to force a specific encoding. - - The `lazy` flag controls if the file should be opened immediately or upon - first IO. The default is to be non-lazy for standard input and output - streams as well as files opened for reading, `lazy` otherwise. When opening a - file lazily for reading, it is still opened temporarily for validation, but - will not be held open until first IO. lazy is mainly useful when opening - for writing to avoid creating the file until it is needed. - - Files can also be opened atomically in which case all writes go into a - separate file in the same folder and upon completion the file will - be moved over to the original location. This is useful if a file - regularly read by other users is modified. - """ - - name = "filename" - envvar_list_splitter: ClassVar[str] = os.path.pathsep - - def __init__( - self, - mode: str = "r", - encoding: str | None = None, - errors: str | None = "strict", - lazy: bool | None = None, - atomic: bool = False, - ) -> None: - self.mode = mode - self.encoding = encoding - self.errors = errors - self.lazy = lazy - self.atomic = atomic - - def resolve_lazy_flag(self, value: str | os.PathLike[str]) -> bool: - if self.lazy is not None: - return self.lazy - if os.fspath(value) == "-": - return False - elif "w" in self.mode: - return True - return False - - def convert( - self, - value: str | os.PathLike[str] | IO[Any], - param: Union["Parameter", None], - ctx: Union["Context", None], - ) -> IO[Any]: - if _is_file_like(value): - return value - - value = cast("str | os.PathLike[str]", value) - - try: - lazy = self.resolve_lazy_flag(value) - - if lazy: - lf = LazyFile( - value, self.mode, self.encoding, self.errors, atomic=self.atomic - ) - - if ctx is not None: - ctx.call_on_close(lf.close_intelligently) - - return cast("IO[Any]", lf) - - f, should_close = open_stream( - value, self.mode, self.encoding, self.errors, atomic=self.atomic - ) - - # If a context is provided, we automatically close the file - # at the end of the context execution (or flush out). If a - # context does not exist, it's the caller's responsibility to - # properly close the file. This for instance happens when the - # type is used with prompts. - if ctx is not None: - if should_close: - ctx.call_on_close(safecall(f.close)) - else: - ctx.call_on_close(safecall(f.flush)) - - return f - except OSError as e: # pragma: no cover - self.fail(f"'{format_filename(value)}': {e.strerror}", param, ctx) - - def shell_complete( - self, ctx: "Context", param: "Parameter", incomplete: str - ) -> list["CompletionItem"]: - """Return a special completion marker that tells the completion - system to use the shell to provide file path completions. - """ - from .shell_completion import CompletionItem - - return [CompletionItem(incomplete, type="file")] - - -def _is_file_like(value: Any) -> TypeGuard[IO[Any]]: - return hasattr(value, "read") or hasattr(value, "write") - - -class Tuple(CompositeParamType): - """The default behavior of Click is to apply a type on a value directly. - This works well in most cases, except for when `nargs` is set to a fixed - count and different types should be used for different items. In this - case the `Tuple` type can be used. This type can only be used - if `nargs` is set to a fixed number. - - For more information see `tuple-type`. - - This can be selected by using a Python tuple literal as a type. - """ - - def __init__(self, types: Sequence[type[Any] | ParamType]) -> None: - self.types: Sequence[ParamType] = [convert_type(ty) for ty in types] - - @property - def name(self) -> str: # type: ignore[override] - return f"<{' '.join(ty.name for ty in self.types)}>" - - @property - def arity(self) -> int: # type: ignore - return len(self.types) - - def convert( - self, value: Any, param: Union["Parameter", None], ctx: Union["Context", None] - ) -> Any: - len_type = len(self.types) - len_value = len(value) - - if len_value != len_type: - self.fail( - f"{len_type} values are required, but {len_value} given.", - param=param, - ctx=ctx, - ) - - return tuple( - ty(x, param, ctx) for ty, x in zip(self.types, value, strict=False) - ) - - -def convert_type(ty: Any | None, default: Any | None = None) -> ParamType: - """Find the most appropriate `ParamType` for the given Python - type. If the type isn't provided, it can be inferred from a default - value. - """ - guessed_type = False - - if ty is None and default is not None: - if isinstance(default, (tuple, list)): - # If the default is empty, ty will remain None and will - # return STRING. - if default: - item = default[0] - - # A tuple of tuples needs to detect the inner types. - # Can't call convert recursively because that would - # incorrectly unwind the tuple to a single type. - if isinstance(item, (tuple, list)): - ty = tuple(map(type, item)) - else: - ty = type(item) - else: - ty = type(default) - - guessed_type = True - - if isinstance(ty, tuple): - return Tuple(ty) - - if isinstance(ty, ParamType): - return ty - - if ty is str or ty is None: - return STRING - - if ty is int: - return INT - - if ty is float: - return FLOAT - - if ty is bool: - return BOOL - - if guessed_type: - return STRING - - return FuncParamType(ty) - - -# A unicode string parameter type which is the implicit default. This -# can also be selected by using ``str`` as type. -STRING = StringParamType() - -# An integer parameter. This can also be selected by using ``int`` as -# type. -INT = IntParamType() - -# A floating point value parameter. This can also be selected by using -# ``float`` as type. -FLOAT = FloatParamType() - -# A boolean parameter. This is the default for boolean flags. This can -# also be selected by using ``bool`` as a type. -BOOL = BoolParamType() - -# A UUID parameter. -UUID = UUIDParameterType() - - -class OptionHelpExtra(TypedDict, total=False): - envvars: tuple[str, ...] - default: str - range: str - required: str diff --git a/typer/_types.py b/typer/_types.py deleted file mode 100644 index 09b38afb3f..0000000000 --- a/typer/_types.py +++ /dev/null @@ -1,120 +0,0 @@ -from collections.abc import Iterable, Mapping, Sequence -from enum import Enum -from typing import Any, Generic, TypeVar - -from . import _click -from ._click import types -from ._click.shell_completion import CompletionItem - -ParamTypeValue = TypeVar("ParamTypeValue") - - -class TyperChoice(types.ParamType, Generic[ParamTypeValue]): - # Code adapted from Click 8.3.1, with Typer using enum values in normalize_choice - name = "choice" - - def __init__( - self, choices: Iterable[ParamTypeValue], case_sensitive: bool = True - ) -> None: - self.choices: Sequence[ParamTypeValue] = tuple(choices) - self.case_sensitive = case_sensitive - - def _normalized_mapping( - self, ctx: _click.Context | None = None - ) -> Mapping[ParamTypeValue, str]: - """ - Returns mapping where keys are the original choices and the values are - the normalized values that are accepted via the command line. - """ - return { - choice: self.normalize_choice( - choice=choice, - ctx=ctx, - ) - for choice in self.choices - } - - def normalize_choice( - self, choice: ParamTypeValue, ctx: _click.Context | None - ) -> str: - normed_value = str(choice.value) if isinstance(choice, Enum) else str(choice) - - if ctx is not None and ctx.token_normalize_func is not None: - normed_value = ctx.token_normalize_func(normed_value) - - if not self.case_sensitive: - normed_value = normed_value.casefold() - - return normed_value - - def get_metavar(self, param: _click.Parameter, ctx: _click.Context) -> str | None: - if param.param_type_name == "option" and not param.show_choices: # type: ignore - choice_metavars = [ - types.convert_type(type(choice)).name.upper() for choice in self.choices - ] - choices_str = "|".join([*dict.fromkeys(choice_metavars)]) - else: - choices_str = "|".join( - [str(i) for i in self._normalized_mapping(ctx=ctx).values()] - ) - - # Use curly braces to indicate a required argument. - if param.required and param.param_type_name == "argument": - return f"{{{choices_str}}}" - - # Use square braces to indicate an option or optional argument. - return f"[{choices_str}]" - - def get_missing_message( - self, param: _click.Parameter, ctx: _click.Context | None - ) -> str: - """Message shown when no choice is passed.""" - choices = ",\n\t".join(self._normalized_mapping(ctx=ctx).values()) - return f"Choose from:\n\t{choices}" - - def convert( - self, value: Any, param: _click.Parameter | None, ctx: _click.Context | None - ) -> ParamTypeValue: - """ - For a given value from the parser, normalize it and find its - matching normalized value in the list of choices. Then return the - matched "original" choice. - """ - normed_value = self.normalize_choice(choice=value, ctx=ctx) - normalized_mapping = self._normalized_mapping(ctx=ctx) - - try: - return next( - original - for original, normalized in normalized_mapping.items() - if normalized == normed_value - ) - except StopIteration: - self.fail( - self.get_invalid_choice_message(value=value, ctx=ctx), - param=param, - ctx=ctx, - ) - - def get_invalid_choice_message(self, value: Any, ctx: _click.Context | None) -> str: - """Get the error message when the given choice is invalid.""" - choices_str = ", ".join(map(repr, self._normalized_mapping(ctx=ctx).values())) - return f"{value!r} is not one of {choices_str}." - - def __repr__(self) -> str: - return f"Choice({list(self.choices)})" - - def shell_complete( - self, ctx: _click.Context, param: _click.Parameter, incomplete: str - ) -> list[CompletionItem]: - """Complete choices that start with the incomplete value.""" - - str_choices = map(str, self.choices) - - if self.case_sensitive: - matched = (c for c in str_choices if c.startswith(incomplete)) - else: - incomplete = incomplete.lower() - matched = (c for c in str_choices if c.lower().startswith(incomplete)) - - return [CompletionItem(c) for c in matched] diff --git a/typer/_typing.py b/typer/_typing.py index 218f674c22..b45965af90 100644 --- a/typer/_typing.py +++ b/typer/_typing.py @@ -1,8 +1,7 @@ -# Copied from pydantic 1.9.2 (the latest version to support python 3.6.) -# https://github.com/pydantic/pydantic/blob/v1.9.2/pydantic/typing.py -# Reduced drastically to only include Typer-specific 3.9+ functionality +# Adapted from pydantic 1.9.2 # mypy: ignore-errors +import numbers import types from collections.abc import Callable from typing import ( @@ -26,6 +25,7 @@ def is_union(tp: type[Any] | None) -> bool: "is_callable_type", "is_literal_type", "all_literal_values", + "is_number_type", "is_union", "Annotated", "Literal", @@ -52,6 +52,15 @@ def is_callable_type(type_: type[Any]) -> bool: return type_ is Callable or get_origin(type_) is Callable +def is_number_type(type_: Any) -> bool: + return ( + isinstance(type_, type) + and type_ is not bool + and type_ is not complex + and issubclass(type_, numbers.Number) + ) + + def is_literal_type(type_: type[Any]) -> bool: return get_origin(type_) is Literal diff --git a/typer/adapters.py b/typer/adapters.py new file mode 100644 index 0000000000..14b63c8b48 --- /dev/null +++ b/typer/adapters.py @@ -0,0 +1,261 @@ +import sys +from collections.abc import Sequence +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Any, get_args, get_origin + +from pydantic import AfterValidator, BeforeValidator, Field, TypeAdapter, ValidationInfo +from pydantic.errors import PydanticSchemaGenerationError + +from ._click import _compat +from ._typing import is_literal_type, is_number_type, literal_values +from .models import ParameterInfo +from .param_types import coerce_cli_choice, coerce_cli_path, lenient_issubclass + +if TYPE_CHECKING: + from ._click import Context + from .core import TyperParameter + +_CTX_KEY = "ctx" +_PARAM_KEY = "param" + + +def validation_context( + ctx: "Context", + param: "TyperParameter", +) -> dict[str, Any]: + return {_CTX_KEY: ctx, _PARAM_KEY: param} + + +def validation_ctx_param( + info: ValidationInfo, +) -> tuple["Context | None", "TyperParameter | None"]: + context = info.context + if not context: + return None, None + return context.get(_CTX_KEY), context.get(_PARAM_KEY) + + +def try_build_adapter( + annotation: Any, + parameter_info: ParameterInfo, +) -> TypeAdapter[Any] | None: + """Build a TypeAdapter when Pydantic can schema-generate the annotation.""" + try: + return build_adapter(annotation, parameter_info) + except PydanticSchemaGenerationError: + return None + + +def build_adapter( + annotation: Any, + parameter_info: ParameterInfo, +) -> TypeAdapter[Any]: + """Build a Pydantic TypeAdapter for a parameter annotation and metadata. + Check for list/tuple and call this function recursively. + Otherwise, delegate to build_leaf_adapter. + """ + origin = get_origin(annotation) + if origin is list: + args = get_args(annotation) + if len(args) != 1: + raise ValueError(f"Expected one list item type, got: {args!r}") + list_type = args[0] + adapter = build_adapter(list_type, parameter_info) + + def parse_list(value: Any, info: ValidationInfo) -> list[Any]: + if not isinstance(value, (list, tuple)): + value = [value] + context = info.context + return [ + None if item is None else adapter.validate_python(item, context=context) + for item in value + ] + + return TypeAdapter(Annotated[list[Any], BeforeValidator(parse_list)]) + + if origin is tuple: + types = get_args(annotation) + adapters = [build_adapter(t, parameter_info) for t in types] + + def parse_tuple(value: Any, info: ValidationInfo) -> tuple[Any, ...]: + if not isinstance(value, (list, tuple)): + raise ValueError("value is not a valid tuple") + if len(value) != len(adapters): + raise ValueError( + f"{len(adapters)} values are required, but {len(value)} given." + ) + context = info.context + return tuple( + None if item is None else adapter.validate_python(item, context=context) + for adapter, item in zip(adapters, value, strict=False) + ) + + return TypeAdapter(Annotated[tuple[Any, ...], BeforeValidator(parse_tuple)]) + + return build_leaf_adapter(annotation, parameter_info=parameter_info) + + +def build_leaf_adapter( + annotation: Any, + *, + parameter_info: ParameterInfo, +) -> TypeAdapter[Any]: + """Build a Pydantic TypeAdapter for a leaf CLI annotation and constraints.""" + if parameter_info.parser is not None: + parser = parameter_info.parser + + # We need this because Pydantic would otherwise reject a callable class + def parse(value: Any) -> Any: + return parser(value) + + return TypeAdapter(Annotated[Any, BeforeValidator(parse)]) + + if lenient_issubclass(annotation, Enum): + case_sensitive = parameter_info.case_sensitive + return _build_choice_adapter( + list(annotation), + case_sensitive=case_sensitive, + ) + if is_literal_type(annotation): + case_sensitive = parameter_info.case_sensitive + return _build_choice_adapter( + literal_values(annotation), + case_sensitive=case_sensitive, + ) + if annotation is Path: + return build_path_adapter(annotation, parameter_info) + + if annotation is datetime: + return _build_datetime_adapter(parameter_info.formats) + + if is_number_type(annotation): + return _build_number_adapter( + annotation, + min=parameter_info.min, + max=parameter_info.max, + clamp=parameter_info.clamp, + ) + + if annotation is bool: + return TypeAdapter(Annotated[bool, BeforeValidator(_parse_cli_bool)]) + + if annotation is str: + return TypeAdapter(Annotated[str, BeforeValidator(_parse_cli_str)]) + + return TypeAdapter(annotation) + + +# DATE # +def _build_datetime_adapter(formats: Sequence[str] | None) -> TypeAdapter[datetime]: + if formats is None: + return TypeAdapter(datetime) + + def parse_datetime(value: Any) -> datetime: + if isinstance(value, datetime): + return value + for format in formats: + try: + return datetime.strptime(value, format) + except ValueError: + continue + formats_str = ", ".join(map(repr, formats)) + raise ValueError(f"{value!r} does not match the formats {formats_str}.") + + return TypeAdapter(Annotated[datetime, BeforeValidator(parse_datetime)]) + + +# STRING / BYTES # +def _parse_cli_str(value: Any) -> str: + """Coerce a CLI value to str""" + if isinstance(value, bytes): + enc = _compat._get_argv_encoding() + try: + return value.decode(enc) + except UnicodeError: + fs_enc = sys.getfilesystemencoding() + if fs_enc != enc: + try: + return value.decode(fs_enc) + except UnicodeError: + return value.decode("utf-8", "replace") + return value.decode("utf-8", "replace") + return str(value) + + +# BOOL # +def _parse_cli_bool(value: Any) -> Any: + if not isinstance(value, str): + return value + + stripped = value.strip() + if stripped == "": + return False + return stripped + + +# NUMBER # +def _build_number_adapter( + number_class: type[Any], *, min: float | None, max: float | None, clamp: bool | None +) -> TypeAdapter[Any]: + if clamp: + + def clamp_number(value: Any) -> Any: + if min is not None and value < min: + return number_class(min) + if max is not None and value > max: + return number_class(max) + return value + + # Use AfterValidator so it runs after coercion + return TypeAdapter(Annotated[number_class, AfterValidator(clamp_number)]) # ty: ignore[invalid-type-form] + else: + field_kwargs: dict[str, Any] = {} + if min is not None: + field_kwargs["ge"] = min + if max is not None: + field_kwargs["le"] = max + if field_kwargs: + return TypeAdapter(Annotated[number_class, Field(**field_kwargs)]) # ty: ignore[invalid-type-form] + return TypeAdapter(number_class) + + +# CHOICE # +def _build_choice_adapter( + choices: Sequence[Any], + *, + case_sensitive: bool, +) -> TypeAdapter[Any]: + def parse_choice(value: Any, info: ValidationInfo) -> Any: + ctx, _ = validation_ctx_param(info) + return coerce_cli_choice( + value, + choices=choices, + case_sensitive=case_sensitive, + ctx=ctx, + ) + + return TypeAdapter(Annotated[Any, BeforeValidator(parse_choice)]) + + +# PATH # +def build_path_adapter( + annotation: Any, + parameter_info: ParameterInfo, +) -> TypeAdapter[Any]: + path_type = parameter_info.path_type + if path_type is None and lenient_issubclass(annotation, Path): + path_type = annotation + + def parse_path(value: Any, info: ValidationInfo) -> Any: + ctx, param = validation_ctx_param(info) + return coerce_cli_path( + value, + parameter_info, + path_type=path_type, + param=param, + ctx=ctx, + ) + + return TypeAdapter(Annotated[Any, BeforeValidator(parse_path)]) diff --git a/typer/coercion.py b/typer/coercion.py new file mode 100644 index 0000000000..043370c105 --- /dev/null +++ b/typer/coercion.py @@ -0,0 +1,280 @@ +import os +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import IO, TYPE_CHECKING, Any + +from pydantic import TypeAdapter, ValidationError + +from . import adapters +from ._click import Context +from ._click.exceptions import BadParameter, UsageError +from ._typing import get_args, get_origin, is_number_type +from .adapters import validation_context +from .display import get_error_msg +from .models import OptionInfo, ParameterInfo +from .param_types import ( + ParameterAnnotation, + _open_cli_file, + annotation_from_prompt, + choice_coercion_annotation, + file_coercion_annotation, + is_file_annotation, + lenient_issubclass, + path_type_name, + resolve_file_mode, +) + +if TYPE_CHECKING: + from .core import TyperParameter + + +@dataclass(frozen=True) +class TypeDescriptor: + """Resolved CLI type: metadata, coercion adapter, and deduced flags.""" + + annotation: ParameterAnnotation + parameter_info: ParameterInfo + adapter: TypeAdapter[Any] | None + file_annotation: Any | None + + @property + def is_list(self) -> bool: + return lenient_issubclass(get_origin(self.annotation), list) + + @property + def is_tuple(self) -> bool: + return lenient_issubclass(get_origin(self.annotation), tuple) + + @property + def is_datetime(self) -> bool: + return self.annotation is datetime + + @property + def is_ranged(self) -> bool: + if self.is_list or self.is_tuple: + return False + return is_number_type(self.annotation) and ( + self.parameter_info.min is not None or self.parameter_info.max is not None + ) + + @property + def is_path(self) -> bool: + return self.annotation is Path + + @property + def is_choice(self) -> bool: + return self.choices is not None + + @property + def is_file(self) -> bool: + return self.file_annotation is not None + + @property + def datetime_formats(self) -> tuple[str, ...]: + formats = self.parameter_info.formats + if formats is not None: + return tuple(formats) + return ("%Y-%m-%d",) + + @property + def path_type(self) -> str: + return path_type_name(self.parameter_info) + + @property + def choices(self) -> tuple[Any, ...] | None: + if self.is_list: + args = get_args(self.annotation) + if len(args) == 1: + choice = choice_coercion_annotation(args[0], self.parameter_info) + if choice is not None: + return choice[0] + choice = choice_coercion_annotation(self.annotation, self.parameter_info) + if choice is not None: + return choice[0] + return None + + @property + def case_sensitive(self) -> bool: + return self.parameter_info.case_sensitive + + @property + def ranged_type_name(self) -> str: + if isinstance(self.annotation, type): + return self.annotation.__name__ + return "number" + + @property + def tuple_arity(self) -> int | None: + if not self.is_tuple: + return None + return len(get_args(self.annotation)) + + @property + def envvar_list_splitter(self) -> str | None: + if self.is_file: + return os.path.pathsep + if self.is_path: + return os.path.pathsep + if self.is_list: + args = get_args(self.annotation) + if len(args) == 1 and (is_file_annotation(args[0]) or args[0] is Path): + return os.path.pathsep + return None + + +def resolve_type_descriptor( + annotation: ParameterAnnotation, + parameter_info: ParameterInfo, +) -> TypeDescriptor: + """Resolve Pydantic adapter for one parameter annotation.""" + file_annotation = file_coercion_annotation(annotation) + adapter = None + if file_annotation is None: + adapter = adapters.try_build_adapter(annotation, parameter_info) + return TypeDescriptor( + annotation=annotation, + parameter_info=parameter_info, + adapter=adapter, + file_annotation=file_annotation, + ) + + +@dataclass(frozen=True) +class RuntimeParam(ABC): + """Runtime coercion contract for one command parameter.""" + + parameter_info: ParameterInfo + annotation: ParameterAnnotation + + def coerce(self, value: Any, param: "TyperParameter", ctx: Context) -> Any: + is_multi_value = param.multiple or param.nargs == -1 + if value is None: + if is_multi_value: + return () + return None + if is_multi_value and isinstance(value, str): + raise BadParameter("Value must be an iterable.", ctx=ctx, param=param) + return self._coerce_value(value, param=param, ctx=ctx) + + @abstractmethod + def _coerce_value(self, value: Any, param: "TyperParameter", ctx: Context) -> Any: + pass + + +@dataclass(frozen=True) +class AdapterRuntimeParam(RuntimeParam): + """Coercion via a Pydantic TypeAdapter.""" + + adapter: TypeAdapter[Any] + + def _coerce_value(self, value: Any, param: "TyperParameter", ctx: Context) -> Any: + try: + return self.adapter.validate_python( + value, + context=validation_context(ctx, param), + ) + except ValidationError as exc: + raise BadParameter(get_error_msg(exc), ctx=ctx, param=param) from exc + except ValueError as exc: + raise BadParameter(str(exc), ctx=ctx, param=param) from exc + + +@dataclass(frozen=True) +class FileRuntimeParam(RuntimeParam): + """Coercion by opening CLI file paths into IO streams.""" + + file_annotation: Any + + def _coerce_value(self, value: Any, param: "TyperParameter", ctx: Context) -> Any: + mode = resolve_file_mode(self.parameter_info, self.file_annotation) + + def open_one(item: Any) -> IO[Any]: + return _open_cli_file( + item, + self.parameter_info, + mode=mode, + param=param, + ctx=ctx, + ) + + if isinstance(value, (list, tuple)): + return type(value)(open_one(item) for item in value) + return open_one(value) + + +@dataclass(frozen=True) +class PassThroughRuntimeParam(RuntimeParam): + """Coercion for annotations that cannot use a Pydantic TypeAdapter.""" + + def _coerce_value(self, value: Any, param: "TyperParameter", ctx: Context) -> Any: + annotation = self.annotation + if isinstance(annotation, type): + if isinstance(value, annotation): + return value + label = getattr(annotation, "__name__", repr(annotation)) + raise BadParameter( + f"Value {value!r} is not a valid {label}.", + ctx=ctx, + param=param, + ) + + +def build_runtime_param(descriptor: TypeDescriptor) -> RuntimeParam: + """Build runtime coercion from a resolved type descriptor.""" + args = { + "annotation": descriptor.annotation, + "parameter_info": descriptor.parameter_info, + } + if descriptor.file_annotation is not None: + return FileRuntimeParam(**args, file_annotation=descriptor.file_annotation) + if descriptor.adapter is not None: + return AdapterRuntimeParam(**args, adapter=descriptor.adapter) + return PassThroughRuntimeParam(**args) + + +def bool_flag_type_descriptor() -> TypeDescriptor: + """Resolved type for a standalone boolean flag option.""" + return resolve_type_descriptor( + annotation=bool, + parameter_info=OptionInfo(), + ) + + +def bool_flag_runtime_param() -> RuntimeParam: + """Build runtime coercion for a standalone boolean flag option.""" + return build_runtime_param(bool_flag_type_descriptor()) + + +def prompt_value_proc( + param_type: Any | None = None, + default: Any | None = None, +) -> Callable[[Any], Any]: + """Coerce interactive prompt input via the runtime adapter layer.""" + annotation = annotation_from_prompt(param_type, default) + + parameter_info = OptionInfo() + adapter = adapters.try_build_adapter(annotation, parameter_info) + + if adapter is not None: + + def coerce(value: Any) -> Any: + try: + return adapter.validate_python(value) + except ValidationError as exc: + raise UsageError(get_error_msg(exc)) from exc + except ValueError as exc: + raise UsageError(str(exc)) from exc + + return coerce + + def coerce_pass_through(value: Any) -> Any: + if isinstance(annotation, type): + if isinstance(value, annotation): + return value + label = getattr(annotation, "__name__", repr(annotation)) + raise UsageError(f"Value {value!r} is not a valid {label}.") + + return coerce_pass_through diff --git a/typer/completion.py b/typer/completion.py index f63692ddf3..c6b92ad00d 100644 --- a/typer/completion.py +++ b/typer/completion.py @@ -109,8 +109,6 @@ def shell_complete( complete_var: str, instruction: str, ) -> int: - from . import _click - if "_" not in instruction: _click.echo("Invalid completion instruction.", err=True) return 1 diff --git a/typer/core.py b/typer/core.py index 6868ab4355..8e379f8d01 100644 --- a/typer/core.py +++ b/typer/core.py @@ -11,13 +11,17 @@ TextIO, Union, cast, + get_args, + get_origin, ) -from . import _click -from ._click import types +from . import _click, param_types from ._click.parser import _OptionParser from ._click.shell_completion import CompletionItem from ._typing import Literal +from .coercion import RuntimeParam, TypeDescriptor +from .display import describe_number_range +from .param_types import choice_as_str, normalize_choice_value from .utils import parse_boolean_env_var MarkupMode = Literal["markdown", "rich", None] @@ -81,6 +85,151 @@ def compat_autocompletion( self._custom_shell_complete = compat_autocompletion +class TyperParameter(_click.core.Parameter): + """Typer parameter with runtime coercion.""" + + runtime_param: RuntimeParam + type_descriptor: TypeDescriptor + show_choices: bool + + def process_value(self, ctx: _click.Context, value: Any) -> Any: + value = self.runtime_param.coerce(value, param=self, ctx=ctx) + if self.required and self.value_is_missing(value): + raise _click.exceptions.MissingParameter(ctx=ctx, param=self) + if self.callback is not None: + value = self.callback(ctx, self, value) + return value + + def value_is_missing(self, value: Any) -> bool: + if value is None: + return True + if (self.nargs != 1 or self.multiple) and value == (): + return True + return False + + def get_missing_message(self, ctx: _click.Context | None) -> str | None: + desc = self.type_descriptor + if desc.is_choice and desc.choices is not None: + normalized = [ + normalize_choice_value(choice, desc.case_sensitive, ctx) + for choice in desc.choices + ] + choices_str = ",\n\t".join(normalized) + return f"Choose from:\n\t{choices_str}" + return "" + + def value_from_envvar(self, ctx: _click.Context) -> str | Sequence[str] | None: + rv: Any | None = self.resolve_envvar_value(ctx) + if rv is not None and (self.nargs != 1 or self.multiple): + splitter = self.type_descriptor.envvar_list_splitter + if splitter is not None: + rv = (rv or "").split(splitter) + else: + rv = self.split_envvar_value(rv) + return rv + + def shell_complete( + self, ctx: _click.Context, incomplete: str + ) -> list[CompletionItem]: + # custom + if self._custom_shell_complete is not None: + results = self._custom_shell_complete(ctx, self, incomplete) + if results and isinstance(results[0], str): + results = [CompletionItem(c) for c in results] + return cast(list[CompletionItem], results) + # choice + desc = self.type_descriptor + if desc.is_choice and desc.choices is not None: + str_choices = map(choice_as_str, desc.choices) + if desc.case_sensitive: + matched = (c for c in str_choices if c.startswith(incomplete)) + else: + incomplete = incomplete.lower() + matched = (c for c in str_choices if c.lower().startswith(incomplete)) + return [CompletionItem(c) for c in matched] + # file + if desc.is_file: + return [CompletionItem(incomplete, type="file")] + # fall-back, specifically also required for path's + return [] + + @property + def display_name_raw(self) -> str: + if self.metavar is not None: + return self.metavar + assert self.name is not None + return self.name + + def get_error_hint(self) -> str: + return f"'{self.display_name_raw}'" + + def display_name_type(self, ctx: _click.Context) -> str | None: + return self.metavar + + def display_type(self, ctx: _click.Context) -> str: + """Formatted type string for help, e.g. ````""" + desc = self.type_descriptor + if desc.is_choice: + if not self.show_choices: + type_names = [self._bare_type(type(c)) for c in desc.choices or ()] + label = "|".join([*dict.fromkeys(type_names)]) + else: + normalized_mapping = { + c: param_types.normalize_choice_value(c, desc.case_sensitive, ctx) + for c in desc.choices or () + } + label = "|".join(normalized_mapping.values()) + if desc.is_list: + label = f"list[{label}]" + elif desc.is_list: + label = self.bare_type() + elif desc.is_tuple: + labels = [self._bare_type(arg) for arg in get_args(desc.annotation)] + label = ",".join(labels) + elif desc.is_datetime: + label = "|".join(desc.datetime_formats) + elif desc.is_ranged: + label = f"{desc.ranged_type_name} range" + elif desc.is_path: + label = desc.path_type + else: + label = self.bare_type() + return f"<{label}>" + + def display_type_rich(self, ctx: _click.Context) -> str | None: + """Type string for the Rich help types column.""" + display = self.display_type(ctx) + if display == "BOOL": + return None + return display + + def bare_type(self) -> str: + annotation = self.runtime_param.annotation + return self._bare_type(annotation) + + def _bare_type(self, annotation: type) -> str: + display_type = str(annotation) + origin = get_origin(annotation) + if annotation is None: + display_type = "str" + elif origin is list: + args = get_args(annotation) + if len(args) == 1: + element_label = self._bare_type(args[0]) + display_type = f"list[{element_label}]" + else: + display_type = "list" + elif origin is tuple: + labels = [self._bare_type(arg) for arg in get_args(annotation)] + display_type = ",".join(labels) + elif isinstance(annotation, type): + display_type = annotation.__name__ + return display_type + + def get_number_range_help_str(self) -> str | None: + return None + + def _get_default_string( obj: Union["TyperArgument", "TyperOption"], *, @@ -242,7 +391,7 @@ def _main( sys.exit(1) -class TyperArgument(_click.core.Parameter): +class TyperArgument(TyperParameter): param_type_name = "argument" def __init__( @@ -250,7 +399,8 @@ def __init__( *, # Parameter param_decls: list[str], - type: Any | None = None, + runtime_param: RuntimeParam, + type_descriptor: TypeDescriptor, required: bool = False, default: Any | None = None, callback: Callable[..., Any] | None = None, @@ -273,6 +423,9 @@ def __init__( show_envvar: bool = True, help: str | None = None, hidden: bool = False, + # Numbers + min: int | float | None = None, + max: int | float | None = None, # Rich settings rich_help_panel: str | None = None, ): @@ -281,11 +434,14 @@ def __init__( self.show_choices = show_choices self.show_envvar = show_envvar self.hidden = hidden + self.min = min + self.max = max self.rich_help_panel = rich_help_panel + self.runtime_param = runtime_param + self.type_descriptor = type_descriptor super().__init__( param_decls=param_decls, - type=type, required=required, default=default, callback=callback, @@ -298,13 +454,6 @@ def __init__( ) _typer_param_setup_autocompletion_compat(self, autocompletion=autocompletion) - @property - def human_readable_name(self) -> str: - if self.metavar is not None: - return self.metavar - assert self.name is not None, "self.name or self.metavar should be set" - return self.name.upper() - def _get_default_string( self, *, @@ -319,17 +468,39 @@ def _get_default_string( default_value=default_value, ) + def display_name(self) -> str: + """Argument display name for help listings (no type suffix).""" + if not self.required: + return f"[{self.display_name_raw}]" + return self.display_name_raw + + def rich_display_name(self) -> str: + """Argument display name for the Rich help name column.""" + name = self.display_name() + if self.metavar is None and self.nargs != 1: + name += "..." + return name + + def usage_display_name(self) -> str: + """Argument name for the usage line only.""" + name = self.display_name_raw + if self.required: + name = f"{{{name}}}" + else: + name = f"[{name}]" + if self.nargs != 1: + name += "..." + return name + def _extract_default_help_str( self, *, ctx: _click.Context ) -> Any | Callable[[], Any] | None: return _extract_default_help_str(self, ctx=ctx) def get_help_record(self, ctx: _click.Context) -> tuple[str, str] | None: - # Modified version of _click.core.Option.get_help_record() - # to support Arguments if self.hidden: return None - name = self.make_metavar(ctx=ctx) + name = self.rich_display_name() help = self.help or "" extra = [] if self.show_envvar: @@ -363,6 +534,9 @@ def get_help_record(self, ctx: _click.Context) -> tuple[str, str] | None: # Typer override end if default_string: extra.append(_("default: {default}").format(default=default_string)) + range_str = self.get_number_range_help_str() + if range_str: + extra.append(range_str) if self.required: extra.append(_("required")) if extra: @@ -380,27 +554,15 @@ def get_help_record(self, ctx: _click.Context) -> tuple[str, str] | None: help = f"{help} {extra_str}" if help else f"{extra_str}" return name, help - def make_metavar(self, ctx: _click.Context) -> str: - # Modified version of _click.core.Argument.make_metavar() - # to include Argument name - if self.metavar is not None: - var = self.metavar - if not self.required and not var.startswith("["): - var = f"[{var}]" - return var - var = (self.name or "").upper() - if not self.required: - var = f"[{var}]" - type_var = self.type.get_metavar(self, ctx=ctx) + def display_name_type(self, ctx: _click.Context) -> str: + var = self.display_name() + type_var = self.display_type(ctx) if type_var: var += f":{type_var}" - if self.nargs != 1: + if self.metavar is None and self.nargs != 1: var += "..." return var - def value_is_missing(self, value: Any) -> bool: - return _value_is_missing(self, value) - def _parse_decls( self, decls: Sequence[str], expose_value: bool ) -> tuple[str | None, list[str], list[str]]: @@ -410,7 +572,7 @@ def _parse_decls( raise TypeError("Argument is marked as exposed, but does not have a name.") if len(decls) == 1: name = arg = decls[0] - name = name.replace("-", "_").lower() + name = name.replace("-", "_") else: raise TypeError( "Arguments take exactly one parameter declaration, got" @@ -419,16 +581,22 @@ def _parse_decls( return name, [arg], [] def get_usage_pieces(self, ctx: _click.Context) -> list[str]: - return [self.make_metavar(ctx)] - - def get_error_hint(self, ctx: _click.Context) -> str: - return f"'{self.make_metavar(ctx)}'" + return [self.usage_display_name()] def add_to_parser(self, parser: _OptionParser, ctx: _click.Context) -> None: parser.add_argument(dest=self.name, nargs=self.nargs, obj=self) + def display_type_rich(self, ctx: _click.Context) -> str | None: + suffix = self.display_type(ctx) + if suffix is not None: + return suffix + return self.display_type(ctx) + + def get_number_range_help_str(self) -> str | None: + return describe_number_range(self.min, self.max) + -class TyperOption(_click.Parameter): +class TyperOption(TyperParameter): param_type_name = "option" _depr_flag_value: bool | None @@ -438,7 +606,8 @@ def __init__( *, # Parameter param_decls: list[str], - type: types.ParamType | Any | None = None, + runtime_param: RuntimeParam, + type_descriptor: TypeDescriptor, required: bool = False, default: Any | None = None, callback: Callable[..., Any] | None = None, @@ -469,15 +638,22 @@ def __init__( hidden: bool = False, show_choices: bool = True, show_envvar: bool = False, + # Numbers + min: int | float | None = None, + max: int | float | None = None, # Rich settings rich_help_panel: str | None = None, ): if help: help = inspect.cleandoc(help) + self.min = min + self.max = max + self.runtime_param = runtime_param + self.type_descriptor = type_descriptor + super().__init__( param_decls, - type=type, multiple=multiple, required=required, default=default, @@ -507,23 +683,18 @@ def __init__( self.hidden = hidden # TODO: revisit all of this flag stuff - if is_flag and type is None: - self.type: types.ParamType = types.BoolParamType() - self.is_flag: bool = bool(is_flag) - self.is_bool_flag: bool = bool( - is_flag and isinstance(self.type, types.BoolParamType) - ) + self.is_bool_flag: bool = bool(is_flag and not count) if self.is_flag: self._depr_flag_value = True else: self._depr_flag_value = None - # Counting. TODO: test or remove? Not currently in coverage. + # Counting self.count = count - if count and type is None: - self.type = types.IntRange(min=0) + if count and self.min is None: + self.min = 0 self.allow_from_autoenv = allow_from_autoenv self.help = help @@ -534,8 +705,9 @@ def __init__( _typer_param_setup_autocompletion_compat(self, autocompletion=autocompletion) self.rich_help_panel = rich_help_panel - def get_error_hint(self, ctx: _click.Context) -> str: - result = super().get_error_hint(ctx) + def get_error_hint(self) -> str: + hint_list = self.opts or [self.display_name_raw] + result = " / ".join(f"'{x}'" for x in hint_list) if self.show_envvar and self.envvar is not None: result += f" (env var: '{self.envvar}')" return result @@ -575,7 +747,7 @@ def _parse_decls( if name is None and possible_names: possible_names.sort(key=lambda x: -len(x[0])) # group long options first - name = possible_names[0][1].replace("-", "_").lower() + name = possible_names[0][1].replace("-", "_") if not name.isidentifier(): name = None @@ -658,7 +830,7 @@ def prompt_for_value(self, ctx: _click.Context) -> Any: # Use ``None`` to inform the prompt() function to reiterate until a valid # value is provided by the user if we have no default. default=default, - type=self.type, + type=self.type_descriptor.annotation, hide_input=self.hide_input, show_choices=self.show_choices, confirmation_prompt=self.confirmation_prompt, @@ -666,19 +838,6 @@ def prompt_for_value(self, ctx: _click.Context) -> Any: **prompt_kwargs, ) - def value_from_envvar(self, ctx: _click.Context) -> Any: - # TODO: clean up - rv = self.resolve_envvar_value(ctx) - - # Absent environment variable or an empty string is interpreted as unset. - if rv is None: - return None - - if self.nargs != 1 or self.multiple: - return self.type.split_envvar_value(rv) - - return rv - def resolve_envvar_value(self, ctx: _click.Context) -> str | None: rv = super().resolve_envvar_value(ctx) @@ -742,13 +901,19 @@ def _extract_default_help_str( ) -> Any | Callable[[], Any] | None: return _extract_default_help_str(self, ctx=ctx) - def make_metavar(self, ctx: _click.Context) -> str: - return super().make_metavar(ctx=ctx) + def value_label(self, ctx: _click.Context) -> str | None: + if self.metavar is not None: + return self.metavar + + value_display = self.display_type(ctx) + if self.nargs != 1 and value_display is not None: + return str(value_display) + "..." + return value_display + + def display_name_type(self, ctx: _click.Context) -> str | None: + return self.value_label(ctx) def get_help_record(self, ctx: _click.Context) -> tuple[str, str] | None: - # Duplicate all of Click's logic only to modify a single line, to allow boolean - # flags with only names for False values as it's currently supported by Typer - # Ref: https://typer.tiangolo.com/tutorial/parameter-types/bool/#only-names-for-false if self.hidden: return None @@ -763,7 +928,7 @@ def _write_opts(opts: Sequence[str]) -> str: any_prefix_is_slash = True if not self.is_flag and not self.count: - rv += f" {self.make_metavar(ctx=ctx)}" + rv += f" {self.display_name_type(ctx=ctx)}" return rv @@ -794,32 +959,24 @@ def _write_opts(opts: Sequence[str]) -> str: ) extra.append(_("env var: {var}").format(var=var_str)) - # Typer override: - # Extracted to _extract_default() to allow re-using it in rich_utils default_value = self._extract_default_help_str(ctx=ctx) - # Typer override end show_default_is_str = isinstance(self.show_default, str) if show_default_is_str or ( default_value is not None and (self.show_default or ctx.show_default) ): - # Typer override: - # Extracted to _get_default_string() to allow re-using it in rich_utils default_string = self._get_default_string( ctx=ctx, show_default_is_str=show_default_is_str, default_value=default_value, ) - # Typer override end if default_string: extra.append(_("default: {default}").format(default=default_string)) - if isinstance(self.type, types._NumberRangeBase): - range_str = self.type._describe_range() - - if range_str: - extra.append(range_str) + range_str = self.get_number_range_help_str() + if range_str: + extra.append(range_str) if self.required: extra.append(_("required")) @@ -840,18 +997,13 @@ def _write_opts(opts: Sequence[str]) -> str: return ("; " if any_prefix_is_slash else " / ").join(rv), help - def value_is_missing(self, value: Any) -> bool: - return _value_is_missing(self, value) - - -def _value_is_missing(param: _click.Parameter, value: Any) -> bool: - if value is None: - return True + def display_type_rich(self, ctx: _click.Context) -> str | None: + return self.value_label(ctx) - if (param.nargs != 1 or param.multiple) and value == (): - return True # pragma: no cover - - return False + def get_number_range_help_str(self) -> str | None: + if self.count and self.min == 0 and self.max is None: + return None + return describe_number_range(self.min, self.max) def _typer_format_options( diff --git a/typer/display.py b/typer/display.py new file mode 100644 index 0000000000..b536a330c1 --- /dev/null +++ b/typer/display.py @@ -0,0 +1,22 @@ +from pydantic import ValidationError + + +def describe_number_range( + min: int | float | None, + max: int | float | None, +) -> str | None: + if min is None and max is None: + return None + if min is None: + return f"x<={max}" + if max is None: + return f"x>={min}" + return f"{min}<=x<={max}" + + +def get_error_msg(exc: ValidationError) -> str: + """Get a string representation of the (first) validation error.""" + errors = exc.errors() + if errors: + return errors[0]["msg"] + return str(exc) diff --git a/typer/main.py b/typer/main.py index a825c1b14e..52e8e3ff43 100644 --- a/typer/main.py +++ b/typer/main.py @@ -6,22 +6,17 @@ import sys import traceback from collections.abc import Callable, Sequence -from datetime import datetime -from enum import Enum from functools import update_wrapper -from pathlib import Path from traceback import FrameSummary, StackSummary from types import TracebackType from typing import Annotated, Any -from uuid import UUID from annotated_doc import Doc -from typer._types import TyperChoice from . import _click -from ._click import types from ._click.globals import get_current_context -from ._typing import get_args, get_origin, is_literal_type, is_union, literal_values +from ._typing import get_args, get_origin +from .coercion import build_runtime_param, resolve_type_descriptor from .completion import get_completion_inspect_parameters from .core import ( DEFAULT_MARKUP_MODE, @@ -33,25 +28,19 @@ TyperOption, ) from .models import ( - AnyType, ArgumentInfo, CommandFunctionType, CommandInfo, Default, DefaultPlaceholder, DeveloperExceptionConfig, - FileBinaryRead, - FileBinaryWrite, - FileText, - FileTextWrite, - NoneType, OptionInfo, ParameterInfo, ParamMeta, Required, TyperInfo, - TyperPath, ) +from .param_types import lenient_issubclass, parse_param_annotation from .utils import get_params_from_function _original_except_hook = sys.excepthook @@ -111,8 +100,8 @@ def except_hook( def get_install_completion_arguments() -> tuple[_click.Parameter, _click.Parameter]: install_param, show_param = get_completion_inspect_parameters() - click_install_param, _ = get_click_param(install_param) - click_show_param, _ = get_click_param(show_param) + click_install_param = get_param(install_param) + click_show_param = get_param(show_param) return click_install_param, click_show_param @@ -1319,11 +1308,9 @@ def get_group_from_info( for sub_command_name, sub_command in sub_group.commands.items(): commands[sub_command_name] = sub_command solved_info = solve_typer_info_defaults(group_info) - ( - params, - convertors, - context_param_name, - ) = get_params_convertors_ctx_param_name_from_function(solved_info.callback) + params, context_param_name = get_params_ctx_param_name_from_function( + solved_info.callback + ) cls = solved_info.cls or TyperGroup assert issubclass(cls, TyperGroup), f"{cls} should be a subclass of {TyperGroup}" group = cls( @@ -1337,7 +1324,6 @@ def get_group_from_info( callback=get_callback( callback=solved_info.callback, params=params, - convertors=convertors, context_param_name=context_param_name, pretty_exceptions_short=pretty_exceptions_short, ), @@ -1361,11 +1347,26 @@ def get_command_name(name: str) -> str: return name.lower().replace("_", "-") -def get_params_convertors_ctx_param_name_from_function( +def get_option_flag_name(name: str) -> str: + return name.replace("_", "-") + + +def get_default_option_flag_name(name: str, metavar: str | None) -> str: + flag_name = get_option_flag_name(name) + if metavar is None: + return flag_name + if ( + get_option_flag_name(metavar).replace("-", "_").casefold() + == flag_name.replace("-", "_").casefold() + ): + return get_option_flag_name(metavar) + return flag_name + + +def get_params_ctx_param_name_from_function( callback: Callable[..., Any] | None, -) -> tuple[list[TyperArgument | TyperOption], dict[str, Any], str | None]: +) -> tuple[list[TyperArgument | TyperOption], str | None]: params = [] - convertors = {} context_param_name = None if callback: parameters = get_params_from_function(callback) @@ -1373,11 +1374,8 @@ def get_params_convertors_ctx_param_name_from_function( if lenient_issubclass(param.annotation, _click.Context): context_param_name = param_name continue - click_param, convertor = get_click_param(param) - if convertor: - convertors[param_name] = convertor - params.append(click_param) - return params, convertors, context_param_name + params.append(get_param(param)) + return params, context_param_name def get_command_from_info( @@ -1393,11 +1391,9 @@ def get_command_from_info( use_help = inspect.getdoc(command_info.callback) else: use_help = inspect.cleandoc(use_help) - ( - params, - convertors, - context_param_name, - ) = get_params_convertors_ctx_param_name_from_function(command_info.callback) + params, context_param_name = get_params_ctx_param_name_from_function( + command_info.callback + ) cls = command_info.cls or TyperCommand command = cls( name=name, @@ -1405,7 +1401,6 @@ def get_command_from_info( callback=get_callback( callback=command_info.callback, params=params, - convertors=convertors, context_param_name=context_param_name, pretty_exceptions_short=pretty_exceptions_short, ), @@ -1425,89 +1420,42 @@ def get_command_from_info( return command -def determine_type_convertor(type_: Any) -> Callable[[Any], Any] | None: - convertor: Callable[[Any], Any] | None = None - if lenient_issubclass(type_, Path): - convertor = param_path_convertor - if lenient_issubclass(type_, Enum): - convertor = generate_enum_convertor(type_) - return convertor - - -def param_path_convertor(value: str | None = None) -> Path | None: - if value is not None: - # allow returning any subclass of Path created by an annotated parser without converting - # it back to a Path - return value if isinstance(value, Path) else Path(value) - return None - - -def generate_enum_convertor(enum: type[Enum]) -> Callable[[Any], Any]: - val_map = {str(val.value): val for val in enum} - - def convertor(value: Any) -> Any: - if value is not None: - val = str(value) - if val in val_map: - key = val_map[val] - return enum(key) - - return convertor - - -def generate_list_convertor( - convertor: Callable[[Any], Any] | None, default_value: Any | None -) -> Callable[[Sequence[Any] | None], list[Any] | None]: - def internal_convertor(value: Sequence[Any] | None) -> list[Any] | None: - if (value is None) or (default_value is None and len(value) == 0): - return None - return [convertor(v) if convertor else v for v in value] - - return internal_convertor - - -def generate_tuple_convertor( - types: Sequence[Any], -) -> Callable[[tuple[Any, ...] | None], tuple[Any, ...] | None]: - convertors = [determine_type_convertor(type_) for type_ in types] - - def internal_convertor( - param_args: tuple[Any, ...] | None, - ) -> tuple[Any, ...] | None: - if param_args is None: - return None - return tuple( - convertor(arg) if convertor else arg - for (convertor, arg) in zip(convertors, param_args, strict=False) - ) - - return internal_convertor +def _normalize_collection_value(param: _click.Parameter, value: Any) -> Any: + if value is None: + return None + is_multi = getattr(param, "multiple", False) or getattr(param, "nargs", 1) == -1 + if not is_multi: + return value + if param.default is None and len(value) == 0: + return None + return list(value) def get_callback( *, callback: Callable[..., Any] | None = None, params: Sequence[_click.Parameter] = [], - convertors: dict[str, Callable[[str], Any]] | None = None, context_param_name: str | None = None, pretty_exceptions_short: bool, ) -> Callable[..., Any] | None: - use_convertors = convertors or {} if not callback: return None parameters = get_params_from_function(callback) use_params: dict[str, Any] = {} for param_name in parameters: use_params[param_name] = None + params_by_name: dict[str, _click.Parameter] = {} for param in params: if param.name: use_params[param.name] = param.default + params_by_name[param.name] = param def wrapper(**kwargs: Any) -> Any: _rich_traceback_guard = pretty_exceptions_short # noqa: F841 for k, v in kwargs.items(): - if k in use_convertors: - use_params[k] = use_convertors[k](v) + matched_param = params_by_name.get(k) + if matched_param is not None: + use_params[k] = _normalize_collection_value(matched_param, v) else: use_params[k] = v if context_param_name: @@ -1518,115 +1466,7 @@ def wrapper(**kwargs: Any) -> Any: return wrapper -def get_click_type( - *, annotation: Any, parameter_info: ParameterInfo -) -> types.ParamType: - if parameter_info.click_type is not None: - return parameter_info.click_type - - elif parameter_info.parser is not None: - return types.FuncParamType(parameter_info.parser) - - elif annotation is str: - return types.STRING - elif annotation is int: - if parameter_info.min is not None or parameter_info.max is not None: - min_ = None - max_ = None - if parameter_info.min is not None: - min_ = int(parameter_info.min) - if parameter_info.max is not None: - max_ = int(parameter_info.max) - return types.IntRange(min=min_, max=max_, clamp=parameter_info.clamp) - else: - return types.INT - elif annotation is float: - if parameter_info.min is not None or parameter_info.max is not None: - return types.FloatRange( - min=parameter_info.min, - max=parameter_info.max, - clamp=parameter_info.clamp, - ) - else: - return types.FLOAT - elif annotation is bool: - return types.BOOL - elif annotation == UUID: - return types.UUID - elif annotation == datetime: - return types.DateTime(formats=parameter_info.formats) - elif ( - annotation == Path - or parameter_info.allow_dash - or parameter_info.path_type - or parameter_info.resolve_path - ): - return TyperPath( - exists=parameter_info.exists, - file_okay=parameter_info.file_okay, - dir_okay=parameter_info.dir_okay, - writable=parameter_info.writable, - readable=parameter_info.readable, - resolve_path=parameter_info.resolve_path, - allow_dash=parameter_info.allow_dash, - path_type=parameter_info.path_type, - ) - elif lenient_issubclass(annotation, FileTextWrite): - return types.File( - mode=parameter_info.mode or "w", - encoding=parameter_info.encoding, - errors=parameter_info.errors, - lazy=parameter_info.lazy, - atomic=parameter_info.atomic, - ) - elif lenient_issubclass(annotation, FileText): - return types.File( - mode=parameter_info.mode or "r", - encoding=parameter_info.encoding, - errors=parameter_info.errors, - lazy=parameter_info.lazy, - atomic=parameter_info.atomic, - ) - elif lenient_issubclass(annotation, FileBinaryRead): - return types.File( - mode=parameter_info.mode or "rb", - encoding=parameter_info.encoding, - errors=parameter_info.errors, - lazy=parameter_info.lazy, - atomic=parameter_info.atomic, - ) - elif lenient_issubclass(annotation, FileBinaryWrite): - return types.File( - mode=parameter_info.mode or "wb", - encoding=parameter_info.encoding, - errors=parameter_info.errors, - lazy=parameter_info.lazy, - atomic=parameter_info.atomic, - ) - elif lenient_issubclass(annotation, Enum): - return TyperChoice( - [item.value for item in annotation], - case_sensitive=parameter_info.case_sensitive, - ) - elif is_literal_type(annotation): - return TyperChoice( - literal_values(annotation), - case_sensitive=parameter_info.case_sensitive, - ) - raise RuntimeError(f"Type not yet supported: {annotation}") # pragma: no cover - - -def lenient_issubclass(cls: Any, class_or_tuple: AnyType | tuple[AnyType, ...]) -> bool: - return isinstance(cls, type) and issubclass(cls, class_or_tuple) - - -def get_click_param( - param: ParamMeta, -) -> tuple[TyperArgument | TyperOption, Any]: - # First, find out what will be: - # * ParamInfo (ArgumentInfo or OptionInfo) - # * default_value - # * required +def get_param(param: ParamMeta) -> TyperArgument | TyperOption: default_value = None required = False if isinstance(param.default, ParameterInfo): @@ -1641,65 +1481,32 @@ def get_click_param( else: default_value = param.default parameter_info = OptionInfo() - annotation: Any - if param.annotation is not param.empty: - annotation = param.annotation - else: - annotation = str - main_type = annotation - is_list = False - is_tuple = False - parameter_type: Any = None + + annotation = parse_param_annotation(param, default_value) + annotation_args = get_args(annotation) + is_list = lenient_issubclass(get_origin(annotation), list) is_flag = None - origin = get_origin(main_type) - - if origin is not None: - # Handle SomeType | None and Optional[SomeType] - if is_union(origin): - types = [] - for type_ in get_args(main_type): - if type_ is NoneType: - continue - types.append(type_) - assert len(types) == 1, "Typer Currently doesn't support Union types" - main_type = types[0] - origin = get_origin(main_type) - # Handle Tuples and Lists - if lenient_issubclass(origin, list): - main_type = get_args(main_type)[0] - assert not get_origin(main_type), ( - "List types with complex sub-types are not currently supported" - ) - is_list = True - elif lenient_issubclass(origin, tuple): - types = [] - for type_ in get_args(main_type): - assert not get_origin(type_), ( - "Tuple types with complex sub-types are not currently supported" - ) - types.append( - get_click_type(annotation=type_, parameter_info=parameter_info) - ) - parameter_type = tuple(types) - is_tuple = True - if parameter_type is None: - parameter_type = get_click_type( - annotation=main_type, parameter_info=parameter_info - ) - convertor = determine_type_convertor(main_type) - if is_list: - convertor = generate_list_convertor( - convertor=convertor, default_value=default_value - ) - if is_tuple: - convertor = generate_tuple_convertor(get_args(main_type)) if isinstance(parameter_info, OptionInfo): - if main_type is bool: + if annotation is bool: is_flag = True - # Click doesn't accept a flag of type bool, only None, and then it sets it - # to bool internally - parameter_type = None - default_option_name = get_command_name(param.name) + elif ( + is_list + and annotation_args == (bool,) + and parameter_info.param_decls + and any("/" in decl for decl in parameter_info.param_decls) + ): + is_flag = True + descriptor = resolve_type_descriptor( + annotation=annotation, + parameter_info=parameter_info, + ) + runtime_param = build_runtime_param(descriptor) + tuple_nargs = descriptor.tuple_arity + + if isinstance(parameter_info, OptionInfo): + default_option_name = get_default_option_flag_name( + param.name, parameter_info.metavar + ) if is_flag: default_option_declaration = ( f"--{default_option_name}/--no-{default_option_name}" @@ -1711,107 +1518,105 @@ def get_click_param( param_decls.extend(parameter_info.param_decls) else: param_decls.append(default_option_declaration) - return ( - TyperOption( - # Option - param_decls=param_decls, - show_default=parameter_info.show_default, - prompt=parameter_info.prompt, - confirmation_prompt=parameter_info.confirmation_prompt, - prompt_required=parameter_info.prompt_required, - hide_input=parameter_info.hide_input, - is_flag=is_flag, - multiple=is_list, - count=parameter_info.count, - allow_from_autoenv=parameter_info.allow_from_autoenv, - type=parameter_type, - help=parameter_info.help, - hidden=parameter_info.hidden, - show_choices=parameter_info.show_choices, - show_envvar=parameter_info.show_envvar, - # Parameter - required=required, - default=default_value, - callback=get_param_callback( - callback=parameter_info.callback, convertor=convertor - ), - metavar=parameter_info.metavar, - expose_value=parameter_info.expose_value, - is_eager=parameter_info.is_eager, - envvar=parameter_info.envvar, - shell_complete=parameter_info.shell_complete, - autocompletion=get_param_completion(parameter_info.autocompletion), - # Rich settings - rich_help_panel=parameter_info.rich_help_panel, - ), - convertor, + return TyperOption( + # Option + param_decls=param_decls, + show_default=parameter_info.show_default, + prompt=parameter_info.prompt, + confirmation_prompt=parameter_info.confirmation_prompt, + prompt_required=parameter_info.prompt_required, + hide_input=parameter_info.hide_input, + is_flag=is_flag, + multiple=is_list, + count=parameter_info.count, + allow_from_autoenv=parameter_info.allow_from_autoenv, + help=parameter_info.help, + hidden=parameter_info.hidden, + show_choices=parameter_info.show_choices, + show_envvar=parameter_info.show_envvar, + # Parameter + required=required, + default=default_value, + callback=get_param_callback(callback=parameter_info.callback), + metavar=parameter_info.metavar, + expose_value=parameter_info.expose_value, + is_eager=parameter_info.is_eager, + envvar=parameter_info.envvar, + shell_complete=parameter_info.shell_complete, + autocompletion=get_param_completion(parameter_info.autocompletion), + min=parameter_info.min, + max=parameter_info.max, + nargs=tuple_nargs, + # Rich settings + rich_help_panel=parameter_info.rich_help_panel, + runtime_param=runtime_param, + type_descriptor=descriptor, ) elif isinstance(parameter_info, ArgumentInfo): param_decls = [param.name] nargs = None if is_list: nargs = -1 - return ( - TyperArgument( - # Argument - param_decls=param_decls, - type=parameter_type, - required=required, - nargs=nargs, - # TyperArgument - show_default=parameter_info.show_default, - show_choices=parameter_info.show_choices, - show_envvar=parameter_info.show_envvar, - help=parameter_info.help, - hidden=parameter_info.hidden, - # Parameter - default=default_value, - callback=get_param_callback( - callback=parameter_info.callback, convertor=convertor - ), - metavar=parameter_info.metavar, - expose_value=parameter_info.expose_value, - is_eager=parameter_info.is_eager, - envvar=parameter_info.envvar, - shell_complete=parameter_info.shell_complete, - autocompletion=get_param_completion(parameter_info.autocompletion), - # Rich settings - rich_help_panel=parameter_info.rich_help_panel, - ), - convertor, + elif tuple_nargs is not None: + nargs = tuple_nargs + return TyperArgument( + # Argument + param_decls=param_decls, + required=required, + nargs=nargs, + # TyperArgument + show_default=parameter_info.show_default, + show_choices=parameter_info.show_choices, + show_envvar=parameter_info.show_envvar, + help=parameter_info.help, + hidden=parameter_info.hidden, + # Parameter + default=default_value, + callback=get_param_callback(callback=parameter_info.callback), + metavar=parameter_info.metavar, + expose_value=parameter_info.expose_value, + is_eager=parameter_info.is_eager, + envvar=parameter_info.envvar, + shell_complete=parameter_info.shell_complete, + autocompletion=get_param_completion(parameter_info.autocompletion), + min=parameter_info.min, + max=parameter_info.max, + # Rich settings + rich_help_panel=parameter_info.rich_help_panel, + runtime_param=runtime_param, + type_descriptor=descriptor, ) - raise AssertionError("A _click.Parameter should be returned") # pragma: no cover + raise AssertionError("A Parameter should be returned") # pragma: no cover def get_param_callback( *, callback: Callable[..., Any] | None = None, - convertor: Callable[..., Any] | None = None, ) -> Callable[..., Any] | None: if not callback: return None parameters = get_params_from_function(callback) ctx_name = None - click_param_name = None + param_arg_name = None value_name = None untyped_names: list[str] = [] for param_name, param_sig in parameters.items(): if lenient_issubclass(param_sig.annotation, _click.Context): ctx_name = param_name elif lenient_issubclass(param_sig.annotation, _click.Parameter): - click_param_name = param_name + param_arg_name = param_name else: untyped_names.append(param_name) # Extract value param name first if untyped_names: value_name = untyped_names.pop() - # If context and Click param were not typed (old/Click callback style) extract them + # If context and parameter were not typed, extract them by position. if untyped_names: if ctx_name is None: ctx_name = untyped_names.pop(0) - if click_param_name is None: + if param_arg_name is None: if untyped_names: - click_param_name = untyped_names.pop(0) + param_arg_name = untyped_names.pop(0) if untyped_names: raise _click.ClickException( "Too many CLI parameter callback function parameters" @@ -1821,14 +1626,10 @@ def wrapper(ctx: _click.Context, param: _click.Parameter, value: Any) -> Any: use_params: dict[str, Any] = {} if ctx_name: use_params[ctx_name] = ctx - if click_param_name: - use_params[click_param_name] = param + if param_arg_name: + use_params[param_arg_name] = param if value_name: - if convertor: - use_value = convertor(value) - else: - use_value = value - use_params[value_name] = use_value + use_params[value_name] = _normalize_collection_value(param, value) return callback(**use_params) update_wrapper(wrapper, callback) diff --git a/typer/models.py b/typer/models.py index 00385c38ce..6d61443c7f 100644 --- a/typer/models.py +++ b/typer/models.py @@ -1,19 +1,14 @@ import inspect import io -import os -import stat from collections.abc import Callable, Sequence from typing import ( TYPE_CHECKING, Any, - ClassVar, Optional, TypeVar, - cast, ) from . import _click -from ._click import types from ._click.shell_completion import CompletionItem if TYPE_CHECKING: # pragma: no cover @@ -299,7 +294,6 @@ def __init__( default_factory: Callable[[], Any] | None = None, # Custom type parser: Callable[[str], Any] | None = None, - click_type: types.ParamType | None = None, # TyperArgument show_default: bool | str = True, show_choices: bool = True, @@ -332,13 +326,6 @@ def __init__( # Rich settings rich_help_panel: str | None = None, ): - # Check if user has provided multiple custom parsers - if parser and click_type: - raise ValueError( - "Multiple custom type parsers provided. " - "`parser` and `click_type` may not both be provided." - ) - self.default = default self.param_decls = param_decls self.callback = callback @@ -351,7 +338,6 @@ def __init__( self.default_factory = default_factory # Custom type self.parser = parser - self.click_type = click_type # TyperArgument self.show_default = show_default self.show_choices = show_choices @@ -408,7 +394,6 @@ def __init__( default_factory: Callable[[], Any] | None = None, # Custom type parser: Callable[[str], Any] | None = None, - click_type: types.ParamType | None = None, # Option show_default: bool | str = True, prompt: bool | str = False, @@ -463,7 +448,6 @@ def __init__( default_factory=default_factory, # Custom type parser=parser, - click_type=click_type, # TyperArgument show_default=show_default, show_choices=show_choices, @@ -536,7 +520,6 @@ def __init__( default_factory: Callable[[], Any] | None = None, # Custom type parser: Callable[[str], Any] | None = None, - click_type: types.ParamType | None = None, # TyperArgument show_default: bool | str = True, show_choices: bool = True, @@ -582,7 +565,6 @@ def __init__( default_factory=default_factory, # Custom type parser=parser, - click_type=click_type, # TyperArgument show_default=show_default, show_choices=show_choices, @@ -643,101 +625,3 @@ def __init__( self.pretty_exceptions_enable = pretty_exceptions_enable self.pretty_exceptions_show_locals = pretty_exceptions_show_locals self.pretty_exceptions_short = pretty_exceptions_short - - -class TyperPath(types.ParamType): - # Based originally on code from Click 8.3.1 - # Partly rewritten and added an override for shell_complete - - envvar_list_splitter: ClassVar[str] = os.path.pathsep - - def __init__( - self, - exists: bool = False, - file_okay: bool = True, - dir_okay: bool = True, - writable: bool = False, - readable: bool = True, - resolve_path: bool = False, - allow_dash: bool = False, - path_type: type[Any] | None = None, - ): - self.exists = exists - self.file_okay = file_okay - self.dir_okay = dir_okay - self.readable = readable - self.writable = writable - self.resolve_path = resolve_path - self.allow_dash = allow_dash - self.type = path_type - - if self.file_okay and not self.dir_okay: - self.name = "file" - elif self.dir_okay and not self.file_okay: - self.name = "directory" - else: - self.name = "path" - - def coerce_path_result( - self, value: str | os.PathLike[str] - ) -> str | bytes | os.PathLike[str]: - if self.type is not None and not isinstance(value, self.type): - if ( - self.type is str - ): # pragma: no cover # TODO: perhaps this branch can't be hit and should be removed - return os.fsdecode(value) - elif self.type is bytes: - return os.fsencode(value) - else: - return cast("os.PathLike[str]", self.type(value)) - - return value - - def convert( # ty: ignore[invalid-method-override] - self, - value: str | os.PathLike[str], - param: _click.Parameter | None, - ctx: Context | None, # type: ignore[override] - ) -> str | bytes | os.PathLike[str]: - rv = value - - is_dash = self.file_okay and self.allow_dash and rv in (b"-", "-") - - if not is_dash: - if self.resolve_path: - rv = os.path.realpath(rv) - - try: - st = os.stat(rv) - except OSError: - if not self.exists: - return self.coerce_path_result(rv) - self.fail( - f"{self.name.title()} {_click.utils.format_filename(value)!r} does not exist.", - param, - ctx, - ) - - name = self.name.title() - loc = repr(_click.utils.format_filename(value)) - if not self.file_okay and stat.S_ISREG(st.st_mode): - self.fail(f"{name} {loc} is a file.", param, ctx) - - if not self.dir_okay and stat.S_ISDIR(st.st_mode): - self.fail(f"{name} {loc} is a directory.", param, ctx) - - if self.readable and not os.access(rv, os.R_OK): - self.fail(f"{name} {loc} is not readable.", param, ctx) - - if self.writable and not os.access(rv, os.W_OK): - self.fail(f"{name} {loc} is not writable.", param, ctx) - - return self.coerce_path_result(rv) - - def shell_complete( - self, ctx: _click.Context, param: _click.Parameter, incomplete: str - ) -> list[CompletionItem]: - """Return an empty list so that the autocompletion functionality - will work properly from the commandline. - """ - return [] diff --git a/typer/param_types.py b/typer/param_types.py new file mode 100644 index 0000000000..b7e7f51eb1 --- /dev/null +++ b/typer/param_types.py @@ -0,0 +1,328 @@ +import os +import stat +from collections.abc import Sequence +from enum import Enum +from pathlib import Path +from typing import ( + IO, + TYPE_CHECKING, + Any, + TypeAlias, + cast, +) + +from pydantic import TypeAdapter, ValidationError + +from ._click import Context +from ._click._compat import open_stream +from ._click.exceptions import BadParameter +from ._click.utils import LazyFile, format_filename, safecall +from ._typing import get_args, get_origin, is_literal_type, is_union, literal_values +from .display import get_error_msg +from .models import ( + AnyType, + FileBinaryRead, + FileBinaryWrite, + FileText, + FileTextWrite, + NoneType, + ParameterInfo, + ParamMeta, +) + +if TYPE_CHECKING: + from .core import TyperParameter + +ParameterAnnotation: TypeAlias = Any + + +def lenient_issubclass(cls: Any, class_or_tuple: AnyType | tuple[AnyType, ...]) -> bool: + return isinstance(cls, type) and issubclass(cls, class_or_tuple) + + +def infer_annotation_from_default(default: Any | None) -> ParameterAnnotation: + """Infer a normalized annotation from a default value.""" + if default is None: + return str + if isinstance(default, tuple) and len(default) > 0: + if not isinstance(default[0], (tuple, list)): + return tuple.__class_getitem__(tuple(map(type, default))) + if isinstance(default, (tuple, list)): + if not default: + return str + item = default[0] + if isinstance(item, (tuple, list)): + return tuple.__class_getitem__(tuple(map(type, item))) + return type(item) + return type(default) + + +def annotation_from_prompt(t: Any | None, default: Any | None) -> ParameterAnnotation: + if t is not None: + return t + return infer_annotation_from_default(default) + + +def parse_param_annotation( + param: ParamMeta, default: Any | None +) -> ParameterAnnotation: + """Parse the annotation for a callback parameter.""" + if param.annotation is not param.empty: + main_type = param.annotation + origin = get_origin(main_type) + + if origin is not None: + if is_union(origin): + types = [] + for type_ in get_args(main_type): + if type_ is NoneType: + continue + types.append(type_) + assert len(types) == 1, "Typer currently doesn't support Union types" + main_type = types[0] + origin = get_origin(main_type) + + if lenient_issubclass(origin, list): + element_type = get_args(main_type)[0] + assert not get_origin(element_type), ( + "List types with complex sub-types are not currently supported" + ) + return main_type + if lenient_issubclass(origin, tuple): + type_args = get_args(main_type) + for type_ in type_args: + assert not get_origin(type_), ( + "Tuple types with complex sub-types are not currently supported" + ) + return main_type + return main_type + return main_type + return infer_annotation_from_default(default) + + +# ENUM # +def normalize_choice_value( + choice: Any, + case_sensitive: bool, + ctx: Context | None, +) -> str: + normed_value = str(choice.value) if isinstance(choice, Enum) else str(choice) + if ctx is not None and ctx.token_normalize_func is not None: + normed_value = ctx.token_normalize_func(normed_value) + if not case_sensitive: + normed_value = normed_value.casefold() + return normed_value + + +def coerce_cli_choice( + value: Any, + *, + choices: Sequence[Any], + case_sensitive: bool, + ctx: Context | None = None, +) -> Any: + if any(isinstance(choice, Enum) and value is choice for choice in choices): + return value + normalized_mapping = { + c: normalize_choice_value(c, case_sensitive, ctx) for c in choices + } + normed_value = normalize_choice_value(value, case_sensitive, ctx) + for original, normalized in normalized_mapping.items(): + if normalized == normed_value: + return original + choices_str = ", ".join(map(repr, normalized_mapping.values())) + raise ValueError(f"{value!r} is not one of {choices_str}.") + + +def choice_coercion_annotation( + annotation: Any, + parameter_info: ParameterInfo, +) -> tuple[tuple[Any, ...], bool] | None: + if lenient_issubclass(annotation, Enum): + return tuple(annotation), parameter_info.case_sensitive + if is_literal_type(annotation): + return literal_values(annotation), parameter_info.case_sensitive + return None + + +def choice_as_str(choice: Any) -> str: + if isinstance(choice, Enum): + return str(choice.value) + return str(choice) + + +# PATH # +def path_type_name(parameter_info: ParameterInfo) -> str: + if parameter_info.file_okay and not parameter_info.dir_okay: + return "file" + if parameter_info.dir_okay and not parameter_info.file_okay: + return "dir" + return "path" + + +def _coerce_path_result( + value: str | os.PathLike[str], + path_type: type[Any] | None, +) -> str | bytes | os.PathLike[str]: + if path_type is not None and not isinstance(value, path_type): + if path_type is bytes: + return os.fsencode(value) + return cast("os.PathLike[str]", path_type(value)) + return value + + +def coerce_cli_path( + value: str | os.PathLike[str], + parameter_info: ParameterInfo, + *, + path_type: type[Any] | None, + param: "TyperParameter | None" = None, + ctx: Context | None = None, +) -> str | bytes | os.PathLike[str] | Path: + if path_type is None or path_type is str or path_type is bytes: + rv: Any = value + elif isinstance(path_type, type) and issubclass(path_type, Path): + if isinstance(value, path_type): + rv = value + elif isinstance(value, (str, os.PathLike)): + try: + rv = TypeAdapter(path_type).validate_python(value) + except ValidationError as exc: + raise BadParameter(get_error_msg(exc), ctx=ctx, param=param) from exc + else: + rv = value + else: + rv = value + + is_dash = ( + parameter_info.file_okay and parameter_info.allow_dash and rv in (b"-", "-") + ) + + if not is_dash: + if parameter_info.resolve_path: + rv = os.path.realpath(rv) + + label = path_type_name(parameter_info) + try: + st = os.stat(rv) + except OSError: + if not parameter_info.exists: + return _coerce_path_result(rv, path_type) + raise BadParameter( + f"{label} {format_filename(value)!r} does not exist.", + ctx=ctx, + param=param, + ) from None + + loc = repr(format_filename(value)) + if not parameter_info.file_okay and stat.S_ISREG(st.st_mode): + raise BadParameter(f"{label} {loc} is a file.", ctx=ctx, param=param) + + if not parameter_info.dir_okay and stat.S_ISDIR(st.st_mode): + raise BadParameter(f"{label} {loc} is a directory.", ctx=ctx, param=param) + + if parameter_info.readable and not os.access(rv, os.R_OK): + raise BadParameter(f"{label} {loc} is not readable.", ctx=ctx, param=param) + + if parameter_info.writable and not os.access(rv, os.W_OK): + raise BadParameter(f"{label} {loc} is not writable.", ctx=ctx, param=param) + + return _coerce_path_result(rv, path_type) + + +# FILE # +CLI_FILE_TYPES = (FileTextWrite, FileText, FileBinaryRead, FileBinaryWrite) + + +def is_file_annotation(annotation: Any) -> bool: + return lenient_issubclass(annotation, CLI_FILE_TYPES) + + +def file_coercion_annotation(annotation: Any) -> Any | None: + """Return the file marker type when this parameter opens files.""" + origin = get_origin(annotation) + if origin is list: + args = get_args(annotation) + if args and all(is_file_annotation(arg) for arg in args): + return args[0] + return None + if origin is tuple: + args = get_args(annotation) + if args and all(is_file_annotation(arg) for arg in args): + return args[0] + return None + if is_file_annotation(annotation): + return annotation + return None + + +def resolve_file_mode(parameter_info: ParameterInfo, annotation: Any) -> str: + if parameter_info.mode is not None: + return parameter_info.mode + if lenient_issubclass(annotation, FileBinaryWrite): + return "wb" + if lenient_issubclass(annotation, FileTextWrite): + return "w" + if lenient_issubclass(annotation, FileBinaryRead): + return "rb" + return "r" + + +def _open_cli_file( + value: str | os.PathLike[str] | IO[Any], + parameter_info: ParameterInfo, + *, + mode: str, + param: "TyperParameter | None" = None, + ctx: Context | None = None, +) -> IO[Any]: + if hasattr(value, "read") or hasattr(value, "write"): + return cast("IO[Any]", value) + + if isinstance(value, str): + path: str | os.PathLike[str] = value + else: + path = value + + try: + lazy = parameter_info.lazy + if lazy is None: + if os.fspath(path) == "-": + lazy = False + elif "w" in mode: + lazy = True + else: + lazy = False + + if lazy: + lf = LazyFile( + path, + mode, + parameter_info.encoding, + parameter_info.errors, + atomic=parameter_info.atomic, + ) + + if ctx is not None: + ctx.call_on_close(lf.close_intelligently) + + return cast("IO[Any]", lf) + + f, should_close = open_stream( + path, + mode, + parameter_info.encoding, + parameter_info.errors, + atomic=parameter_info.atomic, + ) + + if ctx is not None: + if should_close: + ctx.call_on_close(safecall(f.close)) + else: + ctx.call_on_close(safecall(f.flush)) + + return f + except OSError as exc: # pragma: no cover + message = f"'{format_filename(path)}': {exc.strerror}" + raise BadParameter(message, ctx=ctx, param=param) from exc diff --git a/typer/params.py b/typer/params.py index 833461fa78..5aef6da32d 100644 --- a/typer/params.py +++ b/typer/params.py @@ -1,10 +1,9 @@ from collections.abc import Callable -from typing import TYPE_CHECKING, Annotated, Any, overload +from typing import TYPE_CHECKING, Annotated, Any from annotated_doc import Doc from . import _click -from ._click import types from ._click.shell_completion import CompletionItem from .models import ArgumentInfo, OptionInfo @@ -12,136 +11,6 @@ pass -# Overload for Option created with custom type 'parser' -@overload -def Option( - # Parameter - default: Any | None = ..., - *param_decls: str, - callback: Callable[..., Any] | None = None, - metavar: str | None = None, - expose_value: bool = True, - is_eager: bool = False, - envvar: str | list[str] | None = None, - # Note that shell_complete is not fully supported and will be removed in future versions - # TODO: Remove shell_complete in a future version (after 0.16.0) - shell_complete: Callable[ - [_click.Context, _click.Parameter, str], - list["CompletionItem"] | list[str], - ] - | None = None, - autocompletion: Callable[..., Any] | None = None, - default_factory: Callable[[], Any] | None = None, - # Custom type - parser: Callable[[str], Any] | None = None, - # Option - show_default: bool | str = True, - prompt: bool | str = False, - confirmation_prompt: bool = False, - prompt_required: bool = True, - hide_input: bool = False, - # TODO: remove is_flag and flag_value in a future release - is_flag: bool | None = None, - flag_value: Any | None = None, - count: bool = False, - allow_from_autoenv: bool = True, - help: str | None = None, - hidden: bool = False, - show_choices: bool = True, - show_envvar: bool = True, - # Choice - case_sensitive: bool = True, - # Numbers - min: int | float | None = None, - max: int | float | None = None, - clamp: bool = False, - # DateTime - formats: list[str] | None = None, - # File - mode: str | None = None, - encoding: str | None = None, - errors: str | None = "strict", - lazy: bool | None = None, - atomic: bool = False, - # Path - exists: bool = False, - file_okay: bool = True, - dir_okay: bool = True, - writable: bool = False, - readable: bool = True, - resolve_path: bool = False, - allow_dash: bool = False, - path_type: None | type[str] | type[bytes] = None, - # Rich settings - rich_help_panel: str | None = None, -) -> Any: ... - - -# Overload for Option created with custom type 'click_type' -@overload -def Option( - # Parameter - default: Any | None = ..., - *param_decls: str, - callback: Callable[..., Any] | None = None, - metavar: str | None = None, - expose_value: bool = True, - is_eager: bool = False, - envvar: str | list[str] | None = None, - # Note that shell_complete is not fully supported and will be removed in future versions - # TODO: Remove shell_complete in a future version (after 0.16.0) - shell_complete: Callable[ - [_click.Context, _click.Parameter, str], - list["CompletionItem"] | list[str], - ] - | None = None, - autocompletion: Callable[..., Any] | None = None, - default_factory: Callable[[], Any] | None = None, - # Custom type - click_type: types.ParamType | None = None, - # Option - show_default: bool | str = True, - prompt: bool | str = False, - confirmation_prompt: bool = False, - prompt_required: bool = True, - hide_input: bool = False, - # TODO: remove is_flag and flag_value in a future release - is_flag: bool | None = None, - flag_value: Any | None = None, - count: bool = False, - allow_from_autoenv: bool = True, - help: str | None = None, - hidden: bool = False, - show_choices: bool = True, - show_envvar: bool = True, - # Choice - case_sensitive: bool = True, - # Numbers - min: int | float | None = None, - max: int | float | None = None, - clamp: bool = False, - # DateTime - formats: list[str] | None = None, - # File - mode: str | None = None, - encoding: str | None = None, - errors: str | None = "strict", - lazy: bool | None = None, - atomic: bool = False, - # Path - exists: bool = False, - file_okay: bool = True, - dir_okay: bool = True, - writable: bool = False, - readable: bool = True, - resolve_path: bool = False, - allow_dash: bool = False, - path_type: None | type[str] | type[bytes] = None, - # Rich settings - rich_help_panel: str | None = None, -) -> Any: ... - - def Option( # Parameter default: Annotated[ @@ -344,35 +213,6 @@ def main(opt: Annotated[CustomClass, typer.Option(parser=my_parser)] = "Foo"): """ ), ] = None, - click_type: Annotated[ - types.ParamType | None, - Doc( - """ - Define this parameter to use a [custom Click type](https://click.palletsprojects.com/en/stable/parameters/#implementing-custom-types) in your Typer applications. - - **Example** - - ```python - class MyClass: - def __init__(self, value: str): - self.value = value - - def __str__(self): - return f"" - - class MyParser(click.ParamType): - name = "MyClass" - - def convert(self, value, param, ctx): - return MyClass(value * 3) - - @app.command() - def main(opt: Annotated[MyClass, typer.Option(click_type=MyParser())] = "Foo"): - print(f"--opt is {opt}") - ``` - """ - ), - ] = None, # Option show_default: Annotated[ bool | str, @@ -959,7 +799,6 @@ def register( default_factory=default_factory, # Custom type parser=parser, - click_type=click_type, # Option show_default=show_default, prompt=prompt, @@ -1002,118 +841,6 @@ def register( ) -# Overload for Argument created with custom type 'parser' -@overload -def Argument( - # Parameter - default: Any | None = ..., - *, - callback: Callable[..., Any] | None = None, - metavar: str | None = None, - expose_value: bool = True, - is_eager: bool = False, - envvar: str | list[str] | None = None, - # Note that shell_complete is not fully supported and will be removed in future versions - # TODO: Remove shell_complete in a future version (after 0.16.0) - shell_complete: Callable[ - [_click.Context, _click.Parameter, str], - list["CompletionItem"] | list[str], - ] - | None = None, - autocompletion: Callable[..., Any] | None = None, - default_factory: Callable[[], Any] | None = None, - # Custom type - parser: Callable[[str], Any] | None = None, - # TyperArgument - show_default: bool | str = True, - show_choices: bool = True, - show_envvar: bool = True, - help: str | None = None, - hidden: bool = False, - # Choice - case_sensitive: bool = True, - # Numbers - min: int | float | None = None, - max: int | float | None = None, - clamp: bool = False, - # DateTime - formats: list[str] | None = None, - # File - mode: str | None = None, - encoding: str | None = None, - errors: str | None = "strict", - lazy: bool | None = None, - atomic: bool = False, - # Path - exists: bool = False, - file_okay: bool = True, - dir_okay: bool = True, - writable: bool = False, - readable: bool = True, - resolve_path: bool = False, - allow_dash: bool = False, - path_type: None | type[str] | type[bytes] = None, - # Rich settings - rich_help_panel: str | None = None, -) -> Any: ... - - -# Overload for Argument created with custom type 'click_type' -@overload -def Argument( - # Parameter - default: Any | None = ..., - *, - callback: Callable[..., Any] | None = None, - metavar: str | None = None, - expose_value: bool = True, - is_eager: bool = False, - envvar: str | list[str] | None = None, - # Note that shell_complete is not fully supported and will be removed in future versions - # TODO: Remove shell_complete in a future version (after 0.16.0) - shell_complete: Callable[ - [_click.Context, _click.Parameter, str], - list["CompletionItem"] | list[str], - ] - | None = None, - autocompletion: Callable[..., Any] | None = None, - default_factory: Callable[[], Any] | None = None, - # Custom type - click_type: types.ParamType | None = None, - # TyperArgument - show_default: bool | str = True, - show_choices: bool = True, - show_envvar: bool = True, - help: str | None = None, - hidden: bool = False, - # Choice - case_sensitive: bool = True, - # Numbers - min: int | float | None = None, - max: int | float | None = None, - clamp: bool = False, - # DateTime - formats: list[str] | None = None, - # File - mode: str | None = None, - encoding: str | None = None, - errors: str | None = "strict", - lazy: bool | None = None, - atomic: bool = False, - # Path - exists: bool = False, - file_okay: bool = True, - dir_okay: bool = True, - writable: bool = False, - readable: bool = True, - resolve_path: bool = False, - allow_dash: bool = False, - path_type: None | type[str] | type[bytes] = None, - # Rich settings - rich_help_panel: str | None = None, -) -> Any: ... - - def Argument( # Parameter default: Annotated[ @@ -1298,35 +1025,6 @@ def main(arg: Annotated[CustomClass, typer.Argument(parser=my_parser): """ ), ] = None, - click_type: Annotated[ - types.ParamType | None, - Doc( - """ - Define this parameter to use a [custom Click type](https://click.palletsprojects.com/en/stable/parameters/#implementing-custom-types) in your Typer applications. - - **Example** - - ```python - class MyClass: - def __init__(self, value: str): - self.value = value - - def __str__(self): - return f"" - - class MyParser(click.ParamType): - name = "MyClass" - - def convert(self, value, param, ctx): - return MyClass(value * 3) - - @app.command() - def main(arg: Annotated[MyClass, typer.Argument(click_type=MyParser())]): - print(f"arg is {arg}") - ``` - """ - ), - ] = None, # TyperArgument show_default: Annotated[ bool | str, @@ -1798,7 +1496,6 @@ def main(name: Annotated[str, typer.Argument()] = "World"): default_factory=default_factory, # Custom type parser=parser, - click_type=click_type, # TyperArgument show_default=show_default, show_choices=show_choices, diff --git a/typer/rich_utils.py b/typer/rich_utils.py index de68f60644..70b67c8bcf 100644 --- a/typer/rich_utils.py +++ b/typer/rich_utils.py @@ -25,7 +25,6 @@ from typer.models import DeveloperExceptionConfig from . import _click -from ._click import types from .core import TyperArgument, TyperGroup, TyperOption # Default styles @@ -365,56 +364,35 @@ def _print_options_panel( secondary_opt_long_strs = [] secondary_opt_short_strs = [] - # check whether argument has a metavar name or type set - metavar_name = None - metavar_type = None - metavar_str = param.make_metavar(ctx=ctx) + # Argument name label and option type display use separate APIs. + display_name: str | None = None if isinstance(param, TyperArgument): - # TODO: revise this legacy behaviour of keeping argument names lowercased for Rich formatting - if param.metavar is None and param.name: - metavar_name = metavar_str.replace(param.name.upper(), param.name) - else: - metavar_name = metavar_str - if isinstance(param, TyperOption): - metavar_type = metavar_str + display_name = param.rich_display_name() for opt_str in param.opts: if "--" in opt_str: opt_long_strs.append(opt_str) - elif metavar_name: - opt_short_strs.append(metavar_name) + elif display_name: + opt_short_strs.append(display_name) else: opt_short_strs.append(opt_str) for opt_str in param.secondary_opts: if "--" in opt_str: secondary_opt_long_strs.append(opt_str) - elif metavar_name: # pragma: no cover - secondary_opt_short_strs.append(metavar_name) + elif display_name: # pragma: no cover + secondary_opt_short_strs.append(display_name) else: secondary_opt_short_strs.append(opt_str) # Column for recording the type types_data = Text(style=STYLE_TYPES, overflow="fold") + display_type_str = param.display_type_rich(ctx=ctx) + if display_type_str is not None: + types_data.append(display_type_str) - # Fetch type - if metavar_type and metavar_type != "BOOLEAN": - types_data.append(metavar_type) - else: - type_str = param.type.name.upper() - if type_str != "BOOLEAN": - types_data.append(type_str) - - # Range - from - # https://github.com/pallets/click/blob/c63c70dabd3f86ca68678b4f00951f78f52d0270/src/click/core.py#L2698-L2706 # noqa: E501 - # skip count with default range type - if ( - isinstance(param.type, types._NumberRangeBase) - and isinstance(param, TyperOption) - and not (param.count and param.type.min == 0 and param.type.max is None) - ): - range_str = param.type._describe_range() - if range_str: - types_data.append(RANGE_STRING.format(range_str)) + range_str = param.get_number_range_help_str() + if range_str: + types_data.append(RANGE_STRING.format(range_str)) # Required asterisk required: str | Text = "" diff --git a/uv.lock b/uv.lock index 6ac1892c6f..91f39d020d 100644 --- a/uv.lock +++ b/uv.lock @@ -1791,6 +1791,7 @@ source = { editable = "." } dependencies = [ { name = "annotated-doc" }, { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "pydantic" }, { name = "rich" }, { name = "shellingham" }, ] @@ -1854,6 +1855,7 @@ tests = [ requires-dist = [ { name = "annotated-doc", specifier = ">=0.0.2" }, { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "pydantic", specifier = ">=2.5.3" }, { name = "rich", specifier = ">=13.8.0" }, { name = "shellingham", specifier = ">=1.3.0" }, ]