diff --git a/docs/tutorial/parameter-types/enum.md b/docs/tutorial/parameter-types/enum.md index 767c329b27..b208ed9aa5 100644 --- a/docs/tutorial/parameter-types/enum.md +++ b/docs/tutorial/parameter-types/enum.md @@ -35,6 +35,14 @@ Usage: main.py [OPTIONS] Try "main.py --help" for help. Error: Invalid value for '--network': invalid choice: capsule. (choose from simple, conv, lstm) + +// Note that enums are case sensitive by default +$ python main.py --network CONV + +Usage: main.py [OPTIONS] +Try "main.py --help" for help. + +Error: Invalid value for '--network': invalid choice: CONV. (choose from simple, conv, lstm) ``` diff --git a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial001.py b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial001.py index 584b71a0c7..4f2c246ea0 100644 --- a/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial001.py +++ b/tests/test_tutorial/test_parameter_types/test_enum/test_tutorial001.py @@ -25,7 +25,22 @@ def test_main(): assert "Training neural network of type: conv" in result.output -def test_invalid(): +def test_invalid_case(): + result = runner.invoke(app, ["--network", "CONV"]) + assert result.exit_code != 0 + # TODO: when deprecating Click 7, remove second option + + assert ( + "Invalid value for '--network': 'CONV' is not one of" in result.output + or "Invalid value for '--network': invalid choice: CONV. (choose from" + in result.output + ) + assert "simple" in result.output + assert "conv" in result.output + assert "lstm" in result.output + + +def test_invalid_other(): result = runner.invoke(app, ["--network", "capsule"]) assert result.exit_code != 0 # TODO: when deprecating Click 7, remove second option diff --git a/typer/main.py b/typer/main.py index 9de5f5960d..c739a293ae 100644 --- a/typer/main.py +++ b/typer/main.py @@ -618,13 +618,13 @@ def param_path_convertor(value: Optional[str] = None) -> Optional[Path]: def generate_enum_convertor(enum: Type[Enum]) -> Callable[[Any], Any]: - lower_val_map = {str(val.value).lower(): val for val in enum} + val_map = {str(val.value): val for val in enum} def convertor(value: Any) -> Any: if value is not None: - low = str(value).lower() - if low in lower_val_map: - key = lower_val_map[low] + val = str(value) + if val in val_map: + key = val_map[val] return enum(key) return convertor