Skip to content

Commit 945d2df

Browse files
kumare3wild-endeavorpingsutweapolinario
authored
Improved Type engine for generic types and performance (#2815)
Signed-off-by: Ketan Umare <[email protected]> Signed-off-by: Yee Hing Tong <[email protected]> Signed-off-by: Kevin Su <[email protected]> Signed-off-by: Eduardo Apolinario <[email protected]> Co-authored-by: Ketan Umare <[email protected]> Co-authored-by: Yee Hing Tong <[email protected]> Co-authored-by: Kevin Su <[email protected]> Co-authored-by: Eduardo Apolinario <[email protected]>
1 parent 1f8a273 commit 945d2df

File tree

7 files changed

+575
-389
lines changed

7 files changed

+575
-389
lines changed

flytekit/core/type_engine.py

Lines changed: 123 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,10 @@ def __init__(
307307
self._to_literal_transformer = to_literal_transformer
308308
self._from_literal_transformer = from_literal_transformer
309309

310+
@property
311+
def base_type(self) -> Type:
312+
return self._type
313+
310314
def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType:
311315
return LiteralType.from_flyte_idl(self._lt.to_flyte_idl())
312316

@@ -909,8 +913,9 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:
909913

910914

911915
def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any):
912-
attribute_list: typing.List[tuple[Any, GenericAlias]] = []
916+
attribute_list: typing.List[typing.Tuple[Any, Any]] = []
913917
for property_key, property_val in schema["properties"].items():
918+
property_type = ""
914919
if property_val.get("anyOf"):
915920
property_type = property_val["anyOf"][0]["type"]
916921
elif property_val.get("enum"):
@@ -934,9 +939,8 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
934939
)
935940
)
936941
elif property_val.get("additionalProperties"):
937-
attribute_list.append(
938-
(property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore
939-
)
942+
elem_type = _get_element_type(property_val["additionalProperties"])
943+
attribute_list.append((property_key, typing.Dict[str, elem_type])) # type: ignore
940944
else:
941945
sub_schemea_name = property_val["title"]
942946
attribute_list.append(
@@ -1003,114 +1007,64 @@ def register_additional_type(cls, transformer: TypeTransformer[T], additional_ty
10031007
cls._REGISTRY[additional_type] = transformer
10041008

10051009
@classmethod
1006-
def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
1007-
"""
1008-
The TypeEngine hierarchy for flyteKit. This method looks up and selects the type transformer. The algorithm is
1009-
as follows
1010-
1011-
d = dictionary of registered transformers, where is a python `type`
1012-
v = lookup type
1013-
1014-
Step 1:
1015-
If the type is annotated with a TypeTransformer instance, use that.
1016-
1017-
Step 2:
1018-
find a transformer that matches v exactly
1019-
1020-
Step 3:
1021-
find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc
1022-
1023-
Step 4:
1024-
Walk the inheritance hierarchy of v and find a transformer that matches the first base class.
1025-
This is potentially non-deterministic - will depend on the registration pattern.
1026-
1027-
Special case:
1028-
If v inherits from Enum, use the Enum transformer even if Enum is not the first base class.
1029-
1030-
TODO lets make this deterministic by using an ordered dict
1031-
1032-
Step 5:
1033-
if v is of type data class, use the dataclass transformer
1034-
1035-
Step 6:
1036-
Pickle transformer is used
1037-
"""
1010+
def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]:
10381011
cls.lazy_import_transformers()
1039-
# Step 1
10401012
if is_annotated(python_type):
10411013
args = get_args(python_type)
10421014
for annotation in args:
10431015
if isinstance(annotation, TypeTransformer):
10441016
return annotation
1017+
return cls.get_transformer(args[0])
10451018

1046-
python_type = args[0]
1047-
1048-
# Step 2
1049-
# this makes sure that if it's a list/dict of annotated types, we hit the unwrapping code in step 2
1050-
# see test_list_of_annotated in test_structured_dataset.py
1051-
if (
1052-
(not hasattr(python_type, "__origin__"))
1053-
or (
1054-
hasattr(python_type, "__origin__")
1055-
and (python_type.__origin__ is not list and python_type.__origin__ is not dict)
1056-
)
1057-
) and python_type in cls._REGISTRY:
1058-
return cls._REGISTRY[python_type]
1019+
if inspect.isclass(python_type) and issubclass(python_type, enum.Enum):
1020+
# Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
1021+
return cls._ENUM_TRANSFORMER
10591022

1060-
# Step 3
10611023
if hasattr(python_type, "__origin__"):
1062-
# Handling of annotated generics, eg:
1063-
# Annotated[typing.List[int], 'foo']
1064-
if is_annotated(python_type):
1065-
return cls.get_transformer(get_args(python_type)[0])
1066-
1024+
# If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON]
1025+
# or List[int] has been specifically registered; we should check for the entire type.
1026+
# The challenge is for StructuredDataset, example List[StructuredDataset] the column names is an OrderedDict
1027+
# are not hashable, thus looking up this type is not possible.
1028+
# In such as case, we will have to skip the "type" lookup and use the origin type only
1029+
try:
1030+
if python_type in cls._REGISTRY:
1031+
return cls._REGISTRY[python_type]
1032+
except TypeError:
1033+
pass
10671034
if python_type.__origin__ in cls._REGISTRY:
10681035
return cls._REGISTRY[python_type.__origin__]
10691036

1070-
raise ValueError(f"Generic Type {python_type.__origin__} not supported currently in Flytekit.")
1071-
1072-
# Step 4
1073-
# To facilitate cases where users may specify one transformer for multiple types that all inherit from one
1074-
# parent.
1075-
if inspect.isclass(python_type) and issubclass(python_type, enum.Enum):
1076-
# Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
1077-
return cls._ENUM_TRANSFORMER
1037+
# Handling UnionType specially - PEP 604
1038+
if sys.version_info >= (3, 10):
1039+
import types
10781040

1079-
from flytekit.types.iterator.json_iterator import JSONIterator
1041+
if isinstance(python_type, types.UnionType):
1042+
return cls._REGISTRY[types.UnionType]
10801043

1081-
for base_type in cls._REGISTRY.keys():
1082-
if base_type is None:
1083-
continue # None is actually one of the keys, but isinstance/issubclass doesn't work on it
1084-
try:
1085-
origin_type: Optional[typing.Any] = base_type
1086-
if hasattr(base_type, "__args__"):
1087-
origin_base_type = get_origin(base_type)
1088-
if isinstance(origin_base_type, type) and issubclass(
1089-
origin_base_type, typing.Iterator
1090-
): # Iterator[JSON]
1091-
origin_type = origin_base_type
1092-
1093-
if isinstance(python_type, origin_type) or ( # type: ignore[arg-type]
1094-
inspect.isclass(python_type) and issubclass(python_type, origin_type) # type: ignore[arg-type]
1095-
):
1096-
# Consider Iterator[JSON] but not vanilla Iterator when the value is a JSON iterator.
1097-
if (
1098-
isinstance(python_type, type)
1099-
and issubclass(python_type, JSONIterator)
1100-
and not get_args(base_type)
1101-
):
1102-
continue
1103-
return cls._REGISTRY[base_type]
1104-
except TypeError:
1105-
# As of python 3.9, calls to isinstance raise a TypeError if the base type is not a valid type, which
1106-
# is the case for one of the restricted types, namely NamedTuple.
1107-
logger.debug(f"Invalid base type {base_type} in call to isinstance", exc_info=True)
1044+
if python_type in cls._REGISTRY:
1045+
return cls._REGISTRY[python_type]
11081046

1109-
# Step 5
11101047
if dataclasses.is_dataclass(python_type):
11111048
return cls._DATACLASS_TRANSFORMER
11121049

1113-
# Step 6
1050+
return None
1051+
1052+
@classmethod
1053+
def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
1054+
"""
1055+
Implements a recursive search for the transformer.
1056+
"""
1057+
v = cls._get_transformer(python_type)
1058+
if v is not None:
1059+
return v
1060+
1061+
if hasattr(python_type, "__mro__"):
1062+
class_tree = inspect.getmro(python_type)
1063+
for t in class_tree:
1064+
v = cls._get_transformer(t)
1065+
if v is not None:
1066+
return v
1067+
11141068
display_pickle_warning(str(python_type))
11151069
from flytekit.types.pickle.pickle import FlytePickleTransformer
11161070

@@ -2207,7 +2161,7 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
22072161
)
22082162
element_format = element_property["format"] if "format" in element_property else None
22092163

2210-
if type(element_type) == list:
2164+
if isinstance(element_type, list):
22112165
# Element type of Optional[int] is [integer, None]
22122166
return typing.Optional[_get_element_type({"type": element_type[0]})] # type: ignore
22132167

@@ -2255,89 +2209,82 @@ def _check_and_convert_void(lv: Literal) -> None:
22552209
return None
22562210

22572211

2258-
def _register_default_type_transformers():
2259-
TypeEngine.register(
2260-
SimpleTransformer(
2261-
"int",
2262-
int,
2263-
_type_models.LiteralType(simple=_type_models.SimpleType.INTEGER),
2264-
lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))),
2265-
lambda x: x.scalar.primitive.integer,
2266-
)
2267-
)
2268-
2269-
TypeEngine.register(
2270-
SimpleTransformer(
2271-
"float",
2272-
float,
2273-
_type_models.LiteralType(simple=_type_models.SimpleType.FLOAT),
2274-
lambda x: Literal(scalar=Scalar(primitive=Primitive(float_value=x))),
2275-
_check_and_covert_float,
2276-
)
2277-
)
2278-
2279-
TypeEngine.register(
2280-
SimpleTransformer(
2281-
"bool",
2282-
bool,
2283-
_type_models.LiteralType(simple=_type_models.SimpleType.BOOLEAN),
2284-
lambda x: Literal(scalar=Scalar(primitive=Primitive(boolean=x))),
2285-
lambda x: x.scalar.primitive.boolean,
2286-
)
2287-
)
2212+
IntTransformer = SimpleTransformer(
2213+
"int",
2214+
int,
2215+
_type_models.LiteralType(simple=_type_models.SimpleType.INTEGER),
2216+
lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))),
2217+
lambda x: x.scalar.primitive.integer,
2218+
)
2219+
2220+
FloatTransformer = SimpleTransformer(
2221+
"float",
2222+
float,
2223+
_type_models.LiteralType(simple=_type_models.SimpleType.FLOAT),
2224+
lambda x: Literal(scalar=Scalar(primitive=Primitive(float_value=x))),
2225+
_check_and_covert_float,
2226+
)
2227+
2228+
BoolTransformer = SimpleTransformer(
2229+
"bool",
2230+
bool,
2231+
_type_models.LiteralType(simple=_type_models.SimpleType.BOOLEAN),
2232+
lambda x: Literal(scalar=Scalar(primitive=Primitive(boolean=x))),
2233+
lambda x: x.scalar.primitive.boolean,
2234+
)
2235+
2236+
StrTransformer = SimpleTransformer(
2237+
"str",
2238+
str,
2239+
_type_models.LiteralType(simple=_type_models.SimpleType.STRING),
2240+
lambda x: Literal(scalar=Scalar(primitive=Primitive(string_value=x))),
2241+
lambda x: x.scalar.primitive.string_value,
2242+
)
2243+
2244+
DatetimeTransformer = SimpleTransformer(
2245+
"datetime",
2246+
datetime.datetime,
2247+
_type_models.LiteralType(simple=_type_models.SimpleType.DATETIME),
2248+
lambda x: Literal(scalar=Scalar(primitive=Primitive(datetime=x))),
2249+
lambda x: x.scalar.primitive.datetime,
2250+
)
2251+
2252+
TimedeltaTransformer = SimpleTransformer(
2253+
"timedelta",
2254+
datetime.timedelta,
2255+
_type_models.LiteralType(simple=_type_models.SimpleType.DURATION),
2256+
lambda x: Literal(scalar=Scalar(primitive=Primitive(duration=x))),
2257+
lambda x: x.scalar.primitive.duration,
2258+
)
2259+
2260+
DateTransformer = SimpleTransformer(
2261+
"date",
2262+
datetime.date,
2263+
_type_models.LiteralType(simple=_type_models.SimpleType.DATETIME),
2264+
lambda x: Literal(
2265+
scalar=Scalar(primitive=Primitive(datetime=datetime.datetime.combine(x, datetime.time.min)))
2266+
), # convert datetime to date
2267+
lambda x: x.scalar.primitive.datetime.date(), # get date from datetime
2268+
)
2269+
2270+
NoneTransformer = SimpleTransformer(
2271+
"none",
2272+
type(None),
2273+
_type_models.LiteralType(simple=_type_models.SimpleType.NONE),
2274+
lambda x: Literal(scalar=Scalar(none_type=Void())),
2275+
lambda x: _check_and_convert_void(x),
2276+
)
22882277

2289-
TypeEngine.register(
2290-
SimpleTransformer(
2291-
"str",
2292-
str,
2293-
_type_models.LiteralType(simple=_type_models.SimpleType.STRING),
2294-
lambda x: Literal(scalar=Scalar(primitive=Primitive(string_value=x))),
2295-
lambda x: x.scalar.primitive.string_value,
2296-
)
2297-
)
2298-
2299-
TypeEngine.register(
2300-
SimpleTransformer(
2301-
"datetime",
2302-
datetime.datetime,
2303-
_type_models.LiteralType(simple=_type_models.SimpleType.DATETIME),
2304-
lambda x: Literal(scalar=Scalar(primitive=Primitive(datetime=x))),
2305-
lambda x: x.scalar.primitive.datetime,
2306-
)
2307-
)
23082278

2309-
TypeEngine.register(
2310-
SimpleTransformer(
2311-
"timedelta",
2312-
datetime.timedelta,
2313-
_type_models.LiteralType(simple=_type_models.SimpleType.DURATION),
2314-
lambda x: Literal(scalar=Scalar(primitive=Primitive(duration=x))),
2315-
lambda x: x.scalar.primitive.duration,
2316-
)
2317-
)
2318-
2319-
TypeEngine.register(
2320-
SimpleTransformer(
2321-
"date",
2322-
datetime.date,
2323-
_type_models.LiteralType(simple=_type_models.SimpleType.DATETIME),
2324-
lambda x: Literal(
2325-
scalar=Scalar(primitive=Primitive(datetime=datetime.datetime.combine(x, datetime.time.min)))
2326-
), # convert datetime to date
2327-
lambda x: x.scalar.primitive.datetime.date(), # get date from datetime
2328-
)
2329-
)
2330-
2331-
TypeEngine.register(
2332-
SimpleTransformer(
2333-
"none",
2334-
type(None),
2335-
_type_models.LiteralType(simple=_type_models.SimpleType.NONE),
2336-
lambda x: Literal(scalar=Scalar(none_type=Void())),
2337-
lambda x: _check_and_convert_void(x),
2338-
),
2339-
[None],
2340-
)
2279+
def _register_default_type_transformers():
2280+
TypeEngine.register(IntTransformer)
2281+
TypeEngine.register(FloatTransformer)
2282+
TypeEngine.register(StrTransformer)
2283+
TypeEngine.register(DatetimeTransformer)
2284+
TypeEngine.register(DateTransformer)
2285+
TypeEngine.register(TimedeltaTransformer)
2286+
TypeEngine.register(BoolTransformer)
2287+
TypeEngine.register(NoneTransformer, [None]) # noqa
23412288
TypeEngine.register(ListTransformer())
23422289
if sys.version_info >= (3, 10):
23432290
from types import UnionType

0 commit comments

Comments
 (0)