Skip to content

Commit 3860d3d

Browse files
authored
Add volume and volume mounts arguments to TrainingClient.create_job API (#2449)
Signed-off-by: Antonin Stefanutti <[email protected]>
1 parent 078ec30 commit 3860d3d

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

sdk/python/kubeflow/training/api/training_client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ def create_job(
354354
env_vars: Optional[
355355
Union[Dict[str, str], List[Union[models.V1EnvVar, models.V1EnvVar]]]
356356
] = None,
357+
volumes: Optional[List[models.V1Volume]] = None,
358+
volume_mounts: Optional[List[models.V1VolumeMount]] = None,
357359
):
358360
"""Create the Training Job.
359361
Job can be created using one of the following options:
@@ -418,6 +420,8 @@ def create_job(
418420
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
419421
or a kubernetes.client.models.V1EnvFromSource (documented here:
420422
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
423+
volumes: Volume(s) to be attached to the replicas.
424+
volume_mounts: VolumeMount(s) specifying where to mount the volume(s) into the replicas.
421425
422426
Raises:
423427
ValueError: Invalid input parameters.
@@ -448,6 +452,12 @@ def create_job(
448452
f"Job kind must be one of these: {constants.JOB_PARAMETERS.keys()}"
449453
)
450454

455+
if len(volumes or []) != len(volume_mounts or []):
456+
raise ValueError(
457+
"Volumes and VolumeMounts must be the same length: "
458+
f"{len(volumes or [])} vs. {len(volume_mounts or [])}"
459+
)
460+
451461
# If Training function or base image is set, configure Job template.
452462
if job is None and (train_func is not None or base_image is not None):
453463
# Job name must be set to configure Job template.
@@ -496,11 +506,13 @@ def create_job(
496506
args=args,
497507
resources=resources_per_worker,
498508
env_vars=env_vars,
509+
volume_mounts=volume_mounts,
499510
)
500511

501512
# Get Pod template spec using the above container.
502513
pod_template_spec = utils.get_pod_template_spec(
503514
containers=[container_spec],
515+
volumes=volumes,
504516
)
505517

506518
# Configure template for different Jobs.

sdk/python/kubeflow/training/api/training_client_test.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
V1ObjectMeta,
2323
V1PodSpec,
2424
V1PodTemplateSpec,
25+
V1Volume,
26+
V1VolumeMount,
2527
)
2628

2729
TEST_NAME = "test"
@@ -142,6 +144,8 @@ def create_job(
142144
args=None,
143145
num_workers=2,
144146
env_vars=None,
147+
volumes=None,
148+
volume_mounts=None,
145149
):
146150
# Handle env_vars as either a dict or a list
147151
if env_vars:
@@ -158,6 +162,7 @@ def create_job(
158162
command=command,
159163
args=args,
160164
env=env_vars,
165+
volume_mounts=volume_mounts,
161166
)
162167

163168
master = KubeflowOrgV1ReplicaSpec(
@@ -166,7 +171,10 @@ def create_job(
166171
metadata=V1ObjectMeta(
167172
annotations={constants.ISTIO_SIDECAR_INJECTION: "false"}
168173
),
169-
spec=V1PodSpec(containers=[container]),
174+
spec=V1PodSpec(
175+
containers=[container],
176+
volumes=volumes,
177+
),
170178
),
171179
)
172180

@@ -180,7 +188,10 @@ def create_job(
180188
metadata=V1ObjectMeta(
181189
annotations={constants.ISTIO_SIDECAR_INJECTION: "false"}
182190
),
183-
spec=V1PodSpec(containers=[container]),
191+
spec=V1PodSpec(
192+
containers=[container],
193+
volumes=volumes,
194+
),
184195
),
185196
)
186197

@@ -530,6 +541,35 @@ def __init__(self):
530541
env_vars=[V1EnvVar(name="ENV_VAR", value="env_value")], num_workers=2
531542
),
532543
),
544+
(
545+
"create job with a volume and a volume mount",
546+
{
547+
"name": TEST_NAME,
548+
"namespace": TEST_NAME,
549+
"base_image": TEST_IMAGE,
550+
"num_workers": 1,
551+
"volumes": [V1Volume(name="vol")],
552+
"volume_mounts": [V1VolumeMount(name="vol", mount_path="/mnt")],
553+
},
554+
SUCCESS,
555+
create_job(
556+
num_workers=1,
557+
volumes=[V1Volume(name="vol")],
558+
volume_mounts=[V1VolumeMount(name="vol", mount_path="/mnt")],
559+
),
560+
),
561+
(
562+
"invalid number of volume mount",
563+
{
564+
"name": TEST_NAME,
565+
"namespace": TEST_NAME,
566+
"base_image": TEST_IMAGE,
567+
"num_workers": 1,
568+
"volumes": [V1Volume(name="vol")],
569+
},
570+
ValueError,
571+
None,
572+
),
533573
]
534574

535575
test_data_get_job_pods = [

0 commit comments

Comments
 (0)