Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] KEP-2170: Implement MPI Plugin for Kubeflow Trainer #2394

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api.v2/openapi-spec/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@
"type": "boolean"
},
"sshAuthMountPath": {
"description": "Directory where SSH keys are mounted.",
"description": "Directory where SSH keys are mounted. Defaults to /root/.ssh.",
"type": "string"
}
}
Expand Down
9 changes: 5 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.9.0
go.uber.org/zap v1.27.0
golang.org/x/crypto v0.31.0
k8s.io/api v0.31.3
k8s.io/apimachinery v0.31.3
k8s.io/client-go v0.31.3
Expand Down Expand Up @@ -70,10 +71,10 @@ require (
golang.org/x/mod v0.20.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/term v0.25.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/term v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/time v0.6.0 // indirect
golang.org/x/tools v0.24.0 // indirect
gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect
Expand Down
18 changes: 10 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
Expand All @@ -135,20 +137,20 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ spec:
Defaults to false.
type: boolean
sshAuthMountPath:
description: Directory where SSH keys are mounted.
description: |-
Directory where SSH keys are mounted.
Defaults to /root/.ssh.
type: string
type: object
numNodes:
Expand Down
4 changes: 3 additions & 1 deletion manifests/v2/base/crds/kubeflow.org_trainingruntimes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ spec:
Defaults to false.
type: boolean
sshAuthMountPath:
description: Directory where SSH keys are mounted.
description: |-
Directory where SSH keys are mounted.
Defaults to /root/.ssh.
type: string
type: object
numNodes:
Expand Down
2 changes: 2 additions & 0 deletions manifests/v2/base/manager/manager.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ spec:
spec:
containers:
- name: manager
args:
- "--zap-log-level=5"
image: kubeflow/training-operator-v2
env:
- name: MY_POD_NAMESPACE
Expand Down
2 changes: 2 additions & 0 deletions manifests/v2/base/rbac/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ rules:
- apiGroups:
- ""
resources:
- configmaps
- secrets
verbs:
- create
- get
- list
- update
Expand Down
1 change: 1 addition & 0 deletions manifests/v2/base/runtimes/pre-training/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ apiVersion: kustomize.config.k8s.io/v1beta1
kind: Kustomization
resources:
- torch-distributed.yaml
- mpi-distributed.yaml
45 changes: 45 additions & 0 deletions manifests/v2/base/runtimes/pre-training/mpi-distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# TODO (andreyvelich): Change this to DeepSpeed or MLX runtime.
apiVersion: kubeflow.org/v2alpha1
kind: ClusterTrainingRuntime
metadata:
name: mpi-distributed
labels:
training.kubeflow.org/phase: pre-training
spec:
mlPolicy:
numNodes: 1
mpi:
numProcPerNode: 1
mpiImplementation: OpenMPI
sshAuthMountPath: /root/.ssh
template:
spec:
# TODO (andreyvelich): Use dependsOn when it is released.
startupPolicy:
startupPolicyOrder: InOrder
replicatedJobs:
- name: launcher
template:
spec:
template:
spec:
# TODO (andreyvelich): Change the command with mpirun.
containers:
- name: launcher
image: busybox
command:
- /bin/sh
- -c
- "echo 'launcher runs for 10 seconds' && sleep 100"
- name: trainer-node
template:
spec:
template:
spec:
containers:
- name: trainer
image: busybox
command:
- /bin/sh
- -c
- "echo 'launcher runs for 10 seconds' && sleep 100"
4 changes: 3 additions & 1 deletion manifests/v2/overlays/standalone/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ resources:
- https://github.com/kubernetes-sigs/jobset/releases/download/v0.6.0/manifests.yaml
images:
- name: kubeflow/training-operator-v2
newTag: latest
# TODO (andreyvelich): Change it back.
newName: training-operator-local
newTag: v1
secretGenerator:
- name: training-operator-v2-webhook-cert
namespace: kubeflow-system
Expand Down
5 changes: 3 additions & 2 deletions pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,11 @@ type MPIMLPolicySource struct {

// Implementation name for the MPI to create the appropriate hostfile.
// Defaults to OpenMPI.
MPIImplementation *MPIImplementation `json:"mpiImplementation,omitempty"`
MPIImplementation MPIImplementation `json:"mpiImplementation,omitempty"`

// Directory where SSH keys are mounted.
SSHAuthMountPath *string `json:"sshAuthMountPath,omitempty"`
// Defaults to /root/.ssh.
SSHAuthMountPath string `json:"sshAuthMountPath,omitempty"`

// Whether to run training process on the launcher Job.
// Defaults to false.
Expand Down
10 changes: 0 additions & 10 deletions pkg/apis/kubeflow.org/v2alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/apis/kubeflow.org/v2alpha1/zz_generated.openapi.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 55 additions & 15 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ const (
// PodGroupKind is the Kind name for the PodGroup.
PodGroupKind string = "PodGroup"

// TrainJobJobsCreationSucceededMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationSucceeded"} condition.
TrainJobJobsCreationSucceededMessage = "Succeeded to create Jobs"

// TrainJobJobsBuildFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsBuildFailed"} condition.
TrainJobJobsBuildFailedMessage = "Failed to build Jobs"

// TrainJobJobsCreationFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationFailed"} condition.
TrainJobJobsCreationFailedMessage = "Failed to create Jobs"

// TrainJobSuspendedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Suspended"} condition.
TrainJobSuspendedMessage = "TrainJob is suspended"

// TrainJobResumedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Resumed"} condition.
TrainJobResumedMessage = "TrainJob is resumed"

// Distributed envs for torchrun.
// Ref: https://github.com/pytorch/pytorch/blob/3a0d0885171376ed610c8175a19ba40411fc6f3f/torch/distributed/argparse_util.py#L45
// TorchEnvNumNodes is the env name for the number of training nodes.
Expand All @@ -52,25 +72,45 @@ const (
// TorchEnvMasterPort is the env name for the master node port.
TorchEnvMasterPort string = "PET_MASTER_PORT"

// TrainJobJobsCreationSucceededMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationSucceeded"} condition.
TrainJobJobsCreationSucceededMessage = "Succeeded to create Jobs"
// JobLauncher is the Job name for the launcher.
JobLauncher string = "launcher"

// TrainJobJobsBuildFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsBuildFailed"} condition.
TrainJobJobsBuildFailedMessage = "Failed to build Jobs"
// ContainerLauncher is the container name for the launcher.
ContainerLauncher string = "launcher"

// TrainJobJobsCreationFailedMessage is status condition message for the
// {"type": "Created", "status": "True", "reason": "JobsCreationFailed"} condition.
TrainJobJobsCreationFailedMessage = "Failed to create Jobs"
// MPISSHAuthSecretSuffix is the name suffix for Secret with MPI SSH keys.
MPISSHAuthSecretSuffix string = "-mpi-ssh-auth"

// TrainJobSuspendedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Suspended"} condition.
TrainJobSuspendedMessage = "TrainJob is suspended"
// MPISSHAuthVolumeName is the volume name for Secret with MPI SSH keys.
MPISSHAuthVolumeName string = "mpi-ssh-auth"

// TrainJobResumedMessage is status condition message for the
// {"type": "Suspended", "status": "True", "reason": "Resumed"} condition.
TrainJobResumedMessage = "TrainJob is resumed"
// MPISSHPrivateKeyFile is the file name for the private key.
MPISSHPrivateKeyFile string = "id_rsa"

// MPISSHPublicKey is the value in Secret data for the public key.
MPISSHPublicKey string = "ssh-publickey"

// MPISSHPublicKeyFile is the file name for the public key.
MPISSHPublicKeyFile string = MPISSHPrivateKeyFile + ".pub"

// MPISSHAuthorizedKeys is the file name for authorized keys.
MPISSHAuthorizedKeys string = "authorized_keys"

// MPIHostfilePath is the directory for the MPI hostfile.
MPIHostfileDir string = "/etc/mpi"

// MPIHostfileName is the file name for the MPI hostfile.
MPIHostfileName string = "hostfile"

// MPIHostfileConfigMapSuffix is the name suffix for ConfigMap with MPI hostfile.
MPIHostfileConfigMapSuffix string = "-mpi-hostfile"

// MPIHostfileVolumeName is the volume name for ConfigMap with MPI hostfile.
MPIHostfileVolumeName string = "mpi-hostfile"

// Distributed envs for mpirun.
// Values for OpenMPI implementation.
OpenMPIEnvHostFileLocation string = "OMPI_MCA_orte_default_hostfile"
)

var (
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime.v2/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, runtimeJobTe
return nil, err
}
if obj != nil {
objs = append(objs, obj)
objs = append(objs, obj...)
}
}
return objs, nil
Expand Down
12 changes: 6 additions & 6 deletions pkg/runtime.v2/framework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ type Plugin interface {
Name() string
}

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
}

type WatchExtensionPlugin interface {
Plugin
ReconcilerBuilders() []runtime.ReconcilerBuilder
Expand All @@ -47,14 +52,9 @@ type EnforceMLPolicyPlugin interface {
EnforceMLPolicy(info *runtime.Info, trainJob *kubeflowv2.TrainJob) error
}

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
}

type ComponentBuilderPlugin interface {
Plugin
Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *kubeflowv2.TrainJob) (client.Object, error)
Build(ctx context.Context, runtimeJobTemplate client.Object, info *runtime.Info, trainJob *kubeflowv2.TrainJob) ([]client.Object, error)
}

type TerminalConditionPlugin interface {
Expand Down
Loading
Loading