-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpipeline.py
71 lines (65 loc) · 1.78 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from kfp import dsl
from kfp.aws import use_aws_secret
from helpers.kfp_auth import KfpAuth
@dsl.pipeline(name="CIFAR Pytorch", description="hello world")
def cifar_pipeline(
num_epochs: int,
batch_size: int,
learning_rate: float,
momentum: float,
bucket: str,
path: str,
):
train_op = (
dsl.ContainerOp(
name="Model Train",
image="hermesribeiro/cifar:latest",
command=["python", "train.py"],
arguments=[
"-n",
num_epochs,
"-b",
batch_size,
"-l",
learning_rate,
"-m",
momentum,
"-u",
bucket,
"-p",
path,
],
)
.apply(use_aws_secret())
.set_image_pull_policy("Always")
.set_memory_request('2G')
.set_cpu_request('4')
# .set_gpu_limit('1', 'nvidia')
# .add_volume_mount(...)
# .add_env_variable(V1EnvVar(name='HOST', value='foo.bar'))
# .set_retry(10)
)
eval_op = (
dsl.ContainerOp(
name="Model Eval",
image="hermesribeiro/cifar:latest",
command=["python", "eval.py"],
arguments=["-u", bucket, "-p", path],
)
.apply(use_aws_secret())
.set_image_pull_policy("Always")
.after(train_op)
)
if __name__ == "__main__":
client = KfpAuth().client()
client.create_run_from_pipeline_func(
cifar_pipeline,
arguments={
"num_epochs": 2,
"batch_size": 4,
"learning_rate": 0.001,
"momentum": 0.0,
"bucket": "hermes-freestyle",
"path": "cifar/cifar_net.pth",
}
)