Skip to content

Commit

Permalink
Merge pull request #1 from lyft/single-hive-query-nodes
Browse files Browse the repository at this point in the history
Single Hive query nodes
  • Loading branch information
matthewphsmith authored Sep 17, 2019
2 parents c16795e + f135e47 commit e2574d8
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 28 deletions.
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import absolute_import
import flytekit.plugins

__version__ = '0.1.9'
__version__ = '0.2.0'
39 changes: 26 additions & 13 deletions flytekit/common/tasks/hive_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,34 @@ def __init__(
self._cluster_label = cluster_label
self._tags = tags

def _generate_hive_queries(self, context, inputs_dict):
def _generate_plugin_objects(self, context, inputs_dict):
"""
Runs user code and and produces hive queries
:param flytekit.engines.common.EngineContext context:
:param dict[Text, T] inputs:
:rtype: _qubole.QuboleHiveJob
:rtype: list[_qubole.QuboleHiveJob]
"""
queries_from_task = super(SdkHiveTask, self)._execute_user_code(context, inputs_dict) or []
if not isinstance(queries_from_task, list):
queries_from_task = [queries_from_task]

self._validate_queries(queries_from_task)
queries = _qubole.HiveQueryCollection(
[_qubole.HiveQuery(query=q, timeout_sec=self.metadata.timeout.seconds,
retry_count=self.metadata.retries.retries) for q in queries_from_task])
return _qubole.QuboleHiveJob(queries, self._cluster_label, self._tags)
plugin_objects = []

for q in queries_from_task:
hive_query = _qubole.HiveQuery(query=q, timeout_sec=self.metadata.timeout.seconds,
retry_count=self.metadata.retries.retries)

# TODO: Remove this after all users of older SDK versions that did the single node, multi-query pattern are
# deprecated. This is only here for backwards compatibility - in addition to writing the query to the
# query field, we also construct a QueryCollection with only one query. This will ensure that the
# older plugin will continue to work.
query_collection = _qubole.HiveQueryCollection([hive_query])

plugin_objects.append(_qubole.QuboleHiveJob(hive_query, self._cluster_label, self._tags,
query_collection=query_collection))

return plugin_objects

@staticmethod
def _validate_task_parameters(cluster_label, tags):
Expand Down Expand Up @@ -146,28 +158,29 @@ def _produce_dynamic_job_spec(self, context, inputs):
# Add outputs to inputs
inputs_dict.update(outputs_dict)

# Note: Today a hive task corresponds to a dynamic job spec with one node, which contains multiple
# queries. We may change this in future.
nodes = []
tasks = []
generated_queries = self._generate_hive_queries(context, inputs_dict)
# One node per query
generated_queries = self._generate_plugin_objects(context, inputs_dict)

# Create output bindings always - this has to happen after user code has run
output_bindings = [_literal_models.Binding(var=name, binding=_interface.BindingData.from_python_std(
b.sdk_type.to_flyte_literal_type(), b.value))
for name, b in _six.iteritems(outputs_dict)]

if len(generated_queries.query_collection.queries) > 0:
i = 0
for quboleHiveJob in generated_queries:
hive_job_node = _create_hive_job_node(
"HiveQueries",
generated_queries.to_flyte_idl(),
"HiveQuery_{}".format(i),
quboleHiveJob.to_flyte_idl(),
self.metadata
)
nodes.append(hive_job_node)
tasks.append(hive_job_node.executable_sdk_object)
i += 1

dynamic_job_spec = _dynamic_job.DynamicJobSpec(
min_successes=len(nodes), # At most we only have one node for now, see above comment
min_successes=len(nodes),
tasks=tasks,
nodes=nodes,
outputs=output_bindings,
Expand Down
33 changes: 23 additions & 10 deletions flytekit/models/qubole.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def from_flyte_idl(cls, pb2_object):
retry_count=pb2_object.retryCount
)


class HiveQueryCollection(_common.FlyteIdlEntity):
def __init__(self, queries):
"""
Expand Down Expand Up @@ -101,17 +100,19 @@ def from_flyte_idl(cls, pb2_object):

class QuboleHiveJob(_common.FlyteIdlEntity):

def __init__(self, query_collection, cluster_label, tags):
def __init__(self, query, cluster_label, tags, query_collection=None):
"""
Initializes a HiveJob.
:param HiveQueryCollection query_collection: Queries to execute.
:param HiveQuery query: Single query to execute
:param Text cluster_label: The qubole cluster label to execute the query on
:param list[Text] tags: User tags for the queries
:param HiveQueryCollection query_collection: Deprecated Queries to execute.
"""
self._query_collection = query_collection
self._query = query
self._cluster_label = cluster_label
self._tags = tags
self._query_collection = query_collection

@property
def query_collection(self):
Expand All @@ -121,6 +122,15 @@ def query_collection(self):
"""
return self._query_collection

@property
def query(self):
"""
The query to be executed
:rtype: HiveQuery
"""
return self._query


@property
def cluster_label(self):
"""
Expand All @@ -142,19 +152,22 @@ def to_flyte_idl(self):
:rtype: _qubole.QuboleHiveJob
"""
return _qubole.QuboleHiveJob(
query_collection=self._query_collection.to_flyte_idl(),
query_collection=self._query_collection.to_flyte_idl() if self._query_collection else None,
query=self._query.to_flyte_idl() if self._query else None,
cluster_label=self._cluster_label,
tags=self._tags
)

@classmethod
def from_flyte_idl(cls, pb2_object):
def from_flyte_idl(cls, p):
"""
:param _qubole.QuboleHiveJob pb2_object:
:param _qubole.QuboleHiveJob p:
:rtype: QuboleHiveJob
"""
return cls(
query_collection=HiveQueryCollection.from_flyte_idl(pb2_object.query_collection),
cluster_label=pb2_object.cluster_label,
tags=pb2_object.tags,
query_collection=HiveQueryCollection.from_flyte_idl(p.query_collection) if p.HasField(
"query_collection") else None,
query=HiveQuery.from_flyte_idl(p.query) if p.HasField("query") else None,
cluster_label=p.cluster_label,
tags=p.tags,
)
20 changes: 20 additions & 0 deletions tests/flytekit/unit/models/test_qubole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import absolute_import

import pytest

from flytekit.models import qubole
from tests.flytekit.common.parameterizers import LIST_OF_ALL_LITERAL_TYPES


def test_hive_query():
q = qubole.HiveQuery(query='some query', timeout_sec=10, retry_count=0)
q2 = qubole.HiveQuery.from_flyte_idl(q.to_flyte_idl())
assert q == q2
assert q2.query == 'some query'


def test_hive_job():
query = qubole.HiveQuery(query='some query', timeout_sec=10, retry_count=0)
obj = qubole.QuboleHiveJob(query=query, cluster_label='default', tags=[])
obj2 = qubole.QuboleHiveJob.from_flyte_idl(obj.to_flyte_idl())
assert obj == obj2
12 changes: 8 additions & 4 deletions tests/flytekit/unit/sdk/tasks/test_hive_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,18 @@ def test_hive_task_query_generation():
for name, variable in _six.iteritems(two_queries.interface.outputs)
}

qubole_hive_job = two_queries._generate_hive_queries(context, references)
assert (len(qubole_hive_job.query_collection.queries) == 2)
qubole_hive_jobs = two_queries._generate_plugin_objects(context, references)
assert len(qubole_hive_jobs) == 2

# deprecated, collection is only here for backwards compatibility
assert len(qubole_hive_jobs[0].query_collection.queries) == 1
assert len(qubole_hive_jobs[1].query_collection.queries) == 1

# The output references should now have the same fake S3 path as the formatted queries
assert references['hive_results'].value[0].uri != ''
assert references['hive_results'].value[1].uri != ''
assert references['hive_results'].value[0].uri in qubole_hive_job.query_collection.queries[0].query
assert references['hive_results'].value[1].uri in qubole_hive_job.query_collection.queries[1].query
assert references['hive_results'].value[0].uri in qubole_hive_jobs[0].query.query
assert references['hive_results'].value[1].uri in qubole_hive_jobs[1].query.query


def test_hive_task_dynamic_job_spec_generation():
Expand Down

0 comments on commit e2574d8

Please sign in to comment.