Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User customizable name_transform #150

Merged
merged 5 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cyclopts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"UnusedCliTokensError",
"ValidationError",
"convert",
"default_name_transform",
"types",
"validators",
]
Expand All @@ -36,5 +37,6 @@
from cyclopts.group import Group
from cyclopts.parameter import Parameter
from cyclopts.protocols import Dispatcher
from cyclopts.utils import default_name_transform

from . import types, validators
90 changes: 66 additions & 24 deletions cyclopts/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


from cyclopts.exceptions import CoercionError
from cyclopts.utils import is_union
from cyclopts.utils import default_name_transform, is_union

if TYPE_CHECKING:
from cyclopts.parameter import Parameter
Expand Down Expand Up @@ -83,7 +83,13 @@ def _bytearray(s: str) -> bytearray:
}


def _convert_tuple(type_: Type[Any], *args: str, converter: Optional[Callable] = None) -> Tuple:
def _convert_tuple(
type_: Type[Any],
*args: str,
converter: Optional[Callable[[Type, str], Any]],
name_transform: Callable[[str], str],
) -> Tuple:
convert = partial(_convert, converter=converter, name_transform=name_transform)
inner_types = tuple(x for x in get_args(type_) if x is not ...)
inner_token_count, consume_all = token_count(type_)
if consume_all:
Expand All @@ -101,11 +107,10 @@ def _convert_tuple(type_: Type[Any], *args: str, converter: Optional[Callable] =
raise ValueError("A tuple must have 0 or 1 inner-types.")

if inner_token_count == 1:
out = tuple(_convert(inner_type, x, converter=converter) for x in args)
out = tuple(convert(inner_type, x) for x in args)
else:
out = tuple(
_convert(inner_type, args[i : i + inner_token_count], converter=converter)
for i in range(0, len(args), inner_token_count)
convert(inner_type, args[i : i + inner_token_count]) for i in range(0, len(args), inner_token_count)
)
return out
else:
Expand All @@ -116,35 +121,49 @@ def _convert_tuple(type_: Type[Any], *args: str, converter: Optional[Callable] =
it = iter(args)
batched = [[next(it) for _ in range(size)] for size in args_per_convert]
batched = [elem[0] if len(elem) == 1 else elem for elem in batched]
out = tuple(_convert(inner_type, arg, converter=converter) for inner_type, arg in zip(inner_types, batched))
out = tuple(convert(inner_type, arg) for inner_type, arg in zip(inner_types, batched))
return out


def _convert(type_, element, converter=None):
pconvert = partial(_convert, converter=converter)
def _convert(
type_,
element,
*,
converter: Optional[Callable[[Type, str], Any]],
name_transform: Callable[[str], str],
):
"""Inner recursive conversion function for public ``convert``.

Parameters
----------
converter: Callable
name_transform: Callable
"""
convert = partial(_convert, converter=converter, name_transform=name_transform)
convert_tuple = partial(_convert_tuple, converter=converter, name_transform=name_transform)
origin_type = get_origin(type_)
inner_types = [resolve(x) for x in get_args(type_)]

if type_ in _implicit_iterable_type_mapping:
return pconvert(_implicit_iterable_type_mapping[type_], element)
return convert(_implicit_iterable_type_mapping[type_], element)

if origin_type is collections.abc.Iterable:
assert len(inner_types) == 1
return pconvert(List[inner_types[0]], element) # pyright: ignore[reportGeneralTypeIssues]
return convert(List[inner_types[0]], element) # pyright: ignore[reportGeneralTypeIssues]
elif is_union(origin_type):
for t in inner_types:
if t is NoneType:
continue
try:
return pconvert(t, element)
return convert(t, element)
except Exception:
pass
else:
raise CoercionError(input_value=element, target_type=type_)
elif origin_type is Literal:
for choice in get_args(type_):
try:
res = pconvert(type(choice), (element))
res = convert(type(choice), (element))
except Exception:
continue
if res == choice:
Expand All @@ -157,18 +176,18 @@ def _convert(type_, element, converter=None):
gen = zip(*[iter(element)] * count)
else:
gen = element
return origin_type(pconvert(inner_types[0], e) for e in gen) # pyright: ignore[reportOptionalCall]
return origin_type(convert(inner_types[0], e) for e in gen) # pyright: ignore[reportOptionalCall]
elif origin_type is tuple:
if isinstance(element, str):
# E.g. Tuple[str] (Annotation: tuple containing a single string)
return _convert_tuple(type_, element, converter=converter)
return convert_tuple(type_, element, converter=converter)
else:
return _convert_tuple(type_, *element, converter=converter)
return convert_tuple(type_, *element, converter=converter)
elif isclass(type_) and issubclass(type_, Enum):
if converter is None:
element_lower = element.lower().replace("-", "_")
element_transformed = name_transform(element)
for member in type_:
if member.name.lower().strip("_") == element_lower:
if name_transform(member.name) == element_transformed:
return member
raise CoercionError(input_value=element, target_type=type_)
else:
Expand Down Expand Up @@ -240,7 +259,12 @@ def resolve_annotated(type_: Type) -> Type:
return type_


def convert(type_: Type, *args: str, converter: Optional[Callable] = None):
def convert(
type_: Type,
*args: str,
converter: Optional[Callable[[Type, str], Any]] = None,
name_transform: Optional[Callable[[str], str]] = None,
):
"""Coerce variables into a specified type.

Internally used to coercing string CLI tokens into python builtin types.
Expand All @@ -259,8 +283,7 @@ def convert(type_: Type, *args: str, converter: Optional[Callable] = None):
A type hint/annotation to coerce ``*args`` into.
`*args`: str
String tokens to coerce.
converter: Optional[Callable]

converter: Optional[Callable[[Type, str], Any]]
An optional function to convert tokens to the inner-most types.
The converter should have signature:

Expand All @@ -272,12 +295,31 @@ def converter(type_: type, value: str) -> Any:
This allows to use the :func:`convert` function to handle the the difficult task
of traversing lists/tuples/unions/etc, while leaving the final conversion logic to
the caller.
name_transform: Optional[Callable[[str], str]]
Currently only used for ``Enum`` type hints.
A function that transforms enum names and CLI values into a normalized format.

The function should have signature:

.. code-block:: python

def name_transform(s: str) -> str:
...

where the returned value is the name to be used on the CLI.

If ``None``, defaults to ``cyclopts.default_name_transform``.

Returns
-------
Any
Coerced version of input ``*args``.
"""
if name_transform is None:
name_transform = default_name_transform

convert = partial(_convert, converter=converter, name_transform=name_transform)
convert_tuple = partial(_convert_tuple, converter=converter, name_transform=name_transform)
type_ = resolve(type_)

if type_ is Any:
Expand All @@ -288,13 +330,13 @@ def converter(type_: type, value: str) -> Any:
origin_type = get_origin_and_validate(type_)

if origin_type is tuple:
return _convert_tuple(type_, *args, converter=converter)
return convert_tuple(type_, *args)
elif (origin_type or type_) in _iterable_types or origin_type is collections.abc.Iterable:
return _convert(type_, args, converter=converter)
return convert(type_, args)
elif len(args) == 1:
return _convert(type_, args[0], converter=converter)
return convert(type_, args[0])
else:
return [_convert(type_, item, converter=converter) for item in args]
return [convert(type_, item) for item in args]


def token_count(type_: Union[Type[Any], inspect.Parameter]) -> Tuple[int, bool]:
Expand Down
25 changes: 19 additions & 6 deletions cyclopts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,13 @@
from cyclopts.parameter import Parameter, validate_command
from cyclopts.protocols import Dispatcher
from cyclopts.resolve import ResolvedCommand
from cyclopts.utils import optional_to_tuple_converter, to_list_converter, to_tuple_converter
from cyclopts.utils import default_name_transform, optional_to_tuple_converter, to_list_converter, to_tuple_converter

with suppress(ImportError):
# By importing, makes things like the arrow-keys work.
import readline # Not available on windows


def _format_name(name: str):
return name.lower().replace("_", "-").strip("-")


class _CannotDeriveCallingModuleNameError(Exception):
pass

Expand Down Expand Up @@ -224,6 +220,12 @@ class App:
converter: Optional[Callable] = field(default=None, kw_only=True)
validator: List[Callable] = field(default=None, converter=to_list_converter, kw_only=True)

_name_transform: Optional[Callable[[str], str]] = field(
default=None,
alias="name_transform",
kw_only=True,
)

######################
# Private Attributes #
######################
Expand Down Expand Up @@ -307,7 +309,7 @@ def name(self) -> Tuple[str, ...]:
name = _get_root_module_name()
return (name,)
else:
return (_format_name(self.default_command.__name__),)
return (self.name_transform(self.default_command.__name__),)

@property
def help(self) -> str:
Expand All @@ -328,6 +330,14 @@ def help(self) -> str:
def help(self, value):
self._help = value

@property
def name_transform(self):
return self._name_transform if self._name_transform else default_name_transform

@name_transform.setter
def name_transform(self, value):
self._name_transform = value

def version_print(self) -> None:
"""Print the application version."""
print(self.version() if callable(self.version) else self.version)
Expand Down Expand Up @@ -458,6 +468,9 @@ def command(
app = App(default_command=obj, **kwargs)
# app.name is handled below

if app._name_transform is None:
app.name_transform = self.name_transform

if name is None:
name = app.name
else:
Expand Down
18 changes: 10 additions & 8 deletions cyclopts/help.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import inspect
from enum import Enum
from functools import lru_cache
from functools import lru_cache, partial
from inspect import isclass
from typing import TYPE_CHECKING, List, Literal, Tuple, Type, Union, get_args, get_origin
from typing import TYPE_CHECKING, Callable, List, Literal, Tuple, Type, Union, get_args, get_origin

import docstring_parser
from attrs import define, field, frozen
Expand Down Expand Up @@ -190,20 +190,21 @@ def format_doc(root_app, app: "App", format: str = "restructuredtext"):
raise ValueError(f'Unknown help_format "{format}"')


def _get_choices(type_: Type) -> str:
def _get_choices(type_: Type, name_transform: Callable[[str], str]) -> str:
get_choices = partial(_get_choices, name_transform=name_transform)
choices: str = ""
_origin = get_origin(type_)
if isclass(type_) and issubclass(type_, Enum):
choices = ",".join(x.name.lower().replace("_", "-") for x in type_)
choices = ",".join(name_transform(x.name) for x in type_)
elif _origin is Union:
inner_choices = [_get_choices(inner) for inner in get_args(type_)]
inner_choices = [get_choices(inner) for inner in get_args(type_)]
choices = ",".join(x for x in inner_choices if x)
elif _origin is Literal:
choices = ",".join(str(x) for x in get_args(type_))
elif _origin in (list, set, tuple):
args = get_args(type_)
if len(args) == 1 or (_origin is tuple and len(args) == 2 and args[1] is Ellipsis):
choices = _get_choices(args[0])
choices = get_choices(args[0])
return choices


Expand All @@ -218,6 +219,7 @@ def create_parameter_help_panel(group: "Group", iparams, cparams: List[Parameter

for iparam, cparam in icparams:
assert cparam.name is not None
assert cparam.name_transform is not None
type_ = get_hint_parameter(iparam)[0]
options = list(cparam.name)
options.extend(cparam.get_negatives(type_, *options))
Expand All @@ -241,7 +243,7 @@ def create_parameter_help_panel(group: "Group", iparams, cparams: List[Parameter
help_components.append(cparam.help)

if cparam.show_choices:
choices = _get_choices(type_)
choices = _get_choices(type_, cparam.name_transform)
if choices:
help_components.append(rf"[dim]\[choices: {choices}][/dim]")

Expand All @@ -254,7 +256,7 @@ def create_parameter_help_panel(group: "Group", iparams, cparams: List[Parameter
):
default = ""
if isclass(type_) and issubclass(type_, Enum):
default = iparam.default.name.lower().replace("_", "-")
default = cparam.name_transform(iparam.default.name)
else:
default = iparam.default

Expand Down
15 changes: 13 additions & 2 deletions cyclopts/parameter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from functools import partial
from typing import Any, Callable, Iterable, Optional, Tuple, Type, Union, cast, get_args, get_origin

import attrs
Expand All @@ -13,7 +14,7 @@
resolve_optional,
)
from cyclopts.group import Group
from cyclopts.utils import optional_to_tuple_converter, record_init, to_tuple_converter
from cyclopts.utils import default_name_transform, optional_to_tuple_converter, record_init, to_tuple_converter


def _double_hyphen_validator(instance, attribute, values):
Expand Down Expand Up @@ -48,7 +49,7 @@ class Parameter:
converter=lambda x: cast(Tuple[str, ...], to_tuple_converter(x)),
)

converter: Callable = field(default=None, converter=attrs.converters.default_if_none(convert))
_converter: Callable = field(default=None, alias="converter")

# This can ONLY ever be a Tuple[Callable, ...]
validator: Union[None, Callable, Iterable[Callable]] = field(
Expand Down Expand Up @@ -98,13 +99,23 @@ class Parameter:

allow_leading_hyphen: bool = field(default=False)

name_transform: Optional[Callable[[str], str]] = field(
default=None,
converter=attrs.converters.default_if_none(default_name_transform),
kw_only=True,
)

# Populated by the record_attrs_init_args decorator.
_provided_args: Tuple[str] = field(default=(), init=False, eq=False)

@property
def show(self):
return self._show if self._show is not None else self.parse

@property
def converter(self):
return self._converter if self._converter else partial(convert, name_transform=self.name_transform)

def get_negatives(self, type_, *names: str) -> Tuple[str, ...]:
type_ = get_origin(type_) or type_

Expand Down
Loading