7
7
from datetime import datetime as _datetime
8
8
from six import moves as _six_moves
9
9
10
+ from google .protobuf .json_format import ParseDict as _ParseDict
11
+ from flyteidl .plugins import qubole_pb2 as _qubole_pb2
10
12
from flytekit .common import constants as _sdk_constants , utils as _common_utils
11
13
from flytekit .common .exceptions import user as _user_exceptions , system as _system_exception
12
14
from flytekit .common .types import helpers as _type_helpers
13
15
from flytekit .configuration import TemporaryConfiguration as _TemporaryConfiguration
14
16
from flytekit .engines import common as _common_engine
15
17
from flytekit .engines .unit .mock_stats import MockStats
16
18
from flytekit .interfaces .data import data_proxy as _data_proxy
17
- from flytekit .models import literals as _literals , array_job as _array_job
19
+ from flytekit .models import literals as _literals , array_job as _array_job , qubole as _qubole_models
18
20
from flytekit .models .core .identifier import WorkflowExecutionIdentifier
19
21
20
22
@@ -32,9 +34,12 @@ def get_task(self, sdk_task):
32
34
return ReturnOutputsTask (sdk_task )
33
35
elif sdk_task .type in {
34
36
_sdk_constants .SdkTaskType .DYNAMIC_TASK ,
35
- _sdk_constants .SdkTaskType .BATCH_HIVE_TASK
36
37
}:
37
38
return DynamicTask (sdk_task )
39
+ elif sdk_task .type in {
40
+ _sdk_constants .SdkTaskType .BATCH_HIVE_TASK ,
41
+ }:
42
+ return HiveTask (sdk_task )
38
43
else :
39
44
raise _user_exceptions .FlyteAssertion (
40
45
"Unit tests are not currently supported for tasks of type: {}" .format (
@@ -76,20 +81,20 @@ def execute(self, inputs, context=None):
76
81
Just execute the function and return the outputs as a user-readable dictionary.
77
82
:param flytekit.models.literals.LiteralMap inputs:
78
83
:param context:
79
- :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
84
+ :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
80
85
"""
81
86
with _TemporaryConfiguration (
82
87
_os .path .join (_os .path .dirname (__file__ ), 'unit.config' ),
83
88
internal_overrides = {'image' : 'unit_image' }
84
89
):
85
90
with _common_utils .AutoDeletingTempDir ("unit_test_dir" ) as working_directory :
86
91
with _data_proxy .LocalWorkingDirectoryContext (working_directory ):
87
- return self ._execute_user_code (inputs )
92
+ return self ._transform_for_user_output ( self . _execute_user_code (inputs ) )
88
93
89
94
def _execute_user_code (self , inputs ):
90
95
"""
91
96
:param flytekit.models.literals.LiteralMap inputs:
92
- :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
97
+ :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
93
98
"""
94
99
with _common_utils .AutoDeletingTempDir ("user_dir" ) as user_working_directory :
95
100
return self .sdk_task .execute (
@@ -107,24 +112,32 @@ def _execute_user_code(self, inputs):
107
112
inputs
108
113
)
109
114
115
+ def _transform_for_user_output (self , outputs ):
116
+ """
117
+ Take whatever is returned from the task execution and convert to a reasonable output for the behavior of this
118
+ task's unit test.
119
+ :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
120
+ :rtype: T
121
+ """
122
+ return outputs
123
+
110
124
def register (self , identifier , version ):
111
125
raise _user_exceptions .FlyteAssertion ("You cannot register unit test tasks." )
112
126
113
127
114
128
class ReturnOutputsTask (UnitTestEngineTask ):
115
- def execute (self , inputs , context = None ):
129
+ def _transform_for_user_output (self , outputs ):
116
130
"""
117
- Just execute the function and return the outputs as a user-readable dictionary.
118
- :param flytekit.models.literals.LiteralMap inputs:
119
- :param context:
120
- :rtype: dict[Text, T]
131
+ Just return the outputs as a user-readable dictionary.
132
+ :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
133
+ :rtype: T
121
134
"""
122
- outputs = super ( ReturnOutputsTask , self ). execute ( inputs ) [_sdk_constants .OUTPUT_FILE_NAME ]
135
+ literal_map = outputs [_sdk_constants .OUTPUT_FILE_NAME ]
123
136
return {
124
137
name : _type_helpers .get_sdk_type_from_literal_type (
125
138
variable .type
126
139
).promote_from_model (
127
- outputs .literals [name ]
140
+ literal_map .literals [name ]
128
141
).to_python_std ()
129
142
for name , variable in _six .iteritems (self .sdk_task .interface .outputs )
130
143
}
@@ -135,7 +148,7 @@ class DynamicTask(ReturnOutputsTask):
135
148
def _execute_user_code (self , inputs ):
136
149
"""
137
150
:param flytekit.models.literals.LiteralMap inputs:
138
- :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
151
+ :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
139
152
"""
140
153
results = super (DynamicTask , self )._execute_user_code (inputs )
141
154
if _sdk_constants .FUTURES_FILE_NAME in results :
@@ -151,7 +164,7 @@ def _execute_user_code(self, inputs):
151
164
# TODO: futures.outputs should have the Schema instances.
152
165
# After schema is implemented, fill out random data into the random locations
153
166
# then check output in test function
154
- # From Haytham even though we recommend people use typed schemas, they might not always do so...
167
+ # Even though we recommend people use typed schemas, they might not always do so...
155
168
# in which case it'll be impossible to predict the actual schema, we should support a
156
169
# way for unit test authors to provide fake data regardless
157
170
sub_task_output = None
@@ -201,7 +214,7 @@ def fulfil_bindings(binding_data, fulfilled_promises):
201
214
fulfilled_promises
202
215
203
216
:param _interface.BindingData binding_data:
204
- :param dict[Text, T] fulfilled_promises:
217
+ :param dict[Text,T] fulfilled_promises:
205
218
:rtype:
206
219
"""
207
220
if binding_data .scalar :
@@ -228,3 +241,26 @@ def fulfil_bindings(binding_data, fulfilled_promises):
228
241
k : DynamicTask .fulfil_bindings (sub_binding_data , fulfilled_promises ) for k , sub_binding_data in
229
242
_six .iteritems (binding_data .map .bindings )
230
243
}))
244
+
245
+
246
+ class HiveTask (DynamicTask ):
247
+ def _transform_for_user_output (self , outputs ):
248
+ """
249
+ Just execute the function and return the list of Hive queries returned.
250
+ :param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
251
+ :rtype: list[Text]
252
+ """
253
+ futures = outputs .get (_sdk_constants .FUTURES_FILE_NAME )
254
+ if futures :
255
+ task_ids_to_defs = {
256
+ t .id .name : _qubole_models .QuboleHiveJob .from_flyte_idl (
257
+ _ParseDict (t .custom , _qubole_pb2 .QuboleHiveJob ())
258
+ )
259
+ for t in futures .tasks
260
+ }
261
+ return [
262
+ q .query
263
+ for q in task_ids_to_defs [futures .nodes [0 ].task_node .reference_id .name ].query_collection .queries
264
+ ]
265
+ else :
266
+ return []
0 commit comments