Skip to content

Commit f329ea1

Browse files
committed
Add some tests (#819)
Signed-off-by: Yee Hing Tong <[email protected]>
1 parent d9f5106 commit f329ea1

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

tests/flytekit/unit/core/test_imperative.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,29 @@ def t2(a: typing.List[int]) -> int:
212212
assert wb() == [3, 6]
213213

214214

215+
def test_imperative_tuples():
216+
@task
217+
def t1() -> (int, str):
218+
return 3, "three"
219+
220+
@task
221+
def t3(a: int, b: str) -> typing.Tuple[int, str]:
222+
return a + 2, "world" + b
223+
224+
wb = ImperativeWorkflow(name="my.workflow.a")
225+
t1_node = wb.add_entity(t1)
226+
t3_node = wb.add_entity(t3, a=t1_node.outputs["o0"], b=t1_node.outputs["o1"])
227+
wb.add_workflow_output("wf0", t3_node.outputs["o0"], python_type=int)
228+
wb.add_workflow_output("wf1", t3_node.outputs["o1"], python_type=str)
229+
res = wb()
230+
assert res == (5, "worldthree")
231+
232+
with pytest.raises(KeyError):
233+
wb = ImperativeWorkflow(name="my.workflow.b")
234+
t1_node = wb.add_entity(t1)
235+
wb.add_entity(t3, a=t1_node.outputs["bad"], b=t1_node.outputs["o2"])
236+
237+
215238
def test_call_normal():
216239
@task
217240
def t1(a: int) -> (int, str):
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import tempfile
2+
from collections import OrderedDict
3+
4+
import mock
5+
6+
from flytekit import ContainerTask, kwtypes
7+
from flytekit.core import context_manager
8+
from flytekit.core.context_manager import Image, ImageConfig
9+
from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask, TaskTemplateResolver
10+
from flytekit.core.utils import write_proto_to_file
11+
from flytekit.tools.translator import get_serializable
12+
13+
default_img = Image(name="default", fqn="test", tag="tag")
14+
serialization_settings = context_manager.SerializationSettings(
15+
project="project",
16+
domain="domain",
17+
version="version",
18+
env=None,
19+
image_config=ImageConfig(default_image=default_img, images=[default_img]),
20+
)
21+
22+
23+
class Placeholder(object):
24+
...
25+
26+
27+
def test_resolver_load_task():
28+
# any task is fine, just copied one
29+
square = ContainerTask(
30+
name="square",
31+
input_data_dir="/var/inputs",
32+
output_data_dir="/var/outputs",
33+
inputs=kwtypes(val=int),
34+
outputs=kwtypes(out=int),
35+
image="alpine",
36+
command=["sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"],
37+
)
38+
39+
resolver = TaskTemplateResolver()
40+
ts = get_serializable(OrderedDict(), serialization_settings, square)
41+
with tempfile.NamedTemporaryFile() as f:
42+
write_proto_to_file(ts.template.to_flyte_idl(), f.name)
43+
# load_task should create an instance of the path to the object given, doesn't need to be a real executor
44+
shim_task = resolver.load_task([f.name, f"{Placeholder.__module__}.Placeholder"])
45+
assert isinstance(shim_task.executor, Placeholder)
46+
assert shim_task.task_template.id.name == "square"
47+
assert shim_task.task_template.interface.inputs["val"] is not None
48+
assert shim_task.task_template.interface.outputs["out"] is not None
49+
50+
51+
@mock.patch("flytekit.core.python_customized_container_task.PythonCustomizedContainerTask.get_config")
52+
@mock.patch("flytekit.core.python_customized_container_task.PythonCustomizedContainerTask.get_custom")
53+
def test_serialize_to_model(mock_custom, mock_config):
54+
mock_custom.return_value = {"a": "custom"}
55+
mock_config.return_value = {"a": "config"}
56+
ct = PythonCustomizedContainerTask(
57+
name="mytest", task_config=None, container_image="someimage", executor_type=Placeholder
58+
)
59+
tt = ct.serialize_to_model(serialization_settings)
60+
assert tt.container.image == "someimage"
61+
assert len(tt.config) == 1
62+
assert tt.id.name == "mytest"
63+
assert len(tt.custom) == 1

0 commit comments

Comments
 (0)