From 78509d81afb9b65ac835489037531ccba129d4bb Mon Sep 17 00:00:00 2001 From: Andreas Backx Date: Tue, 3 Dec 2024 23:21:21 +0000 Subject: [PATCH] Expanded choice value support. --- src/click/types.py | 30 ++++++------ tests/test_basic.py | 94 +++++++++++++++++++++++++++---------- tests/test_normalization.py | 37 ++++++++++++--- 3 files changed, 116 insertions(+), 45 deletions(-) diff --git a/src/click/types.py b/src/click/types.py index 302f9eb17..63ddc0b03 100644 --- a/src/click/types.py +++ b/src/click/types.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections.abc as cabc +import enum import os import stat import sys @@ -249,9 +250,9 @@ class Choice(ParamType, t.Generic[ParamTypeValue]): name = "choice" def __init__( - self, choices: cabc.Sequence[ParamTypeValue], case_sensitive: bool = True + self, choices: cabc.Iterable[ParamTypeValue], case_sensitive: bool = True ) -> None: - self.choices = choices + self.choices: cabc.Sequence[ParamTypeValue] = list(choices) self.case_sensitive = case_sensitive def to_info_dict(self) -> dict[str, t.Any]: @@ -277,7 +278,7 @@ def normalized_mapping( for choice in self.choices } - def normalize_choice(self, choice: ParamTypeValue, ctx: Context | None) -> str: + def normalize_choice(self, choice: t.Any, ctx: Context | None) -> str: """ Normalize a choice value. @@ -286,7 +287,7 @@ def normalize_choice(self, choice: ParamTypeValue, ctx: Context | None) -> str: .. versionadded:: 8.2.0 """ - normed_value = str(choice) + normed_value = choice.name if isinstance(choice, enum.Enum) else str(choice) if ctx is not None and ctx.token_normalize_func is not None: normed_value = ctx.token_normalize_func(normed_value) @@ -326,27 +327,28 @@ def get_missing_message(self, param: Parameter, ctx: Context | None) -> str: def convert( self, value: t.Any, param: Parameter | None, ctx: Context | None - ) -> t.Any: + ) -> ParamTypeValue: + """ + For a given value from the parser, normalize it and find its + matching normalized value in the list of choices. Then return the + matched "original" choice. + """ normed_value = self.normalize_choice(choice=value, ctx=ctx) normalized_mapping = self.normalized_mapping(ctx=ctx) - original_choice = next( - ( + + try: + return next( original for original, normalized in normalized_mapping.items() if normalized == normed_value - ), - None, - ) - - if not original_choice: + ) + except StopIteration: self.fail( self.get_invalid_choice_message(value=value, ctx=ctx), param=param, ctx=ctx, ) - return original_choice - def get_invalid_choice_message(self, value: t.Any, ctx: Context | None) -> str: """Get the error message when the given choice is invalid. diff --git a/tests/test_basic.py b/tests/test_basic.py index 89e722d4f..b84ae73d6 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import enum import os from itertools import chain @@ -382,45 +385,68 @@ def cli(method): assert "--method [foo|bar|baz]" in result.output -def test_choice_option_normalization(runner): +def test_choice_argument(runner): @click.command() - @click.option( - "--method", - type=click.Choice( - ["SCREAMING_SNAKE_CASE", "snake_case", "PascalCase", "kebab-case"], - case_sensitive=False, - ), - ) + @click.argument("method", type=click.Choice(["foo", "bar", "baz"])) def cli(method): click.echo(method) - result = runner.invoke(cli, ["--method=snake_case"]) - assert not result.exception, result.output - assert result.output == "snake_case\n" - - # Even though it's case sensitive, the choice's original value is preserved - result = runner.invoke(cli, ["--method=pascalcase"]) - assert not result.exception, result.output - assert result.output == "PascalCase\n" + result = runner.invoke(cli, ["foo"]) + assert not result.exception + assert result.output == "foo\n" - result = runner.invoke(cli, ["--method=meh"]) + result = runner.invoke(cli, ["meh"]) assert result.exit_code == 2 assert ( - "Invalid value for '--method': 'meh' is not one of " - "'screaming_snake_case', 'snake_case', 'pascalcase', 'kebab-case'." - ) in result.output + "Invalid value for '{foo|bar|baz}': 'meh' is not one of 'foo'," + " 'bar', 'baz'." in result.output + ) result = runner.invoke(cli, ["--help"]) + assert "{foo|bar|baz}" in result.output + + +def test_choice_argument_enum(runner): + class MyEnum(str, enum.Enum): + FOO = "foo-value" + BAR = "bar-value" + BAZ = "baz-value" + + @click.command() + @click.argument("method", type=click.Choice(MyEnum, case_sensitive=False)) + def cli(method: MyEnum): + assert isinstance(method, MyEnum) + click.echo(method) + + result = runner.invoke(cli, ["foo"]) + assert result.output == "foo-value\n" + assert not result.exception + + result = runner.invoke(cli, ["meh"]) + assert result.exit_code == 2 assert ( - "--method [screaming_snake_case|snake_case|pascalcase|kebab-case]" - in result.output + "Invalid value for '{foo|bar|baz}': 'meh' is not one of 'foo'," + " 'bar', 'baz'." in result.output ) + result = runner.invoke(cli, ["--help"]) + assert "{foo|bar|baz}" in result.output + + +def test_choice_argument_custom_type(runner): + class MyClass: + def __init__(self, value: str) -> None: + self.value = value + + def __str__(self) -> str: + return self.value -def test_choice_argument(runner): @click.command() - @click.argument("method", type=click.Choice(["foo", "bar", "baz"])) - def cli(method): + @click.argument( + "method", type=click.Choice([MyClass("foo"), MyClass("bar"), MyClass("baz")]) + ) + def cli(method: MyClass): + assert isinstance(method, MyClass) click.echo(method) result = runner.invoke(cli, ["foo"]) @@ -438,6 +464,24 @@ def cli(method): assert "{foo|bar|baz}" in result.output +def test_choice_argument_none(runner): + @click.command() + @click.argument( + "method", type=click.Choice(["not-none", None], case_sensitive=False) + ) + def cli(method: str | None): + assert isinstance(method, str) or method is None + click.echo(method) + + result = runner.invoke(cli, ["not-none"]) + assert not result.exception + assert result.output == "not-none\n" + + # None is not yet supported. + result = runner.invoke(cli, ["none"]) + assert result.exception + + def test_datetime_option_default(runner): @click.command() @click.option("--start_date", type=click.DateTime()) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 502e654a3..442b638f4 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -17,12 +17,37 @@ def cli(foo, x): def test_choice_normalization(runner): @click.command(context_settings=CONTEXT_SETTINGS) - @click.option("--choice", type=click.Choice(["Foo", "Bar"])) - def cli(choice): - click.echo(choice) - - result = runner.invoke(cli, ["--CHOICE", "FOO"]) - assert result.output == "Foo\n" + @click.option( + "--method", + type=click.Choice( + ["SCREAMING_SNAKE_CASE", "snake_case", "PascalCase", "kebab-case"], + case_sensitive=False, + ), + ) + def cli(method): + click.echo(method) + + result = runner.invoke(cli, ["--METHOD=snake_case"]) + assert not result.exception, result.output + assert result.output == "snake_case\n" + + # Even though it's case sensitive, the choice's original value is preserved + result = runner.invoke(cli, ["--method=pascalcase"]) + assert not result.exception, result.output + assert result.output == "PascalCase\n" + + result = runner.invoke(cli, ["--method=meh"]) + assert result.exit_code == 2 + assert ( + "Invalid value for '--method': 'meh' is not one of " + "'screaming_snake_case', 'snake_case', 'pascalcase', 'kebab-case'." + ) in result.output + + result = runner.invoke(cli, ["--help"]) + assert ( + "--method [screaming_snake_case|snake_case|pascalcase|kebab-case]" + in result.output + ) def test_command_normalization(runner):