Skip to content

Commit 8274b2c

Browse files
Merge pull request #2 from lyft/fix-hive-unit-test
Implement Hive Unit Test Behavior
2 parents 5e97f30 + b1d31a1 commit 8274b2c

File tree

4 files changed

+97
-17
lines changed

4 files changed

+97
-17
lines changed

flytekit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from __future__ import absolute_import
22
import flytekit.plugins
33

4-
__version__ = '0.1.5'
4+
__version__ = '0.1.6'

flytekit/engines/unit/engine.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
from datetime import datetime as _datetime
88
from six import moves as _six_moves
99

10+
from google.protobuf.json_format import ParseDict as _ParseDict
11+
from flyteidl.plugins import qubole_pb2 as _qubole_pb2
1012
from flytekit.common import constants as _sdk_constants, utils as _common_utils
1113
from flytekit.common.exceptions import user as _user_exceptions, system as _system_exception
1214
from flytekit.common.types import helpers as _type_helpers
1315
from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration
1416
from flytekit.engines import common as _common_engine
1517
from flytekit.engines.unit.mock_stats import MockStats
1618
from flytekit.interfaces.data import data_proxy as _data_proxy
17-
from flytekit.models import literals as _literals, array_job as _array_job
19+
from flytekit.models import literals as _literals, array_job as _array_job, qubole as _qubole_models
1820
from flytekit.models.core.identifier import WorkflowExecutionIdentifier
1921

2022

@@ -32,9 +34,12 @@ def get_task(self, sdk_task):
3234
return ReturnOutputsTask(sdk_task)
3335
elif sdk_task.type in {
3436
_sdk_constants.SdkTaskType.DYNAMIC_TASK,
35-
_sdk_constants.SdkTaskType.BATCH_HIVE_TASK
3637
}:
3738
return DynamicTask(sdk_task)
39+
elif sdk_task.type in {
40+
_sdk_constants.SdkTaskType.BATCH_HIVE_TASK,
41+
}:
42+
return HiveTask(sdk_task)
3843
else:
3944
raise _user_exceptions.FlyteAssertion(
4045
"Unit tests are not currently supported for tasks of type: {}".format(
@@ -76,20 +81,20 @@ def execute(self, inputs, context=None):
7681
Just execute the function and return the outputs as a user-readable dictionary.
7782
:param flytekit.models.literals.LiteralMap inputs:
7883
:param context:
79-
:rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
84+
:rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
8085
"""
8186
with _TemporaryConfiguration(
8287
_os.path.join(_os.path.dirname(__file__), 'unit.config'),
8388
internal_overrides={'image': 'unit_image'}
8489
):
8590
with _common_utils.AutoDeletingTempDir("unit_test_dir") as working_directory:
8691
with _data_proxy.LocalWorkingDirectoryContext(working_directory):
87-
return self._execute_user_code(inputs)
92+
return self._transform_for_user_output(self._execute_user_code(inputs))
8893

8994
def _execute_user_code(self, inputs):
9095
"""
9196
:param flytekit.models.literals.LiteralMap inputs:
92-
:rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
97+
:rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
9398
"""
9499
with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory:
95100
return self.sdk_task.execute(
@@ -107,24 +112,32 @@ def _execute_user_code(self, inputs):
107112
inputs
108113
)
109114

115+
def _transform_for_user_output(self, outputs):
116+
"""
117+
Take whatever is returned from the task execution and convert to a reasonable output for the behavior of this
118+
task's unit test.
119+
:param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
120+
:rtype: T
121+
"""
122+
return outputs
123+
110124
def register(self, identifier, version):
111125
raise _user_exceptions.FlyteAssertion("You cannot register unit test tasks.")
112126

113127

114128
class ReturnOutputsTask(UnitTestEngineTask):
115-
def execute(self, inputs, context=None):
129+
def _transform_for_user_output(self, outputs):
116130
"""
117-
Just execute the function and return the outputs as a user-readable dictionary.
118-
:param flytekit.models.literals.LiteralMap inputs:
119-
:param context:
120-
:rtype: dict[Text, T]
131+
Just return the outputs as a user-readable dictionary.
132+
:param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
133+
:rtype: T
121134
"""
122-
outputs = super(ReturnOutputsTask, self).execute(inputs)[_sdk_constants.OUTPUT_FILE_NAME]
135+
literal_map = outputs[_sdk_constants.OUTPUT_FILE_NAME]
123136
return {
124137
name: _type_helpers.get_sdk_type_from_literal_type(
125138
variable.type
126139
).promote_from_model(
127-
outputs.literals[name]
140+
literal_map.literals[name]
128141
).to_python_std()
129142
for name, variable in _six.iteritems(self.sdk_task.interface.outputs)
130143
}
@@ -135,7 +148,7 @@ class DynamicTask(ReturnOutputsTask):
135148
def _execute_user_code(self, inputs):
136149
"""
137150
:param flytekit.models.literals.LiteralMap inputs:
138-
:rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
151+
:rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
139152
"""
140153
results = super(DynamicTask, self)._execute_user_code(inputs)
141154
if _sdk_constants.FUTURES_FILE_NAME in results:
@@ -151,7 +164,7 @@ def _execute_user_code(self, inputs):
151164
# TODO: futures.outputs should have the Schema instances.
152165
# After schema is implemented, fill out random data into the random locations
153166
# then check output in test function
154-
# From Haytham even though we recommend people use typed schemas, they might not always do so...
167+
# Even though we recommend people use typed schemas, they might not always do so...
155168
# in which case it'll be impossible to predict the actual schema, we should support a
156169
# way for unit test authors to provide fake data regardless
157170
sub_task_output = None
@@ -201,7 +214,7 @@ def fulfil_bindings(binding_data, fulfilled_promises):
201214
fulfilled_promises
202215
203216
:param _interface.BindingData binding_data:
204-
:param dict[Text, T] fulfilled_promises:
217+
:param dict[Text,T] fulfilled_promises:
205218
:rtype:
206219
"""
207220
if binding_data.scalar:
@@ -228,3 +241,26 @@ def fulfil_bindings(binding_data, fulfilled_promises):
228241
k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for k, sub_binding_data in
229242
_six.iteritems(binding_data.map.bindings)
230243
}))
244+
245+
246+
class HiveTask(DynamicTask):
247+
def _transform_for_user_output(self, outputs):
248+
"""
249+
Just execute the function and return the list of Hive queries returned.
250+
:param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
251+
:rtype: list[Text]
252+
"""
253+
futures = outputs.get(_sdk_constants.FUTURES_FILE_NAME)
254+
if futures:
255+
task_ids_to_defs = {
256+
t.id.name: _qubole_models.QuboleHiveJob.from_flyte_idl(
257+
_ParseDict(t.custom, _qubole_pb2.QuboleHiveJob())
258+
)
259+
for t in futures.tasks
260+
}
261+
return [
262+
q.query
263+
for q in task_ids_to_defs[futures.nodes[0].task_node.reference_id.name].query_collection.queries
264+
]
265+
else:
266+
return []

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ max-complexity=16
77
[tool:pytest]
88
norecursedirs = common workflows spark
99
log_cli = true
10-
log_cli_level = 100
10+
log_cli_level = 20
1111

1212
[pep8]
1313
max-line-length = 120
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from __future__ import absolute_import
2+
from flytekit.sdk.tasks import hive_task
3+
import pytest
4+
5+
6+
def test_no_queries():
7+
@hive_task
8+
def test_hive_task(wf_params):
9+
pass
10+
11+
assert test_hive_task.unit_test() == []
12+
13+
14+
def test_empty_list_queries():
15+
@hive_task
16+
def test_hive_task(wf_params):
17+
return []
18+
19+
assert test_hive_task.unit_test() == []
20+
21+
22+
def test_one_query():
23+
@hive_task
24+
def test_hive_task(wf_params):
25+
return "abc"
26+
27+
assert test_hive_task.unit_test() == ["abc"]
28+
29+
30+
def test_multiple_queries():
31+
@hive_task
32+
def test_hive_task(wf_params):
33+
return ["abc", "cde"]
34+
35+
assert test_hive_task.unit_test() == ["abc", "cde"]
36+
37+
38+
def test_raise_exception():
39+
@hive_task
40+
def test_hive_task(wf_params):
41+
raise FloatingPointError("Floating point error for some reason.")
42+
43+
with pytest.raises(FloatingPointError):
44+
test_hive_task.unit_test()

0 commit comments

Comments
 (0)