Skip to content

Commit

Permalink
Mtoledo/add interruptible parameter (#86)
Browse files Browse the repository at this point in the history
* add interruptible field [wip]

* interruptible to node metadata

* interruptible default to false and upd tests

* add interruptible to workflow

* upd idl

* upd tests

* add interruptible to tasks

* upd python and dynamic task

* upd taskmetadata

* upd workflow closure test

* upd version

* upd from_idl and add tests

* upd parameterizers

* dummy commit to retrigger build

* udp workflow metadata

* comment out tests

* merge conflict
  • Loading branch information
migueltol22 authored Mar 10, 2020
1 parent fd17913 commit 9f3b48f
Show file tree
Hide file tree
Showing 19 changed files with 219 additions and 95 deletions.
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import absolute_import
import flytekit.plugins

__version__ = '0.6.0b1'
__version__ = '0.6.0b2'
3 changes: 2 additions & 1 deletion flytekit/common/tasks/hive_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
task_type,
discovery_version,
retries,
interruptible,
deprecated,
storage_request,
cpu_request,
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(
:param dict[Text, Text] environment:
"""
self._task_function = task_function
super(SdkHiveTask, self).__init__(task_function, task_type, discovery_version, retries, deprecated,
super(SdkHiveTask, self).__init__(task_function, task_type, discovery_version, retries, interruptible, deprecated,
storage_request, cpu_request, gpu_request, memory_request, storage_limit,
cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, {})
self._validate_task_parameters(cluster_label, tags)
Expand Down
4 changes: 3 additions & 1 deletion flytekit/common/tasks/sdk_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
task_type,
discovery_version,
retries,
interruptible,
deprecated,
storage_request,
cpu_request,
Expand All @@ -70,6 +71,7 @@ def __init__(
:param Text task_type: string describing the task type
:param Text discovery_version: string describing the version for task discovery purposes
:param int retries: Number of retries to attempt
:param bool interruptible: Whether or not task is interruptible
:param Text deprecated:
:param Text storage_request:
:param Text cpu_request:
Expand All @@ -87,7 +89,7 @@ def __init__(
:param dict[Text, T] custom:
"""
super(SdkDynamicTask, self).__init__(
task_function, task_type, discovery_version, retries, deprecated,
task_function, task_type, discovery_version, retries, interruptible, deprecated,
storage_request, cpu_request, gpu_request, memory_request, storage_limit,
cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, custom)

Expand Down
3 changes: 3 additions & 0 deletions flytekit/common/tasks/sdk_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __init__(
task_type,
discovery_version,
retries,
interruptible,
deprecated,
storage_request,
cpu_request,
Expand All @@ -183,6 +184,7 @@ def __init__(
:param Text task_type: string describing the task type
:param Text discovery_version: string describing the version for task discovery purposes
:param int retries: Number of retries to attempt
:param bool interruptible: Specify whether task is interruptible
:param Text deprecated:
:param Text storage_request:
:param Text cpu_request:
Expand Down Expand Up @@ -210,6 +212,7 @@ def __init__(
),
timeout,
_literal_models.RetryStrategy(retries),
interruptible,
discovery_version,
deprecated
),
Expand Down
2 changes: 2 additions & 0 deletions flytekit/common/tasks/sidecar_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self,
task_type,
discovery_version,
retries,
interruptible,
deprecated,
storage_request,
cpu_request,
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self,
task_type,
discovery_version,
retries,
interruptible,
deprecated,
storage_request,
cpu_request,
Expand Down
3 changes: 3 additions & 0 deletions flytekit/common/tasks/spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
task_type,
discovery_version,
retries,
interruptible,
deprecated,
discoverable,
timeout,
Expand All @@ -69,6 +70,7 @@ def __init__(
:param Text task_type: string describing the task type
:param Text discovery_version: string describing the version for task discovery purposes
:param int retries: Number of retries to attempt
:param bool interruptible: Whether or not task is interruptible
:param Text deprecated:
:param bool discoverable:
:param datetime.timedelta timeout:
Expand All @@ -92,6 +94,7 @@ def __init__(
task_type,
discovery_version,
retries,
interruptible,
deprecated,
"",
"",
Expand Down
2 changes: 1 addition & 1 deletion flytekit/common/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __call__(self, *args, **input_map):
# TODO: Remove DEADBEEF
return _nodes.SdkNode(
id=None,
metadata=_workflow_model.NodeMetadata("DEADBEEF", self.metadata.timeout, self.metadata.retries),
metadata=_workflow_model.NodeMetadata("DEADBEEF", self.metadata.timeout, self.metadata.retries, self.metadata.interruptible),
bindings=sorted(bindings, key=lambda b: b.var),
upstream_nodes=upstream_nodes,
sdk_task=self
Expand Down
2 changes: 2 additions & 0 deletions flytekit/common/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(self, inputs, outputs, nodes, id=None, metadata=None, interface=Non
super(SdkWorkflow, self).__init__(
id=id,
metadata=metadata,
metadata_defaults=_workflow_models.WorkflowMetadataDefaults(),
interface=interface,
nodes=nodes,
outputs=output_bindings,
Expand Down Expand Up @@ -255,6 +256,7 @@ def promote_from_model(cls, base_model, sub_workflows=None, tasks=None):
inputs=None, outputs=None, nodes=list(node_map.values()),
id=_identifier.Identifier.promote_from_model(base_model.id),
metadata=base_model.metadata,
metadata_defaults=base_model.metadata_defaults,
interface=_interface.TypedInterface.promote_from_model(base_model.interface),
output_bindings=base_model.outputs,
)
Expand Down
3 changes: 3 additions & 0 deletions flytekit/contrib/sensors/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def _execute_user_code(self, context, inputs):
def sensor_task(
_task_function=None,
retries=0,
interruptible=None,
deprecated='',
storage_request=None,
cpu_request=None,
Expand Down Expand Up @@ -57,6 +58,7 @@ def my_task(wf_params):
.. note::
If retries > 0, the task must be able to recover from any remote state created within the user code. It is
strongly recommended that tasks are written to be idempotent.
:param bool interruptible: Specify whether task is interruptible
:param Text deprecated: [optional] string that should be provided if this task is deprecated. The string
will be logged as a warning so it should contain information regarding how to update to a newer task.
:param Text storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space
Expand Down Expand Up @@ -99,6 +101,7 @@ def wrapper(fn):
task_function=fn,
task_type=_common_constants.SdkTaskType.SENSOR_TASK,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
storage_request=storage_request,
cpu_request=cpu_request,
Expand Down
50 changes: 47 additions & 3 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def from_flyte_idl(cls, pb2_objct):

class NodeMetadata(_common.FlyteIdlEntity):

def __init__(self, name, timeout, retries):
def __init__(self, name, timeout, retries, interruptible=False):
"""
Defines extra information about the Node.
Expand All @@ -159,6 +159,7 @@ def __init__(self, name, timeout, retries):
self._name = name
self._timeout = timeout
self._retries = retries
self._interruptible = interruptible

@property
def name(self):
Expand All @@ -181,11 +182,18 @@ def retries(self):
"""
return self._retries

@property
def interruptible(self):
"""
:rtype: flytekit.models.
"""
return self._interruptible

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.workflow_pb2.NodeMetadata
"""
node_metadata = _core_workflow.NodeMetadata(name=self.name, retries=self.retries.to_flyte_idl())
node_metadata = _core_workflow.NodeMetadata(name=self.name, retries=self.retries.to_flyte_idl(), interruptible=self.interruptible)
node_metadata.timeout.FromTimedelta(self.timeout)
return node_metadata

Expand Down Expand Up @@ -458,10 +466,34 @@ def from_flyte_idl(cls, pb2_object):
"""
return cls()

class WorkflowMetadataDefaults(_common.FlyteIdlEntity):

def __init__(self, interruptible=None):
"""
Metadata Defaults for the workflow.
"""
self.interruptible_ = interruptible

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.workflow_pb2.WorkflowMetadataDefaults
"""
return _core_workflow.WorkflowMetadataDefaults(
interruptible=self.interruptible_
)

@classmethod
def from_flyte_idl(cls, pb2_object):
"""
:param flyteidl.core.workflow_pb2.WorkflowMetadataDefaults pb2_object:
:rtype: WorkflowMetadata
"""
return cls(interruptible=pb2_object.interruptible)


class WorkflowTemplate(_common.FlyteIdlEntity):

def __init__(self, id, metadata, interface, nodes, outputs, failure_node=None):
def __init__(self, id, metadata, metadata_defaults, interface, nodes, outputs, failure_node=None):
"""
A workflow template encapsulates all the task, branch, and subworkflow nodes to run a statically analyzable,
directed acyclic graph. It contains also metadata that tells the system how to execute the workflow (i.e.
Expand All @@ -470,6 +502,7 @@ def __init__(self, id, metadata, interface, nodes, outputs, failure_node=None):
:param flytekit.models.core.identifier.Identifier id: This is an autogenerated id by the system. The id is
globally unique across Flyte.
:param WorkflowMetadata metadata: This contains information on how to run the workflow.
:param WorkflowMetadataDefaults metadata_defaults: This contains the default information on how to run the workflow.
:param flytekit.models.interface.TypedInterface interface: Defines a strongly typed interface for the
Workflow (inputs, outputs). This can include some optional parameters.
:param list[Node] nodes: A list of nodes. In addition, "globals" is a special reserved node id that
Expand All @@ -485,6 +518,7 @@ def __init__(self, id, metadata, interface, nodes, outputs, failure_node=None):
"""
self._id = id
self._metadata = metadata
self._metadata_defaults = metadata_defaults
self._interface = interface
self._nodes = nodes
self._outputs = outputs
Expand All @@ -506,6 +540,14 @@ def metadata(self):
"""
return self._metadata

@property
def metadata_defaults(self):
"""
This contains information on how to run the workflow.
:rtype: WorkflowMetadataDefaults
"""
return self._metadata_defaults

@property
def interface(self):
"""
Expand Down Expand Up @@ -552,6 +594,7 @@ def to_flyte_idl(self):
return _core_workflow.WorkflowTemplate(
id=self.id.to_flyte_idl(),
metadata=self.metadata.to_flyte_idl(),
metadata_defaults=self.metadata_defaults.to_flyte_idl(),
interface=self.interface.to_flyte_idl(),
nodes=[n.to_flyte_idl() for n in self.nodes],
outputs=[o.to_flyte_idl() for o in self.outputs],
Expand All @@ -567,6 +610,7 @@ def from_flyte_idl(cls, pb2_object):
return cls(
id=_identifier.Identifier.from_flyte_idl(pb2_object.id),
metadata=WorkflowMetadata.from_flyte_idl(pb2_object.metadata),
metadata_defaults=WorkflowMetadataDefaults.from_flyte_idl(pb2_object.metadata_defaults),
interface=_interface.TypedInterface.from_flyte_idl(pb2_object.interface),
nodes=[Node.from_flyte_idl(n) for n in pb2_object.nodes],
outputs=[_Binding.from_flyte_idl(b) for b in pb2_object.outputs],
Expand Down
14 changes: 13 additions & 1 deletion flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def from_flyte_idl(cls, pb2_object):

class TaskMetadata(_common.FlyteIdlEntity):

def __init__(self, discoverable, runtime, timeout, retries, discovery_version, deprecated_error_message):
def __init__(self, discoverable, runtime, timeout, retries, interruptible, discovery_version, deprecated_error_message):
"""
Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts,
and retries.
Expand All @@ -184,6 +184,7 @@ def __init__(self, discoverable, runtime, timeout, retries, discovery_version, d
:param RuntimeMetadata runtime: Metadata describing the runtime environment for this task.
:param datetime.timedelta timeout: The amount of time to wait before timing out. This includes queuing and
scheduler latency.
:param bool interruptible: Whether or not the task is interruptible.
:param flytekit.models.literals.RetryStrategy retries: Retry strategy for this task. 0 retries means only
try once.
:param Text discovery_version: This is the version used to create a logical version for data in the cache.
Expand All @@ -195,6 +196,7 @@ def __init__(self, discoverable, runtime, timeout, retries, discovery_version, d
self._discoverable = discoverable
self._runtime = runtime
self._timeout = timeout
self._interruptible = interruptible
self._retries = retries
self._discovery_version = discovery_version
self._deprecated_error_message = deprecated_error_message
Expand Down Expand Up @@ -231,6 +233,14 @@ def timeout(self):
"""
return self._timeout

@property
def interruptible(self):
"""
Whether or not the task is interruptible.
:rtype: bool
"""
return self._interruptible

@property
def discovery_version(self):
"""
Expand Down Expand Up @@ -258,6 +268,7 @@ def to_flyte_idl(self):
discoverable=self.discoverable,
runtime=self.runtime.to_flyte_idl(),
retries=self.retries.to_flyte_idl(),
interruptible=self.interruptible,
discovery_version=self.discovery_version,
deprecated_error_message=self.deprecated_error_message
)
Expand All @@ -274,6 +285,7 @@ def from_flyte_idl(cls, pb2_object):
discoverable=pb2_object.discoverable,
runtime=RuntimeMetadata.from_flyte_idl(pb2_object.runtime),
timeout=pb2_object.timeout.ToTimedelta(),
interruptible=pb2_object.interruptible if pb2_object.HasField("interruptible") else None,
retries=_literals.RetryStrategy.from_flyte_idl(pb2_object.retries),
discovery_version=pb2_object.discovery_version,
deprecated_error_message=pb2_object.deprecated_error_message
Expand Down
Loading

0 comments on commit 9f3b48f

Please sign in to comment.