Skip to content

Commit e2574d8

Browse files
Merge pull request #1 from lyft/single-hive-query-nodes
Single Hive query nodes
2 parents c16795e + f135e47 commit e2574d8

File tree

5 files changed

+78
-28
lines changed

5 files changed

+78
-28
lines changed

flytekit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from __future__ import absolute_import
22
import flytekit.plugins
33

4-
__version__ = '0.1.9'
4+
__version__ = '0.2.0'

flytekit/common/tasks/hive_task.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,34 @@ def __init__(
7878
self._cluster_label = cluster_label
7979
self._tags = tags
8080

81-
def _generate_hive_queries(self, context, inputs_dict):
81+
def _generate_plugin_objects(self, context, inputs_dict):
8282
"""
8383
Runs user code and and produces hive queries
8484
:param flytekit.engines.common.EngineContext context:
8585
:param dict[Text, T] inputs:
86-
:rtype: _qubole.QuboleHiveJob
86+
:rtype: list[_qubole.QuboleHiveJob]
8787
"""
8888
queries_from_task = super(SdkHiveTask, self)._execute_user_code(context, inputs_dict) or []
8989
if not isinstance(queries_from_task, list):
9090
queries_from_task = [queries_from_task]
9191

9292
self._validate_queries(queries_from_task)
93-
queries = _qubole.HiveQueryCollection(
94-
[_qubole.HiveQuery(query=q, timeout_sec=self.metadata.timeout.seconds,
95-
retry_count=self.metadata.retries.retries) for q in queries_from_task])
96-
return _qubole.QuboleHiveJob(queries, self._cluster_label, self._tags)
93+
plugin_objects = []
94+
95+
for q in queries_from_task:
96+
hive_query = _qubole.HiveQuery(query=q, timeout_sec=self.metadata.timeout.seconds,
97+
retry_count=self.metadata.retries.retries)
98+
99+
# TODO: Remove this after all users of older SDK versions that did the single node, multi-query pattern are
100+
# deprecated. This is only here for backwards compatibility - in addition to writing the query to the
101+
# query field, we also construct a QueryCollection with only one query. This will ensure that the
102+
# older plugin will continue to work.
103+
query_collection = _qubole.HiveQueryCollection([hive_query])
104+
105+
plugin_objects.append(_qubole.QuboleHiveJob(hive_query, self._cluster_label, self._tags,
106+
query_collection=query_collection))
107+
108+
return plugin_objects
97109

98110
@staticmethod
99111
def _validate_task_parameters(cluster_label, tags):
@@ -146,28 +158,29 @@ def _produce_dynamic_job_spec(self, context, inputs):
146158
# Add outputs to inputs
147159
inputs_dict.update(outputs_dict)
148160

149-
# Note: Today a hive task corresponds to a dynamic job spec with one node, which contains multiple
150-
# queries. We may change this in future.
151161
nodes = []
152162
tasks = []
153-
generated_queries = self._generate_hive_queries(context, inputs_dict)
163+
# One node per query
164+
generated_queries = self._generate_plugin_objects(context, inputs_dict)
154165

155166
# Create output bindings always - this has to happen after user code has run
156167
output_bindings = [_literal_models.Binding(var=name, binding=_interface.BindingData.from_python_std(
157168
b.sdk_type.to_flyte_literal_type(), b.value))
158169
for name, b in _six.iteritems(outputs_dict)]
159170

160-
if len(generated_queries.query_collection.queries) > 0:
171+
i = 0
172+
for quboleHiveJob in generated_queries:
161173
hive_job_node = _create_hive_job_node(
162-
"HiveQueries",
163-
generated_queries.to_flyte_idl(),
174+
"HiveQuery_{}".format(i),
175+
quboleHiveJob.to_flyte_idl(),
164176
self.metadata
165177
)
166178
nodes.append(hive_job_node)
167179
tasks.append(hive_job_node.executable_sdk_object)
180+
i += 1
168181

169182
dynamic_job_spec = _dynamic_job.DynamicJobSpec(
170-
min_successes=len(nodes), # At most we only have one node for now, see above comment
183+
min_successes=len(nodes),
171184
tasks=tasks,
172185
nodes=nodes,
173186
outputs=output_bindings,

flytekit/models/qubole.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def from_flyte_idl(cls, pb2_object):
6363
retry_count=pb2_object.retryCount
6464
)
6565

66-
6766
class HiveQueryCollection(_common.FlyteIdlEntity):
6867
def __init__(self, queries):
6968
"""
@@ -101,17 +100,19 @@ def from_flyte_idl(cls, pb2_object):
101100

102101
class QuboleHiveJob(_common.FlyteIdlEntity):
103102

104-
def __init__(self, query_collection, cluster_label, tags):
103+
def __init__(self, query, cluster_label, tags, query_collection=None):
105104
"""
106105
Initializes a HiveJob.
107106
108-
:param HiveQueryCollection query_collection: Queries to execute.
107+
:param HiveQuery query: Single query to execute
109108
:param Text cluster_label: The qubole cluster label to execute the query on
110109
:param list[Text] tags: User tags for the queries
110+
:param HiveQueryCollection query_collection: Deprecated Queries to execute.
111111
"""
112-
self._query_collection = query_collection
112+
self._query = query
113113
self._cluster_label = cluster_label
114114
self._tags = tags
115+
self._query_collection = query_collection
115116

116117
@property
117118
def query_collection(self):
@@ -121,6 +122,15 @@ def query_collection(self):
121122
"""
122123
return self._query_collection
123124

125+
@property
126+
def query(self):
127+
"""
128+
The query to be executed
129+
:rtype: HiveQuery
130+
"""
131+
return self._query
132+
133+
124134
@property
125135
def cluster_label(self):
126136
"""
@@ -142,19 +152,22 @@ def to_flyte_idl(self):
142152
:rtype: _qubole.QuboleHiveJob
143153
"""
144154
return _qubole.QuboleHiveJob(
145-
query_collection=self._query_collection.to_flyte_idl(),
155+
query_collection=self._query_collection.to_flyte_idl() if self._query_collection else None,
156+
query=self._query.to_flyte_idl() if self._query else None,
146157
cluster_label=self._cluster_label,
147158
tags=self._tags
148159
)
149160

150161
@classmethod
151-
def from_flyte_idl(cls, pb2_object):
162+
def from_flyte_idl(cls, p):
152163
"""
153-
:param _qubole.QuboleHiveJob pb2_object:
164+
:param _qubole.QuboleHiveJob p:
154165
:rtype: QuboleHiveJob
155166
"""
156167
return cls(
157-
query_collection=HiveQueryCollection.from_flyte_idl(pb2_object.query_collection),
158-
cluster_label=pb2_object.cluster_label,
159-
tags=pb2_object.tags,
168+
query_collection=HiveQueryCollection.from_flyte_idl(p.query_collection) if p.HasField(
169+
"query_collection") else None,
170+
query=HiveQuery.from_flyte_idl(p.query) if p.HasField("query") else None,
171+
cluster_label=p.cluster_label,
172+
tags=p.tags,
160173
)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import absolute_import
2+
3+
import pytest
4+
5+
from flytekit.models import qubole
6+
from tests.flytekit.common.parameterizers import LIST_OF_ALL_LITERAL_TYPES
7+
8+
9+
def test_hive_query():
10+
q = qubole.HiveQuery(query='some query', timeout_sec=10, retry_count=0)
11+
q2 = qubole.HiveQuery.from_flyte_idl(q.to_flyte_idl())
12+
assert q == q2
13+
assert q2.query == 'some query'
14+
15+
16+
def test_hive_job():
17+
query = qubole.HiveQuery(query='some query', timeout_sec=10, retry_count=0)
18+
obj = qubole.QuboleHiveJob(query=query, cluster_label='default', tags=[])
19+
obj2 = qubole.QuboleHiveJob.from_flyte_idl(obj.to_flyte_idl())
20+
assert obj == obj2

tests/flytekit/unit/sdk/tasks/test_hive_tasks.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,18 @@ def test_hive_task_query_generation():
108108
for name, variable in _six.iteritems(two_queries.interface.outputs)
109109
}
110110

111-
qubole_hive_job = two_queries._generate_hive_queries(context, references)
112-
assert (len(qubole_hive_job.query_collection.queries) == 2)
111+
qubole_hive_jobs = two_queries._generate_plugin_objects(context, references)
112+
assert len(qubole_hive_jobs) == 2
113+
114+
# deprecated, collection is only here for backwards compatibility
115+
assert len(qubole_hive_jobs[0].query_collection.queries) == 1
116+
assert len(qubole_hive_jobs[1].query_collection.queries) == 1
113117

114118
# The output references should now have the same fake S3 path as the formatted queries
115119
assert references['hive_results'].value[0].uri != ''
116120
assert references['hive_results'].value[1].uri != ''
117-
assert references['hive_results'].value[0].uri in qubole_hive_job.query_collection.queries[0].query
118-
assert references['hive_results'].value[1].uri in qubole_hive_job.query_collection.queries[1].query
121+
assert references['hive_results'].value[0].uri in qubole_hive_jobs[0].query.query
122+
assert references['hive_results'].value[1].uri in qubole_hive_jobs[1].query.query
119123

120124

121125
def test_hive_task_dynamic_job_spec_generation():

0 commit comments

Comments
 (0)