Skip to content

Initial attempt at pydantic support. #630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Required,
TyperInfo,
)
from .utils import get_params_from_function
from .utils import get_params_from_function, pydantic, update_pydantic_params

try:
import rich
Expand Down Expand Up @@ -680,6 +680,8 @@ def wrapper(**kwargs: Any) -> Any:
use_params[k] = v
if context_param_name:
use_params[context_param_name] = click.get_current_context()

update_pydantic_params(callback, use_params)
return callback(**use_params) # type: ignore

update_wrapper(wrapper, callback)
Expand All @@ -692,6 +694,14 @@ def get_click_type(
if parameter_info.click_type is not None:
return parameter_info.click_type

elif pydantic and lenient_issubclass(annotation, pydantic.BaseModel):

class CustomParamType(click.ParamType):
def convert(self, value, param, ctx):
return annotation.parse_raw(value)

return CustomParamType()

elif parameter_info.parser is not None:
return click.types.FuncParamType(parameter_info.parser)

Expand Down
266 changes: 190 additions & 76 deletions typer/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
import inspect
from copy import copy
from typing import Any, Callable, Dict, List, Tuple, Type, cast, get_type_hints
from inspect import Parameter
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
cast,
get_type_hints,
)

from typing_extensions import Annotated

from ._typing import get_args, get_origin
from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta

try:
import pydantic
except ImportError:
pydantic = None

PYDANTIC_FIELD_DELIMITER = "__"


def _param_type_to_user_string(param_type: Type[ParameterInfo]) -> str:
# Render a `ParameterInfo` subclass for use in error messages.
Expand Down Expand Up @@ -105,83 +123,179 @@ def _split_annotation_from_typer_annotations(
]


def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]:
def _lenient_issubclass(
cls: Any, class_or_tuple # : Union[AnyType, Tuple[AnyType, ...]]
) -> bool:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)


def _process_pydantic_model(model: "pydantic.BaseModel", names: List[str]):
for field in model.__fields__.values():
if _lenient_issubclass(field.type_, pydantic.BaseModel):
names.append(field.name)
yield from _process_pydantic_model(field.type_, names)
names.pop()
else:
name = PYDANTIC_FIELD_DELIMITER.join(names + [field.name])
yield name, ParamMeta(
name=name, default=field.default, annotation=field.type_
)


def _process_param_from_function(
type_hints: Dict[str, Any], param: Parameter, expand_pydantic_fields: bool = True
) -> Tuple[str, ParamMeta]:
annotation, typer_annotations = _split_annotation_from_typer_annotations(
param.annotation,
)
if len(typer_annotations) > 1:
raise MultipleTyperAnnotationsError(param.name)

if (
expand_pydantic_fields
and pydantic
and _lenient_issubclass(annotation, pydantic.BaseModel)
):
yield from _process_pydantic_model(annotation, [param.name])
return

default = param.default
if typer_annotations:
# It's something like `my_param: Annotated[str, Argument()]`
[parameter_info] = typer_annotations

# Forbid `my_param: Annotated[str, Argument()] = Argument("...")`
if isinstance(param.default, ParameterInfo):
raise MixedAnnotatedAndDefaultStyleError(
argument_name=param.name,
annotated_param_type=type(parameter_info),
default_param_type=type(param.default),
)

parameter_info = copy(parameter_info)

# When used as a default, `Option` takes a default value and option names
# as positional arguments:
# `Option(some_value, "--some-argument", "-s")`
# When used in `Annotated` (ie, what this is handling), `Option` just takes
# option names as positional arguments:
# `Option("--some-argument", "-s")`
# In this case, the `default` attribute of `parameter_info` is actually
# meant to be the first item of `param_decls`.
if isinstance(parameter_info, OptionInfo) and parameter_info.default is not ...:
parameter_info.param_decls = (
cast(str, parameter_info.default),
*(parameter_info.param_decls or ()),
)
parameter_info.default = ...

# Forbid `my_param: Annotated[str, Argument('some-default')]`
if parameter_info.default is not ...:
raise AnnotatedParamWithDefaultValueError(
param_type=type(parameter_info),
argument_name=param.name,
)
if param.default is not param.empty:
# Put the parameter's default (set by `=`) into `parameter_info`, where
# typer can find it.
parameter_info.default = param.default

default = parameter_info
elif param.name in type_hints:
# Resolve forward references.
annotation = type_hints[param.name]

if isinstance(default, ParameterInfo):
parameter_info = copy(default)
# Click supports `default` as either
# - an actual value; or
# - a factory function (returning a default value.)
# The two are not interchangeable for static typing, so typer allows
# specifying `default_factory`. Move the `default_factory` into `default`
# so click can find it.
if parameter_info.default is ... and parameter_info.default_factory:
parameter_info.default = parameter_info.default_factory
elif parameter_info.default_factory:
raise DefaultFactoryAndDefaultValueError(
argument_name=param.name, param_type=type(parameter_info)
)
default = parameter_info
yield param.name, ParamMeta(name=param.name, default=default, annotation=annotation)


def get_params_from_function(
func: Callable[..., Any], expand_pydantic_fields: bool = True
) -> Dict[str, ParamMeta]:
signature = inspect.signature(func)
type_hints = get_type_hints(func)
params = {}
for param in signature.parameters.values():
annotation, typer_annotations = _split_annotation_from_typer_annotations(
param.annotation,
)
if len(typer_annotations) > 1:
raise MultipleTyperAnnotationsError(param.name)

default = param.default
if typer_annotations:
# It's something like `my_param: Annotated[str, Argument()]`
[parameter_info] = typer_annotations

# Forbid `my_param: Annotated[str, Argument()] = Argument("...")`
if isinstance(param.default, ParameterInfo):
raise MixedAnnotatedAndDefaultStyleError(
argument_name=param.name,
annotated_param_type=type(parameter_info),
default_param_type=type(param.default),
)

parameter_info = copy(parameter_info)

# When used as a default, `Option` takes a default value and option names
# as positional arguments:
# `Option(some_value, "--some-argument", "-s")`
# When used in `Annotated` (ie, what this is handling), `Option` just takes
# option names as positional arguments:
# `Option("--some-argument", "-s")`
# In this case, the `default` attribute of `parameter_info` is actually
# meant to be the first item of `param_decls`.
if (
isinstance(parameter_info, OptionInfo)
and parameter_info.default is not ...
):
parameter_info.param_decls = (
cast(str, parameter_info.default),
*(parameter_info.param_decls or ()),
)
parameter_info.default = ...

# Forbid `my_param: Annotated[str, Argument('some-default')]`
if parameter_info.default is not ...:
raise AnnotatedParamWithDefaultValueError(
param_type=type(parameter_info),
argument_name=param.name,
)
if param.default is not param.empty:
# Put the parameter's default (set by `=`) into `parameter_info`, where
# typer can find it.
parameter_info.default = param.default

default = parameter_info
elif param.name in type_hints:
# Resolve forward references.
annotation = type_hints[param.name]

if isinstance(default, ParameterInfo):
parameter_info = copy(default)
# Click supports `default` as either
# - an actual value; or
# - a factory function (returning a default value.)
# The two are not interchangeable for static typing, so typer allows
# specifying `default_factory`. Move the `default_factory` into `default`
# so click can find it.
if parameter_info.default is ... and parameter_info.default_factory:
parameter_info.default = parameter_info.default_factory
elif parameter_info.default_factory:
raise DefaultFactoryAndDefaultValueError(
argument_name=param.name, param_type=type(parameter_info)
)
default = parameter_info

params[param.name] = ParamMeta(
name=param.name, default=default, annotation=annotation
)
for name, meta in _process_param_from_function(
type_hints, param, expand_pydantic_fields=expand_pydantic_fields
):
params[name] = meta
return params


def _explode_env_vars(
env_nested_delimiter: str, env_vars: Dict[str, Optional[str]]
) -> Dict[str, Any]:
"""
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.

This is applied to a single field, hence filtering by env_var prefix.
"""
prefixes = (
[]
) # f'{env_name}{env_nested_delimiter}' for env_name in field.field_info.extra['env_names']]
result: Dict[str, Any] = {}
for env_name, env_val in env_vars.items():
# if not any(env_name.startswith(prefix) for prefix in prefixes):
# continue
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
env_name_without_prefix = env_name # [self.env_prefix_len :]
_, *keys, last_key = env_name_without_prefix.split(env_nested_delimiter)
env_var = result
for key in keys:
env_var = env_var.setdefault(key, {})
env_var[last_key] = env_val

return result


def update_pydantic_params(
callback: Optional[Callable[..., Any]], use_params: Dict[str, Any]
) -> None:
"""
Explode and collapse delimited params into pydantic models.

Example:
--foo--bar 23, --foo--baz 42
-> {"foo": {"bar": 23, "baz" 42 }}
-> {"foo": Foo(**{"bar": 23, "baz" 42 })}
"""
if not pydantic:
return use_params

pydantic_parameters = {
param_name: param.annotation
for param_name, param in get_params_from_function(
callback, expand_pydantic_fields=False
).items()
if _lenient_issubclass(param.annotation, pydantic.BaseModel)
}

delimited_vars = {}
delete = []
for k, v in use_params.items():
if PYDANTIC_FIELD_DELIMITER in k:
if v is not None:
delimited_vars[PYDANTIC_FIELD_DELIMITER + k] = v
delete.append(k)
for delete_one in delete:
use_params.pop(delete_one)
exploded_vars = _explode_env_vars(
env_nested_delimiter=PYDANTIC_FIELD_DELIMITER, env_vars=delimited_vars
)
for k, v in exploded_vars.items():
use_params[k] = pydantic_parameters[k](**v)