@@ -307,6 +307,10 @@ def __init__(
307
307
self ._to_literal_transformer = to_literal_transformer
308
308
self ._from_literal_transformer = from_literal_transformer
309
309
310
+ @property
311
+ def base_type (self ) -> Type :
312
+ return self ._type
313
+
310
314
def get_literal_type (self , t : Optional [Type [T ]] = None ) -> LiteralType :
311
315
return LiteralType .from_flyte_idl (self ._lt .to_flyte_idl ())
312
316
@@ -909,8 +913,9 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:
909
913
910
914
911
915
def generate_attribute_list_from_dataclass_json_mixin (schema : dict , schema_name : typing .Any ):
912
- attribute_list : typing .List [tuple [Any , GenericAlias ]] = []
916
+ attribute_list : typing .List [typing . Tuple [Any , Any ]] = []
913
917
for property_key , property_val in schema ["properties" ].items ():
918
+ property_type = ""
914
919
if property_val .get ("anyOf" ):
915
920
property_type = property_val ["anyOf" ][0 ]["type" ]
916
921
elif property_val .get ("enum" ):
@@ -934,9 +939,8 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
934
939
)
935
940
)
936
941
elif property_val .get ("additionalProperties" ):
937
- attribute_list .append (
938
- (property_key , typing .Dict [str , _get_element_type (property_val ["additionalProperties" ])]) # type: ignore
939
- )
942
+ elem_type = _get_element_type (property_val ["additionalProperties" ])
943
+ attribute_list .append ((property_key , typing .Dict [str , elem_type ])) # type: ignore
940
944
else :
941
945
sub_schemea_name = property_val ["title" ]
942
946
attribute_list .append (
@@ -1003,114 +1007,64 @@ def register_additional_type(cls, transformer: TypeTransformer[T], additional_ty
1003
1007
cls ._REGISTRY [additional_type ] = transformer
1004
1008
1005
1009
@classmethod
1006
- def get_transformer (cls , python_type : Type ) -> TypeTransformer [T ]:
1007
- """
1008
- The TypeEngine hierarchy for flyteKit. This method looks up and selects the type transformer. The algorithm is
1009
- as follows
1010
-
1011
- d = dictionary of registered transformers, where is a python `type`
1012
- v = lookup type
1013
-
1014
- Step 1:
1015
- If the type is annotated with a TypeTransformer instance, use that.
1016
-
1017
- Step 2:
1018
- find a transformer that matches v exactly
1019
-
1020
- Step 3:
1021
- find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc
1022
-
1023
- Step 4:
1024
- Walk the inheritance hierarchy of v and find a transformer that matches the first base class.
1025
- This is potentially non-deterministic - will depend on the registration pattern.
1026
-
1027
- Special case:
1028
- If v inherits from Enum, use the Enum transformer even if Enum is not the first base class.
1029
-
1030
- TODO lets make this deterministic by using an ordered dict
1031
-
1032
- Step 5:
1033
- if v is of type data class, use the dataclass transformer
1034
-
1035
- Step 6:
1036
- Pickle transformer is used
1037
- """
1010
+ def _get_transformer (cls , python_type : Type ) -> Optional [TypeTransformer [T ]]:
1038
1011
cls .lazy_import_transformers ()
1039
- # Step 1
1040
1012
if is_annotated (python_type ):
1041
1013
args = get_args (python_type )
1042
1014
for annotation in args :
1043
1015
if isinstance (annotation , TypeTransformer ):
1044
1016
return annotation
1017
+ return cls .get_transformer (args [0 ])
1045
1018
1046
- python_type = args [0 ]
1047
-
1048
- # Step 2
1049
- # this makes sure that if it's a list/dict of annotated types, we hit the unwrapping code in step 2
1050
- # see test_list_of_annotated in test_structured_dataset.py
1051
- if (
1052
- (not hasattr (python_type , "__origin__" ))
1053
- or (
1054
- hasattr (python_type , "__origin__" )
1055
- and (python_type .__origin__ is not list and python_type .__origin__ is not dict )
1056
- )
1057
- ) and python_type in cls ._REGISTRY :
1058
- return cls ._REGISTRY [python_type ]
1019
+ if inspect .isclass (python_type ) and issubclass (python_type , enum .Enum ):
1020
+ # Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
1021
+ return cls ._ENUM_TRANSFORMER
1059
1022
1060
- # Step 3
1061
1023
if hasattr (python_type , "__origin__" ):
1062
- # Handling of annotated generics, eg:
1063
- # Annotated[typing.List[int], 'foo']
1064
- if is_annotated (python_type ):
1065
- return cls .get_transformer (get_args (python_type )[0 ])
1066
-
1024
+ # If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON]
1025
+ # or List[int] has been specifically registered; we should check for the entire type.
1026
+ # The challenge is for StructuredDataset, example List[StructuredDataset] the column names is an OrderedDict
1027
+ # are not hashable, thus looking up this type is not possible.
1028
+ # In such as case, we will have to skip the "type" lookup and use the origin type only
1029
+ try :
1030
+ if python_type in cls ._REGISTRY :
1031
+ return cls ._REGISTRY [python_type ]
1032
+ except TypeError :
1033
+ pass
1067
1034
if python_type .__origin__ in cls ._REGISTRY :
1068
1035
return cls ._REGISTRY [python_type .__origin__ ]
1069
1036
1070
- raise ValueError (f"Generic Type { python_type .__origin__ } not supported currently in Flytekit." )
1071
-
1072
- # Step 4
1073
- # To facilitate cases where users may specify one transformer for multiple types that all inherit from one
1074
- # parent.
1075
- if inspect .isclass (python_type ) and issubclass (python_type , enum .Enum ):
1076
- # Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
1077
- return cls ._ENUM_TRANSFORMER
1037
+ # Handling UnionType specially - PEP 604
1038
+ if sys .version_info >= (3 , 10 ):
1039
+ import types
1078
1040
1079
- from flytekit .types .iterator .json_iterator import JSONIterator
1041
+ if isinstance (python_type , types .UnionType ):
1042
+ return cls ._REGISTRY [types .UnionType ]
1080
1043
1081
- for base_type in cls ._REGISTRY .keys ():
1082
- if base_type is None :
1083
- continue # None is actually one of the keys, but isinstance/issubclass doesn't work on it
1084
- try :
1085
- origin_type : Optional [typing .Any ] = base_type
1086
- if hasattr (base_type , "__args__" ):
1087
- origin_base_type = get_origin (base_type )
1088
- if isinstance (origin_base_type , type ) and issubclass (
1089
- origin_base_type , typing .Iterator
1090
- ): # Iterator[JSON]
1091
- origin_type = origin_base_type
1092
-
1093
- if isinstance (python_type , origin_type ) or ( # type: ignore[arg-type]
1094
- inspect .isclass (python_type ) and issubclass (python_type , origin_type ) # type: ignore[arg-type]
1095
- ):
1096
- # Consider Iterator[JSON] but not vanilla Iterator when the value is a JSON iterator.
1097
- if (
1098
- isinstance (python_type , type )
1099
- and issubclass (python_type , JSONIterator )
1100
- and not get_args (base_type )
1101
- ):
1102
- continue
1103
- return cls ._REGISTRY [base_type ]
1104
- except TypeError :
1105
- # As of python 3.9, calls to isinstance raise a TypeError if the base type is not a valid type, which
1106
- # is the case for one of the restricted types, namely NamedTuple.
1107
- logger .debug (f"Invalid base type { base_type } in call to isinstance" , exc_info = True )
1044
+ if python_type in cls ._REGISTRY :
1045
+ return cls ._REGISTRY [python_type ]
1108
1046
1109
- # Step 5
1110
1047
if dataclasses .is_dataclass (python_type ):
1111
1048
return cls ._DATACLASS_TRANSFORMER
1112
1049
1113
- # Step 6
1050
+ return None
1051
+
1052
+ @classmethod
1053
+ def get_transformer (cls , python_type : Type ) -> TypeTransformer [T ]:
1054
+ """
1055
+ Implements a recursive search for the transformer.
1056
+ """
1057
+ v = cls ._get_transformer (python_type )
1058
+ if v is not None :
1059
+ return v
1060
+
1061
+ if hasattr (python_type , "__mro__" ):
1062
+ class_tree = inspect .getmro (python_type )
1063
+ for t in class_tree :
1064
+ v = cls ._get_transformer (t )
1065
+ if v is not None :
1066
+ return v
1067
+
1114
1068
display_pickle_warning (str (python_type ))
1115
1069
from flytekit .types .pickle .pickle import FlytePickleTransformer
1116
1070
@@ -2207,7 +2161,7 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
2207
2161
)
2208
2162
element_format = element_property ["format" ] if "format" in element_property else None
2209
2163
2210
- if type (element_type ) == list :
2164
+ if isinstance (element_type , list ) :
2211
2165
# Element type of Optional[int] is [integer, None]
2212
2166
return typing .Optional [_get_element_type ({"type" : element_type [0 ]})] # type: ignore
2213
2167
@@ -2255,89 +2209,82 @@ def _check_and_convert_void(lv: Literal) -> None:
2255
2209
return None
2256
2210
2257
2211
2258
- def _register_default_type_transformers ():
2259
- TypeEngine .register (
2260
- SimpleTransformer (
2261
- "int" ,
2262
- int ,
2263
- _type_models .LiteralType (simple = _type_models .SimpleType .INTEGER ),
2264
- lambda x : Literal (scalar = Scalar (primitive = Primitive (integer = x ))),
2265
- lambda x : x .scalar .primitive .integer ,
2266
- )
2267
- )
2268
-
2269
- TypeEngine .register (
2270
- SimpleTransformer (
2271
- "float" ,
2272
- float ,
2273
- _type_models .LiteralType (simple = _type_models .SimpleType .FLOAT ),
2274
- lambda x : Literal (scalar = Scalar (primitive = Primitive (float_value = x ))),
2275
- _check_and_covert_float ,
2276
- )
2277
- )
2278
-
2279
- TypeEngine .register (
2280
- SimpleTransformer (
2281
- "bool" ,
2282
- bool ,
2283
- _type_models .LiteralType (simple = _type_models .SimpleType .BOOLEAN ),
2284
- lambda x : Literal (scalar = Scalar (primitive = Primitive (boolean = x ))),
2285
- lambda x : x .scalar .primitive .boolean ,
2286
- )
2287
- )
2212
+ IntTransformer = SimpleTransformer (
2213
+ "int" ,
2214
+ int ,
2215
+ _type_models .LiteralType (simple = _type_models .SimpleType .INTEGER ),
2216
+ lambda x : Literal (scalar = Scalar (primitive = Primitive (integer = x ))),
2217
+ lambda x : x .scalar .primitive .integer ,
2218
+ )
2219
+
2220
+ FloatTransformer = SimpleTransformer (
2221
+ "float" ,
2222
+ float ,
2223
+ _type_models .LiteralType (simple = _type_models .SimpleType .FLOAT ),
2224
+ lambda x : Literal (scalar = Scalar (primitive = Primitive (float_value = x ))),
2225
+ _check_and_covert_float ,
2226
+ )
2227
+
2228
+ BoolTransformer = SimpleTransformer (
2229
+ "bool" ,
2230
+ bool ,
2231
+ _type_models .LiteralType (simple = _type_models .SimpleType .BOOLEAN ),
2232
+ lambda x : Literal (scalar = Scalar (primitive = Primitive (boolean = x ))),
2233
+ lambda x : x .scalar .primitive .boolean ,
2234
+ )
2235
+
2236
+ StrTransformer = SimpleTransformer (
2237
+ "str" ,
2238
+ str ,
2239
+ _type_models .LiteralType (simple = _type_models .SimpleType .STRING ),
2240
+ lambda x : Literal (scalar = Scalar (primitive = Primitive (string_value = x ))),
2241
+ lambda x : x .scalar .primitive .string_value ,
2242
+ )
2243
+
2244
+ DatetimeTransformer = SimpleTransformer (
2245
+ "datetime" ,
2246
+ datetime .datetime ,
2247
+ _type_models .LiteralType (simple = _type_models .SimpleType .DATETIME ),
2248
+ lambda x : Literal (scalar = Scalar (primitive = Primitive (datetime = x ))),
2249
+ lambda x : x .scalar .primitive .datetime ,
2250
+ )
2251
+
2252
+ TimedeltaTransformer = SimpleTransformer (
2253
+ "timedelta" ,
2254
+ datetime .timedelta ,
2255
+ _type_models .LiteralType (simple = _type_models .SimpleType .DURATION ),
2256
+ lambda x : Literal (scalar = Scalar (primitive = Primitive (duration = x ))),
2257
+ lambda x : x .scalar .primitive .duration ,
2258
+ )
2259
+
2260
+ DateTransformer = SimpleTransformer (
2261
+ "date" ,
2262
+ datetime .date ,
2263
+ _type_models .LiteralType (simple = _type_models .SimpleType .DATETIME ),
2264
+ lambda x : Literal (
2265
+ scalar = Scalar (primitive = Primitive (datetime = datetime .datetime .combine (x , datetime .time .min )))
2266
+ ), # convert datetime to date
2267
+ lambda x : x .scalar .primitive .datetime .date (), # get date from datetime
2268
+ )
2269
+
2270
+ NoneTransformer = SimpleTransformer (
2271
+ "none" ,
2272
+ type (None ),
2273
+ _type_models .LiteralType (simple = _type_models .SimpleType .NONE ),
2274
+ lambda x : Literal (scalar = Scalar (none_type = Void ())),
2275
+ lambda x : _check_and_convert_void (x ),
2276
+ )
2288
2277
2289
- TypeEngine .register (
2290
- SimpleTransformer (
2291
- "str" ,
2292
- str ,
2293
- _type_models .LiteralType (simple = _type_models .SimpleType .STRING ),
2294
- lambda x : Literal (scalar = Scalar (primitive = Primitive (string_value = x ))),
2295
- lambda x : x .scalar .primitive .string_value ,
2296
- )
2297
- )
2298
-
2299
- TypeEngine .register (
2300
- SimpleTransformer (
2301
- "datetime" ,
2302
- datetime .datetime ,
2303
- _type_models .LiteralType (simple = _type_models .SimpleType .DATETIME ),
2304
- lambda x : Literal (scalar = Scalar (primitive = Primitive (datetime = x ))),
2305
- lambda x : x .scalar .primitive .datetime ,
2306
- )
2307
- )
2308
2278
2309
- TypeEngine .register (
2310
- SimpleTransformer (
2311
- "timedelta" ,
2312
- datetime .timedelta ,
2313
- _type_models .LiteralType (simple = _type_models .SimpleType .DURATION ),
2314
- lambda x : Literal (scalar = Scalar (primitive = Primitive (duration = x ))),
2315
- lambda x : x .scalar .primitive .duration ,
2316
- )
2317
- )
2318
-
2319
- TypeEngine .register (
2320
- SimpleTransformer (
2321
- "date" ,
2322
- datetime .date ,
2323
- _type_models .LiteralType (simple = _type_models .SimpleType .DATETIME ),
2324
- lambda x : Literal (
2325
- scalar = Scalar (primitive = Primitive (datetime = datetime .datetime .combine (x , datetime .time .min )))
2326
- ), # convert datetime to date
2327
- lambda x : x .scalar .primitive .datetime .date (), # get date from datetime
2328
- )
2329
- )
2330
-
2331
- TypeEngine .register (
2332
- SimpleTransformer (
2333
- "none" ,
2334
- type (None ),
2335
- _type_models .LiteralType (simple = _type_models .SimpleType .NONE ),
2336
- lambda x : Literal (scalar = Scalar (none_type = Void ())),
2337
- lambda x : _check_and_convert_void (x ),
2338
- ),
2339
- [None ],
2340
- )
2279
+ def _register_default_type_transformers ():
2280
+ TypeEngine .register (IntTransformer )
2281
+ TypeEngine .register (FloatTransformer )
2282
+ TypeEngine .register (StrTransformer )
2283
+ TypeEngine .register (DatetimeTransformer )
2284
+ TypeEngine .register (DateTransformer )
2285
+ TypeEngine .register (TimedeltaTransformer )
2286
+ TypeEngine .register (BoolTransformer )
2287
+ TypeEngine .register (NoneTransformer , [None ]) # noqa
2341
2288
TypeEngine .register (ListTransformer ())
2342
2289
if sys .version_info >= (3 , 10 ):
2343
2290
from types import UnionType
0 commit comments