Skip to content

Commit c902ac3

Browse files
committed
Adds support for unions with no complex types
1 parent b8bac79 commit c902ac3

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

typer/main.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,27 @@ def wrapper(**kwargs: Any) -> Any:
700700
update_wrapper(wrapper, callback)
701701
return wrapper
702702

703+
class UnionParamType(click.ParamType):
704+
@property
705+
def name(self) -> str: # type: ignore
706+
return ' | '.join(_type.name for _type in self._types)
707+
708+
def __init__(self, types: List[click.ParamType]):
709+
super().__init__()
710+
self._types = types
711+
712+
def convert(self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context]) -> Any:
713+
# *types, last = self._types
714+
error_messages = []
715+
for _type in self._types:
716+
try:
717+
return _type.convert(value, param, ctx)
718+
except click.BadParameter as e:
719+
print(type(e))
720+
error_messages.append(str(e))
721+
# return last.convert(value, param, ctx)
722+
raise self.fail('\n' + '\nbut also\n'.join(error_messages), param, ctx)
723+
703724

704725
def get_click_type(
705726
*, annotation: Any, parameter_info: ParameterInfo
@@ -791,6 +812,9 @@ def get_click_type(
791812
[item.value for item in annotation],
792813
case_sensitive=parameter_info.case_sensitive,
793814
)
815+
elif get_origin(annotation) is not None and is_union(get_origin(annotation)):
816+
types = [get_click_type(annotation=arg, parameter_info=parameter_info) for arg in get_args(annotation)]
817+
return UnionParamType(types)
794818
raise RuntimeError(f"Type not yet supported: {annotation}") # pragma: no cover
795819

796820

@@ -841,9 +865,14 @@ def get_click_param(
841865
if type_ is NoneType:
842866
continue
843867
types.append(type_)
844-
assert len(types) == 1, "Typer Currently doesn't support Union types"
845-
main_type = types[0]
846-
origin = get_origin(main_type)
868+
if len(types) == 1:
869+
main_type, = types
870+
origin = get_origin(main_type)
871+
else:
872+
for type_ in get_args(main_type):
873+
assert not get_origin(type_), (
874+
"Union types with complex sub-types are not currently supported"
875+
)
847876
# Handle Tuples and Lists
848877
if lenient_issubclass(origin, List):
849878
main_type = get_args(main_type)[0]

0 commit comments

Comments
 (0)