From 1e0334f149d352de93c8fbd7759a6f69cdbbc462 Mon Sep 17 00:00:00 2001 From: gkarg Date: Fri, 30 Jun 2023 17:58:46 +0300 Subject: [PATCH] Initial attempt at pydantic support. --- typer/main.py | 12 ++- typer/utils.py | 266 +++++++++++++++++++++++++++++++++++-------------- 2 files changed, 201 insertions(+), 77 deletions(-) diff --git a/typer/main.py b/typer/main.py index aa39e82849..68bc1ca169 100644 --- a/typer/main.py +++ b/typer/main.py @@ -34,7 +34,7 @@ Required, TyperInfo, ) -from .utils import get_params_from_function +from .utils import get_params_from_function, pydantic, update_pydantic_params try: import rich @@ -680,6 +680,8 @@ def wrapper(**kwargs: Any) -> Any: use_params[k] = v if context_param_name: use_params[context_param_name] = click.get_current_context() + + update_pydantic_params(callback, use_params) return callback(**use_params) # type: ignore update_wrapper(wrapper, callback) @@ -692,6 +694,14 @@ def get_click_type( if parameter_info.click_type is not None: return parameter_info.click_type + elif pydantic and lenient_issubclass(annotation, pydantic.BaseModel): + + class CustomParamType(click.ParamType): + def convert(self, value, param, ctx): + return annotation.parse_raw(value) + + return CustomParamType() + elif parameter_info.parser is not None: return click.types.FuncParamType(parameter_info.parser) diff --git a/typer/utils.py b/typer/utils.py index 44816e2420..20790f4781 100644 --- a/typer/utils.py +++ b/typer/utils.py @@ -1,12 +1,30 @@ import inspect from copy import copy -from typing import Any, Callable, Dict, List, Tuple, Type, cast, get_type_hints +from inspect import Parameter +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + cast, + get_type_hints, +) from typing_extensions import Annotated from ._typing import get_args, get_origin from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta +try: + import pydantic +except ImportError: + pydantic = None + +PYDANTIC_FIELD_DELIMITER = "__" + def _param_type_to_user_string(param_type: Type[ParameterInfo]) -> str: # Render a `ParameterInfo` subclass for use in error messages. @@ -105,83 +123,179 @@ def _split_annotation_from_typer_annotations( ] -def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]: +def _lenient_issubclass( + cls: Any, class_or_tuple # : Union[AnyType, Tuple[AnyType, ...]] +) -> bool: + return isinstance(cls, type) and issubclass(cls, class_or_tuple) + + +def _process_pydantic_model(model: "pydantic.BaseModel", names: List[str]): + for field in model.__fields__.values(): + if _lenient_issubclass(field.type_, pydantic.BaseModel): + names.append(field.name) + yield from _process_pydantic_model(field.type_, names) + names.pop() + else: + name = PYDANTIC_FIELD_DELIMITER.join(names + [field.name]) + yield name, ParamMeta( + name=name, default=field.default, annotation=field.type_ + ) + + +def _process_param_from_function( + type_hints: Dict[str, Any], param: Parameter, expand_pydantic_fields: bool = True +) -> Tuple[str, ParamMeta]: + annotation, typer_annotations = _split_annotation_from_typer_annotations( + param.annotation, + ) + if len(typer_annotations) > 1: + raise MultipleTyperAnnotationsError(param.name) + + if ( + expand_pydantic_fields + and pydantic + and _lenient_issubclass(annotation, pydantic.BaseModel) + ): + yield from _process_pydantic_model(annotation, [param.name]) + return + + default = param.default + if typer_annotations: + # It's something like `my_param: Annotated[str, Argument()]` + [parameter_info] = typer_annotations + + # Forbid `my_param: Annotated[str, Argument()] = Argument("...")` + if isinstance(param.default, ParameterInfo): + raise MixedAnnotatedAndDefaultStyleError( + argument_name=param.name, + annotated_param_type=type(parameter_info), + default_param_type=type(param.default), + ) + + parameter_info = copy(parameter_info) + + # When used as a default, `Option` takes a default value and option names + # as positional arguments: + # `Option(some_value, "--some-argument", "-s")` + # When used in `Annotated` (ie, what this is handling), `Option` just takes + # option names as positional arguments: + # `Option("--some-argument", "-s")` + # In this case, the `default` attribute of `parameter_info` is actually + # meant to be the first item of `param_decls`. + if isinstance(parameter_info, OptionInfo) and parameter_info.default is not ...: + parameter_info.param_decls = ( + cast(str, parameter_info.default), + *(parameter_info.param_decls or ()), + ) + parameter_info.default = ... + + # Forbid `my_param: Annotated[str, Argument('some-default')]` + if parameter_info.default is not ...: + raise AnnotatedParamWithDefaultValueError( + param_type=type(parameter_info), + argument_name=param.name, + ) + if param.default is not param.empty: + # Put the parameter's default (set by `=`) into `parameter_info`, where + # typer can find it. + parameter_info.default = param.default + + default = parameter_info + elif param.name in type_hints: + # Resolve forward references. + annotation = type_hints[param.name] + + if isinstance(default, ParameterInfo): + parameter_info = copy(default) + # Click supports `default` as either + # - an actual value; or + # - a factory function (returning a default value.) + # The two are not interchangeable for static typing, so typer allows + # specifying `default_factory`. Move the `default_factory` into `default` + # so click can find it. + if parameter_info.default is ... and parameter_info.default_factory: + parameter_info.default = parameter_info.default_factory + elif parameter_info.default_factory: + raise DefaultFactoryAndDefaultValueError( + argument_name=param.name, param_type=type(parameter_info) + ) + default = parameter_info + yield param.name, ParamMeta(name=param.name, default=default, annotation=annotation) + + +def get_params_from_function( + func: Callable[..., Any], expand_pydantic_fields: bool = True +) -> Dict[str, ParamMeta]: signature = inspect.signature(func) type_hints = get_type_hints(func) params = {} for param in signature.parameters.values(): - annotation, typer_annotations = _split_annotation_from_typer_annotations( - param.annotation, - ) - if len(typer_annotations) > 1: - raise MultipleTyperAnnotationsError(param.name) - - default = param.default - if typer_annotations: - # It's something like `my_param: Annotated[str, Argument()]` - [parameter_info] = typer_annotations - - # Forbid `my_param: Annotated[str, Argument()] = Argument("...")` - if isinstance(param.default, ParameterInfo): - raise MixedAnnotatedAndDefaultStyleError( - argument_name=param.name, - annotated_param_type=type(parameter_info), - default_param_type=type(param.default), - ) - - parameter_info = copy(parameter_info) - - # When used as a default, `Option` takes a default value and option names - # as positional arguments: - # `Option(some_value, "--some-argument", "-s")` - # When used in `Annotated` (ie, what this is handling), `Option` just takes - # option names as positional arguments: - # `Option("--some-argument", "-s")` - # In this case, the `default` attribute of `parameter_info` is actually - # meant to be the first item of `param_decls`. - if ( - isinstance(parameter_info, OptionInfo) - and parameter_info.default is not ... - ): - parameter_info.param_decls = ( - cast(str, parameter_info.default), - *(parameter_info.param_decls or ()), - ) - parameter_info.default = ... - - # Forbid `my_param: Annotated[str, Argument('some-default')]` - if parameter_info.default is not ...: - raise AnnotatedParamWithDefaultValueError( - param_type=type(parameter_info), - argument_name=param.name, - ) - if param.default is not param.empty: - # Put the parameter's default (set by `=`) into `parameter_info`, where - # typer can find it. - parameter_info.default = param.default - - default = parameter_info - elif param.name in type_hints: - # Resolve forward references. - annotation = type_hints[param.name] - - if isinstance(default, ParameterInfo): - parameter_info = copy(default) - # Click supports `default` as either - # - an actual value; or - # - a factory function (returning a default value.) - # The two are not interchangeable for static typing, so typer allows - # specifying `default_factory`. Move the `default_factory` into `default` - # so click can find it. - if parameter_info.default is ... and parameter_info.default_factory: - parameter_info.default = parameter_info.default_factory - elif parameter_info.default_factory: - raise DefaultFactoryAndDefaultValueError( - argument_name=param.name, param_type=type(parameter_info) - ) - default = parameter_info - - params[param.name] = ParamMeta( - name=param.name, default=default, annotation=annotation - ) + for name, meta in _process_param_from_function( + type_hints, param, expand_pydantic_fields=expand_pydantic_fields + ): + params[name] = meta return params + + +def _explode_env_vars( + env_nested_delimiter: str, env_vars: Dict[str, Optional[str]] +) -> Dict[str, Any]: + """ + Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries. + + This is applied to a single field, hence filtering by env_var prefix. + """ + prefixes = ( + [] + ) # f'{env_name}{env_nested_delimiter}' for env_name in field.field_info.extra['env_names']] + result: Dict[str, Any] = {} + for env_name, env_val in env_vars.items(): + # if not any(env_name.startswith(prefix) for prefix in prefixes): + # continue + # we remove the prefix before splitting in case the prefix has characters in common with the delimiter + env_name_without_prefix = env_name # [self.env_prefix_len :] + _, *keys, last_key = env_name_without_prefix.split(env_nested_delimiter) + env_var = result + for key in keys: + env_var = env_var.setdefault(key, {}) + env_var[last_key] = env_val + + return result + + +def update_pydantic_params( + callback: Optional[Callable[..., Any]], use_params: Dict[str, Any] +) -> None: + """ + Explode and collapse delimited params into pydantic models. + + Example: + --foo--bar 23, --foo--baz 42 + -> {"foo": {"bar": 23, "baz" 42 }} + -> {"foo": Foo(**{"bar": 23, "baz" 42 })} + """ + if not pydantic: + return use_params + + pydantic_parameters = { + param_name: param.annotation + for param_name, param in get_params_from_function( + callback, expand_pydantic_fields=False + ).items() + if _lenient_issubclass(param.annotation, pydantic.BaseModel) + } + + delimited_vars = {} + delete = [] + for k, v in use_params.items(): + if PYDANTIC_FIELD_DELIMITER in k: + if v is not None: + delimited_vars[PYDANTIC_FIELD_DELIMITER + k] = v + delete.append(k) + for delete_one in delete: + use_params.pop(delete_one) + exploded_vars = _explode_env_vars( + env_nested_delimiter=PYDANTIC_FIELD_DELIMITER, env_vars=delimited_vars + ) + for k, v in exploded_vars.items(): + use_params[k] = pydantic_parameters[k](**v)