Skip to content

Commit

Permalink
Initial attempt at pydantic support.
Browse files Browse the repository at this point in the history
  • Loading branch information
gkarg committed Jun 30, 2023
1 parent f83c66d commit 1e0334f
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 77 deletions.
12 changes: 11 additions & 1 deletion typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Required,
TyperInfo,
)
from .utils import get_params_from_function
from .utils import get_params_from_function, pydantic, update_pydantic_params

try:
import rich
Expand Down Expand Up @@ -680,6 +680,8 @@ def wrapper(**kwargs: Any) -> Any:
use_params[k] = v
if context_param_name:
use_params[context_param_name] = click.get_current_context()

update_pydantic_params(callback, use_params)
return callback(**use_params) # type: ignore

update_wrapper(wrapper, callback)
Expand All @@ -692,6 +694,14 @@ def get_click_type(
if parameter_info.click_type is not None:
return parameter_info.click_type

elif pydantic and lenient_issubclass(annotation, pydantic.BaseModel):

class CustomParamType(click.ParamType):
def convert(self, value, param, ctx):
return annotation.parse_raw(value)

return CustomParamType()

elif parameter_info.parser is not None:
return click.types.FuncParamType(parameter_info.parser)

Expand Down
266 changes: 190 additions & 76 deletions typer/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
import inspect
from copy import copy
from typing import Any, Callable, Dict, List, Tuple, Type, cast, get_type_hints
from inspect import Parameter
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
cast,
get_type_hints,
)

from typing_extensions import Annotated

from ._typing import get_args, get_origin
from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta

try:
import pydantic
except ImportError:
pydantic = None

PYDANTIC_FIELD_DELIMITER = "__"


def _param_type_to_user_string(param_type: Type[ParameterInfo]) -> str:
# Render a `ParameterInfo` subclass for use in error messages.
Expand Down Expand Up @@ -105,83 +123,179 @@ def _split_annotation_from_typer_annotations(
]


def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]:
def _lenient_issubclass(
cls: Any, class_or_tuple # : Union[AnyType, Tuple[AnyType, ...]]
) -> bool:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)


def _process_pydantic_model(model: "pydantic.BaseModel", names: List[str]):
for field in model.__fields__.values():
if _lenient_issubclass(field.type_, pydantic.BaseModel):
names.append(field.name)
yield from _process_pydantic_model(field.type_, names)
names.pop()
else:
name = PYDANTIC_FIELD_DELIMITER.join(names + [field.name])
yield name, ParamMeta(
name=name, default=field.default, annotation=field.type_
)


def _process_param_from_function(
type_hints: Dict[str, Any], param: Parameter, expand_pydantic_fields: bool = True
) -> Tuple[str, ParamMeta]:
annotation, typer_annotations = _split_annotation_from_typer_annotations(
param.annotation,
)
if len(typer_annotations) > 1:
raise MultipleTyperAnnotationsError(param.name)

if (
expand_pydantic_fields
and pydantic
and _lenient_issubclass(annotation, pydantic.BaseModel)
):
yield from _process_pydantic_model(annotation, [param.name])
return

default = param.default
if typer_annotations:
# It's something like `my_param: Annotated[str, Argument()]`
[parameter_info] = typer_annotations

# Forbid `my_param: Annotated[str, Argument()] = Argument("...")`
if isinstance(param.default, ParameterInfo):
raise MixedAnnotatedAndDefaultStyleError(
argument_name=param.name,
annotated_param_type=type(parameter_info),
default_param_type=type(param.default),
)

parameter_info = copy(parameter_info)

# When used as a default, `Option` takes a default value and option names
# as positional arguments:
# `Option(some_value, "--some-argument", "-s")`
# When used in `Annotated` (ie, what this is handling), `Option` just takes
# option names as positional arguments:
# `Option("--some-argument", "-s")`
# In this case, the `default` attribute of `parameter_info` is actually
# meant to be the first item of `param_decls`.
if isinstance(parameter_info, OptionInfo) and parameter_info.default is not ...:
parameter_info.param_decls = (
cast(str, parameter_info.default),
*(parameter_info.param_decls or ()),
)
parameter_info.default = ...

# Forbid `my_param: Annotated[str, Argument('some-default')]`
if parameter_info.default is not ...:
raise AnnotatedParamWithDefaultValueError(
param_type=type(parameter_info),
argument_name=param.name,
)
if param.default is not param.empty:
# Put the parameter's default (set by `=`) into `parameter_info`, where
# typer can find it.
parameter_info.default = param.default

default = parameter_info
elif param.name in type_hints:
# Resolve forward references.
annotation = type_hints[param.name]

if isinstance(default, ParameterInfo):
parameter_info = copy(default)
# Click supports `default` as either
# - an actual value; or
# - a factory function (returning a default value.)
# The two are not interchangeable for static typing, so typer allows
# specifying `default_factory`. Move the `default_factory` into `default`
# so click can find it.
if parameter_info.default is ... and parameter_info.default_factory:
parameter_info.default = parameter_info.default_factory
elif parameter_info.default_factory:
raise DefaultFactoryAndDefaultValueError(
argument_name=param.name, param_type=type(parameter_info)
)
default = parameter_info
yield param.name, ParamMeta(name=param.name, default=default, annotation=annotation)


def get_params_from_function(
func: Callable[..., Any], expand_pydantic_fields: bool = True
) -> Dict[str, ParamMeta]:
signature = inspect.signature(func)
type_hints = get_type_hints(func)
params = {}
for param in signature.parameters.values():
annotation, typer_annotations = _split_annotation_from_typer_annotations(
param.annotation,
)
if len(typer_annotations) > 1:
raise MultipleTyperAnnotationsError(param.name)

default = param.default
if typer_annotations:
# It's something like `my_param: Annotated[str, Argument()]`
[parameter_info] = typer_annotations

# Forbid `my_param: Annotated[str, Argument()] = Argument("...")`
if isinstance(param.default, ParameterInfo):
raise MixedAnnotatedAndDefaultStyleError(
argument_name=param.name,
annotated_param_type=type(parameter_info),
default_param_type=type(param.default),
)

parameter_info = copy(parameter_info)

# When used as a default, `Option` takes a default value and option names
# as positional arguments:
# `Option(some_value, "--some-argument", "-s")`
# When used in `Annotated` (ie, what this is handling), `Option` just takes
# option names as positional arguments:
# `Option("--some-argument", "-s")`
# In this case, the `default` attribute of `parameter_info` is actually
# meant to be the first item of `param_decls`.
if (
isinstance(parameter_info, OptionInfo)
and parameter_info.default is not ...
):
parameter_info.param_decls = (
cast(str, parameter_info.default),
*(parameter_info.param_decls or ()),
)
parameter_info.default = ...

# Forbid `my_param: Annotated[str, Argument('some-default')]`
if parameter_info.default is not ...:
raise AnnotatedParamWithDefaultValueError(
param_type=type(parameter_info),
argument_name=param.name,
)
if param.default is not param.empty:
# Put the parameter's default (set by `=`) into `parameter_info`, where
# typer can find it.
parameter_info.default = param.default

default = parameter_info
elif param.name in type_hints:
# Resolve forward references.
annotation = type_hints[param.name]

if isinstance(default, ParameterInfo):
parameter_info = copy(default)
# Click supports `default` as either
# - an actual value; or
# - a factory function (returning a default value.)
# The two are not interchangeable for static typing, so typer allows
# specifying `default_factory`. Move the `default_factory` into `default`
# so click can find it.
if parameter_info.default is ... and parameter_info.default_factory:
parameter_info.default = parameter_info.default_factory
elif parameter_info.default_factory:
raise DefaultFactoryAndDefaultValueError(
argument_name=param.name, param_type=type(parameter_info)
)
default = parameter_info

params[param.name] = ParamMeta(
name=param.name, default=default, annotation=annotation
)
for name, meta in _process_param_from_function(
type_hints, param, expand_pydantic_fields=expand_pydantic_fields
):
params[name] = meta
return params


def _explode_env_vars(
env_nested_delimiter: str, env_vars: Dict[str, Optional[str]]
) -> Dict[str, Any]:
"""
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
This is applied to a single field, hence filtering by env_var prefix.
"""
prefixes = (
[]
) # f'{env_name}{env_nested_delimiter}' for env_name in field.field_info.extra['env_names']]
result: Dict[str, Any] = {}
for env_name, env_val in env_vars.items():
# if not any(env_name.startswith(prefix) for prefix in prefixes):
# continue
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
env_name_without_prefix = env_name # [self.env_prefix_len :]
_, *keys, last_key = env_name_without_prefix.split(env_nested_delimiter)
env_var = result
for key in keys:
env_var = env_var.setdefault(key, {})
env_var[last_key] = env_val

return result


def update_pydantic_params(
callback: Optional[Callable[..., Any]], use_params: Dict[str, Any]
) -> None:
"""
Explode and collapse delimited params into pydantic models.
Example:
--foo--bar 23, --foo--baz 42
-> {"foo": {"bar": 23, "baz" 42 }}
-> {"foo": Foo(**{"bar": 23, "baz" 42 })}
"""
if not pydantic:
return use_params

pydantic_parameters = {
param_name: param.annotation
for param_name, param in get_params_from_function(
callback, expand_pydantic_fields=False
).items()
if _lenient_issubclass(param.annotation, pydantic.BaseModel)
}

delimited_vars = {}
delete = []
for k, v in use_params.items():
if PYDANTIC_FIELD_DELIMITER in k:
if v is not None:
delimited_vars[PYDANTIC_FIELD_DELIMITER + k] = v
delete.append(k)
for delete_one in delete:
use_params.pop(delete_one)
exploded_vars = _explode_env_vars(
env_nested_delimiter=PYDANTIC_FIELD_DELIMITER, env_vars=delimited_vars
)
for k, v in exploded_vars.items():
use_params[k] = pydantic_parameters[k](**v)

0 comments on commit 1e0334f

Please sign in to comment.