Skip to content

Commit fe5b620

Browse files
Serialization to produce registerable entities (#104)
1 parent 07f8e80 commit fe5b620

File tree

12 files changed

+344
-53
lines changed

12 files changed

+344
-53
lines changed

flytekit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from __future__ import absolute_import
22
import flytekit.plugins
33

4-
__version__ = '0.7.0'
4+
__version__ = '0.7.1b0'

flytekit/clis/sdk_in_container/register.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import absolute_import
22

3+
import logging as _logging
4+
35
import click
46

57
from flytekit.clis.sdk_in_container.constants import CTX_PROJECT, CTX_DOMAIN, CTX_TEST, CTX_PACKAGES, CTX_VERSION
68
from flytekit.common import utils as _utils
9+
from flytekit.common.core import identifier as _identifier
710
from flytekit.common.tasks import task as _task
811
from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag, \
912
IMAGE as _IMAGE
@@ -13,21 +16,31 @@
1316
def register_all(project, domain, pkgs, test, version):
1417
if test:
1518
click.echo('Test switch enabled, not doing anything...')
16-
1719
click.echo('Running task, workflow, and launch plan registration for {}, {}, {} with version {}'.format(
1820
project, domain, pkgs, version))
1921

2022
# m = module (i.e. python file)
2123
# k = value of dir(m), type str
2224
# o = object (e.g. SdkWorkflow)
25+
loaded_entities = []
2326
for m, k, o in iterate_registerable_entities_in_order(pkgs):
2427
name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
25-
28+
_logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in))
29+
o._id = _identifier.Identifier(
30+
o.resource_type,
31+
project,
32+
domain,
33+
name,
34+
version
35+
)
36+
loaded_entities.append(o)
37+
38+
for o in loaded_entities:
2639
if test:
27-
click.echo("Would register {:20} {}".format("{}:".format(o.entity_type_text), name))
40+
click.echo("Would register {:20} {}".format("{}:".format(o.entity_type_text), o.id.name))
2841
else:
29-
click.echo("Registering {:20} {}".format("{}:".format(o.entity_type_text), name))
30-
o.register(project, domain, name, version)
42+
click.echo("Registering {:20} {}".format("{}:".format(o.entity_type_text), o.id.name))
43+
o.register(project, domain, o.id.name, version)
3144

3245

3346
def register_tasks_only(project, domain, pkgs, test, version):
@@ -47,6 +60,7 @@ def register_tasks_only(project, domain, pkgs, test, version):
4760
click.echo("Registering task {:20} {}".format("{}:".format(t.entity_type_text), name))
4861
t.register(project, domain, name, version)
4962

63+
5064
@click.group('register')
5165
# --pkgs on the register group is DEPRECATED, use same arg on pyflyte.main instead
5266
@click.option('--pkgs', multiple=True, help="DEPRECATED. This arg can only be used before the 'register' keyword")

flytekit/clis/sdk_in_container/serialize.py

Lines changed: 164 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,132 @@
11
from __future__ import absolute_import
22
from __future__ import print_function
33

4+
import logging as _logging
5+
import math as _math
6+
import os as _os
7+
48
import click
59

610
from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES, CTX_PROJECT, CTX_DOMAIN, CTX_VERSION
7-
from flytekit.common import workflow as _workflow, utils as _utils
11+
from flytekit.common import utils as _utils
12+
from flytekit.common.core import identifier as _identifier
813
from flytekit.common.exceptions.scopes import system_entry_point
914
from flytekit.common.tasks import task as _sdk_task
1015
from flytekit.common.utils import write_proto_to_file as _write_proto_to_file
1116
from flytekit.configuration import TemporaryConfiguration
12-
from flytekit.configuration.internal import CONFIGURATION_PATH
13-
from flytekit.configuration.internal import IMAGE as _IMAGE
14-
from flytekit.models.workflow_closure import WorkflowClosure as _WorkflowClosure
17+
from flytekit.configuration import internal as _internal_configuration
1518
from flytekit.tools.module_loader import iterate_registerable_entities_in_order
1619

1720

1821
@system_entry_point
19-
def serialize_tasks(pkgs):
20-
# Serialize all tasks
21-
for m, k, t in iterate_registerable_entities_in_order(pkgs, include_entities={_sdk_task.SdkTask}):
22-
fname = '{}.pb'.format(_utils.fqdn(m.__name__, k, entity_type=t.resource_type))
23-
click.echo('Writing task {} to {}'.format(t.id, fname))
24-
pb = t.to_flyte_idl()
25-
_write_proto_to_file(pb, fname)
22+
def serialize_tasks_only(project, domain, pkgs, version, folder=None):
23+
"""
24+
:param Text project:
25+
:param Text domain:
26+
:param list[Text] pkgs:
27+
:param Text version:
28+
:param Text folder:
29+
30+
:return:
31+
"""
32+
# m = module (i.e. python file)
33+
# k = value of dir(m), type str
34+
# o = object (e.g. SdkWorkflow)
35+
loaded_entities = []
36+
for m, k, o in iterate_registerable_entities_in_order(pkgs, include_entities={_sdk_task.SdkTask}):
37+
name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
38+
_logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in))
39+
o._id = _identifier.Identifier(
40+
o.resource_type,
41+
project,
42+
domain,
43+
name,
44+
version
45+
)
46+
loaded_entities.append(o)
47+
48+
zero_padded_length = _determine_text_chars(len(loaded_entities))
49+
for i, entity in enumerate(loaded_entities):
50+
serialized = entity.serialize()
51+
fname_index = str(i).zfill(zero_padded_length)
52+
fname = '{}_{}.pb'.format(fname_index, entity._id.name)
53+
click.echo(' Writing {} to\n {}'.format(entity._id, fname))
54+
_write_proto_to_file(serialized, fname)
55+
56+
identifier_fname = '{}_{}.identifier.pb'.format(fname_index, entity._id.name)
57+
if folder:
58+
identifier_fname = _os.path.join(folder, identifier_fname)
59+
_write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname)
2660

2761

2862
@system_entry_point
29-
def serialize_workflows(pkgs):
30-
# Create map to look up tasks by their unique identifier. This is so we can compile them into the workflow closure.
31-
tmap = {}
32-
for _, _, t in iterate_registerable_entities_in_order(pkgs, include_entities={_sdk_task.SdkTask}):
33-
tmap[t.id] = t
63+
def serialize_all(project, domain, pkgs, version, folder=None):
64+
"""
65+
In order to register, we have to comply with Admin's endpoints. Those endpoints take the following object. These
66+
flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
67+
flyteidl.admin.workflow_pb2.WorkflowSpec
68+
flyteidl.admin.task_pb2.TaskSpec
3469
35-
for m, k, w in iterate_registerable_entities_in_order(pkgs, include_entities={_workflow.SdkWorkflow}):
36-
click.echo('Serializing {}'.format(_utils.fqdn(m.__name__, k, entity_type=w.resource_type)))
37-
task_templates = []
38-
for n in w.nodes:
39-
if n.task_node is not None:
40-
task_templates.append(tmap[n.task_node.reference_id])
70+
However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are:
71+
flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
72+
flyteidl.core.workflow_pb2.WorkflowTemplate
73+
flyteidl.core.tasks_pb2.TaskTemplate
74+
75+
For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects.
76+
77+
:param Text project:
78+
:param Text domain:
79+
:param list[Text] pkgs:
80+
:param Text version:
81+
:param Text folder:
82+
83+
:return:
84+
"""
4185

42-
wc = _WorkflowClosure(workflow=w, tasks=task_templates)
43-
wc_pb = wc.to_flyte_idl()
86+
# m = module (i.e. python file)
87+
# k = value of dir(m), type str
88+
# o = object (e.g. SdkWorkflow)
89+
loaded_entities = []
90+
for m, k, o in iterate_registerable_entities_in_order(pkgs):
91+
name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
92+
_logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in))
93+
o._id = _identifier.Identifier(
94+
o.resource_type,
95+
project,
96+
domain,
97+
name,
98+
version
99+
)
100+
loaded_entities.append(o)
101+
102+
zero_padded_length = _determine_text_chars(len(loaded_entities))
103+
for i, entity in enumerate(loaded_entities):
104+
serialized = entity.serialize()
105+
fname_index = str(i).zfill(zero_padded_length)
106+
fname = '{}_{}.pb'.format(fname_index, entity._id.name)
107+
click.echo(' Writing {} to\n {}'.format(entity._id, fname))
108+
_write_proto_to_file(serialized, fname)
109+
110+
# Not everything serialized will necessarily have an identifier field in it, even though some do (like the
111+
# TaskTemplate). To be more rigorous, we write an explicit identifier file that reflects the choices (like
112+
# project/domain, etc.) made for this serialize call. We should not allow users to specify a different project
113+
# for instance come registration time, to avoid mismatches between potential internal ids like the TaskTemplate
114+
# and the registered entity.
115+
identifier_fname = '{}_{}.identifier.pb'.format(fname_index, entity._id.name)
116+
if folder:
117+
identifier_fname = _os.path.join(folder, identifier_fname)
118+
_write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname)
119+
120+
121+
def _determine_text_chars(length):
122+
"""
123+
This function is used to help prefix files. If there are only 10 entries, then we just need one digit (0-9) to be
124+
the prefix. If there are 11, then we'll need two (00-10).
44125
45-
fname = '{}.pb'.format(_utils.fqdn(m.__name__, k, entity_type=w.resource_type))
46-
click.echo(' Writing workflow closure {}'.format(fname))
47-
_write_proto_to_file(wc_pb, fname)
126+
:param int length:
127+
:rtype: int
128+
"""
129+
return _math.ceil(_math.log(length, 10))
48130

49131

50132
@click.group('serialize')
@@ -57,37 +139,77 @@ def serialize(ctx):
57139
object contains the WorkflowTemplate, along with the relevant tasks for that workflow. In lieu of Admin,
58140
this serialization step will set the URN of the tasks to the fully qualified name of the task function.
59141
"""
60-
click.echo('Serializing Flyte elements with image {}'.format(_IMAGE.get()))
142+
click.echo('Serializing Flyte elements with image {}'.format(_internal_configuration.IMAGE.get()))
61143

62144

63145
@click.command('tasks')
146+
@click.option('-v', '--version', type=str, help='Version to serialize tasks with. This is normally parsed from the'
147+
'image, but you can override here.')
148+
@click.option('-f', '--folder', type=click.Path(exists=True))
64149
@click.pass_context
65-
def tasks(ctx):
150+
def tasks(ctx, version=None, folder=None):
151+
project = ctx.obj[CTX_PROJECT]
152+
domain = ctx.obj[CTX_DOMAIN]
66153
pkgs = ctx.obj[CTX_PACKAGES]
154+
155+
if folder:
156+
click.echo(f"Writing output to {folder}")
157+
158+
version = version or ctx.obj[CTX_VERSION] or _internal_configuration.look_up_version_from_image_tag(
159+
_internal_configuration.IMAGE.get())
160+
67161
internal_settings = {
68-
'project': ctx.obj[CTX_PROJECT],
69-
'domain': ctx.obj[CTX_DOMAIN],
70-
'version': ctx.obj[CTX_VERSION]
162+
'project': project,
163+
'domain': domain,
164+
'version': version,
71165
}
72166
# Populate internal settings for project/domain/version from the environment so that the file names are resolved
73-
# with the correct strings. The file itself doesn't need to change though.
74-
with TemporaryConfiguration(CONFIGURATION_PATH.get(), internal_settings):
75-
serialize_tasks(pkgs)
167+
# with the correct strings. The file itself doesn't need to change though.
168+
with TemporaryConfiguration(_internal_configuration.CONFIGURATION_PATH.get(), internal_settings):
169+
_logging.debug("Serializing with settings\n"
170+
"\n Project: {}"
171+
"\n Domain: {}"
172+
"\n Version: {}"
173+
"\n\nover the following packages {}".format(project, domain, version, pkgs)
174+
)
175+
serialize_tasks_only(project, domain, pkgs, version, folder)
76176

77177

78178
@click.command('workflows')
179+
@click.option('-v', '--version', type=str, help='Version to serialize tasks with. This is normally parsed from the'
180+
'image, but you can override here.')
181+
# For now let's just assume that the directory needs to exist. If you're docker run -v'ing, docker will create the
182+
# directory for you so it shouldn't be a problem.
183+
@click.option('-f', '--folder', type=click.Path(exists=True))
79184
@click.pass_context
80-
def workflows(ctx):
185+
def workflows(ctx, version=None, folder=None):
186+
_logging.getLogger().setLevel(_logging.DEBUG)
187+
188+
if folder:
189+
click.echo(f"Writing output to {folder}")
190+
191+
project = ctx.obj[CTX_PROJECT]
192+
domain = ctx.obj[CTX_DOMAIN]
81193
pkgs = ctx.obj[CTX_PACKAGES]
194+
195+
version = version or ctx.obj[CTX_VERSION] or _internal_configuration.look_up_version_from_image_tag(
196+
_internal_configuration.IMAGE.get())
197+
82198
internal_settings = {
83-
'project': ctx.obj[CTX_PROJECT],
84-
'domain': ctx.obj[CTX_DOMAIN],
85-
'version': ctx.obj[CTX_VERSION]
199+
'project': project,
200+
'domain': domain,
201+
'version': version,
86202
}
87203
# Populate internal settings for project/domain/version from the environment so that the file names are resolved
88-
# with the correct strings. The file itself doesn't need to change though.
89-
with TemporaryConfiguration(CONFIGURATION_PATH.get(), internal_settings):
90-
serialize_workflows(pkgs)
204+
# with the correct strings. The file itself doesn't need to change though.
205+
with TemporaryConfiguration(_internal_configuration.CONFIGURATION_PATH.get(), internal_settings):
206+
_logging.debug("Serializing with settings\n"
207+
"\n Project: {}"
208+
"\n Domain: {}"
209+
"\n Version: {}"
210+
"\n\nover the following packages {}".format(project, domain, version, pkgs)
211+
)
212+
serialize_all(project, domain, pkgs, version, folder)
91213

92214

93215
serialize.add_command(tasks)

flytekit/common/launch_plan.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,14 @@ def register(self, project, domain, name, version):
335335
self._id = id_to_register
336336
return _six.text_type(self.id)
337337

338+
@_exception_scopes.system_entry_point
339+
def serialize(self):
340+
"""
341+
Unlike the SdkWorkflow serialize call, nothing special needs to be done here.
342+
:rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
343+
"""
344+
return self.to_flyte_idl()
345+
338346
@classmethod
339347
def from_flyte_idl(cls, _):
340348
raise _user_exceptions.FlyteAssertion(

flytekit/common/mixins/registerable.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,20 @@ def register(self, project, domain, name, version):
5252
"""
5353
pass
5454

55+
@_abc.abstractmethod
56+
def serialize(self, project, domain, name, version):
57+
"""
58+
Registerable entities also are required to be serialized. This allows flytekit to separate serialization from
59+
the network call to Admin (mostly at least, if a Launch Plan is fetched for instance as part of another
60+
workflow, it will still hit Admin.
61+
62+
:param Text project: The project in which to serialize this task.
63+
:param Text domain: The domain in which to serialize this task.
64+
:param Text name: The name to give this task.
65+
:param Text version: The version in which to serialize this task.
66+
"""
67+
pass
68+
5569
@_abc.abstractproperty
5670
def resource_type(self):
5771
"""

flytekit/common/tasks/task.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ def register(self, project, domain, name, version):
149149
self._id = old_id
150150
raise
151151

152+
@_exception_scopes.system_entry_point
153+
def serialize(self):
154+
"""
155+
:rtype: flyteidl.admin.task_pb2.TaskSpec
156+
"""
157+
return _task_model.TaskSpec(self).to_flyte_idl()
158+
152159
@classmethod
153160
@_exception_scopes.system_entry_point
154161
def fetch(cls, project, domain, name, version):

flytekit/common/workflow.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from flytekit.models.core import workflow as _workflow_models, identifier as _identifier_model
2020
from flytekit.common.exceptions import system as _system_exceptions
2121
from flytekit.common import constants as _constants
22+
from flytekit.models.admin import workflow as _admin_workflow_model
2223

2324

2425
class Output(object):
@@ -286,6 +287,20 @@ def register(self, project, domain, name, version):
286287
self._id = old_id
287288
raise
288289

290+
@_exception_scopes.system_entry_point
291+
def serialize(self):
292+
"""
293+
Serializing a workflow should produce an object similar to what the registration step produces, in preparation
294+
for actual registration to Admin.
295+
296+
:rtype: flyteidl.admin.workflow_pb2.WorkflowSpec
297+
"""
298+
sub_workflows = self.get_sub_workflows()
299+
return _admin_workflow_model.WorkflowSpec(
300+
self,
301+
sub_workflows,
302+
).to_flyte_idl()
303+
289304
@_exception_scopes.system_entry_point
290305
def validate(self):
291306
pass

0 commit comments

Comments
 (0)