diff --git a/pkg/webhooks/jax/jaxjob_webhook.go b/pkg/webhooks/jax/jaxjob_webhook.go index 12888b3d3c..430286d8c4 100644 --- a/pkg/webhooks/jax/jaxjob_webhook.go +++ b/pkg/webhooks/jax/jaxjob_webhook.go @@ -19,7 +19,6 @@ package jax import ( "context" "fmt" - "slices" "strings" apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" @@ -31,6 +30,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook/admission" trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/webhooks/utils" ) var ( @@ -84,41 +84,13 @@ func validateSpec(spec trainingoperator.JAXJobSpec) field.ErrorList { } func validateJAXReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { - var allErrs field.ErrorList - - if rSpecs == nil { - allErrs = append(allErrs, field.Required(jaxReplicaSpecPath, "must be required")) - } - for rType, rSpec := range rSpecs { - rolePath := jaxReplicaSpecPath.Key(string(rType)) - containersPath := rolePath.Child("template").Child("spec").Child("containers") - - // Make sure the replica type is valid. - validRoleTypes := []trainingoperator.ReplicaType{ - trainingoperator.JAXJobReplicaTypeWorker, - } - if !slices.Contains(validRoleTypes, rType) { - allErrs = append(allErrs, field.NotSupported(rolePath, rType, validRoleTypes)) - } - - if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { - allErrs = append(allErrs, field.Required(containersPath, "must be specified")) - } - - // Make sure the image is defined in the container - defaultContainerPresent := false - for idx, container := range rSpec.Template.Spec.Containers { - if container.Image == "" { - allErrs = append(allErrs, field.Required(containersPath.Index(idx).Child("image"), "must be required")) - } - if container.Name == trainingoperator.JAXJobDefaultContainerName { - defaultContainerPresent = true - } - } - // Make sure there has at least one container named "jax" - if !defaultContainerPresent { - allErrs = append(allErrs, field.Required(containersPath, fmt.Sprintf("must have at least one container with name %s", trainingoperator.JAXJobDefaultContainerName))) - } + // Make sure the replica type is valid. + validReplicaTypes := []trainingoperator.ReplicaType{ + trainingoperator.JAXJobReplicaTypeWorker, } - return allErrs + + return utils.ValidateReplicaSpecs(rSpecs, + trainingoperator.JAXJobDefaultContainerName, + validReplicaTypes, + jaxReplicaSpecPath) } diff --git a/pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go b/pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go index fedc95b5f7..08b56d1c02 100644 --- a/pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go +++ b/pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go @@ -19,10 +19,8 @@ package paddlepaddle import ( "context" "fmt" - "slices" "strings" - apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/klog/v2" @@ -32,6 +30,8 @@ import ( trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/common/util" + "github.com/kubeflow/training-operator/pkg/webhooks/utils" + apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" ) var ( @@ -85,42 +85,14 @@ func validatePaddleJob(oldJob, newJob *trainingoperator.PaddleJob) field.ErrorLi } func validateSpec(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { - var allErrs field.ErrorList - - if rSpecs == nil { - allErrs = append(allErrs, field.Required(paddleReplicaSpecPath, "must be required")) + // Make sure the replica type is valid. + validReplicaTypes := []trainingoperator.ReplicaType{ + trainingoperator.PaddleJobReplicaTypeMaster, + trainingoperator.PaddleJobReplicaTypeWorker, } - for rType, rSpec := range rSpecs { - rolePath := paddleReplicaSpecPath.Key(string(rType)) - containersPath := rolePath.Child("template").Child("spec").Child("containers") - - // Make sure the replica type is valid. - validReplicaTypes := []trainingoperator.ReplicaType{ - trainingoperator.PaddleJobReplicaTypeMaster, - trainingoperator.PaddleJobReplicaTypeWorker, - } - if !slices.Contains(validReplicaTypes, rType) { - allErrs = append(allErrs, field.NotSupported(rolePath, rType, validReplicaTypes)) - } - - if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { - allErrs = append(allErrs, field.Required(containersPath, "must be specified")) - } - - // Make sure the image is defined in the container - defaultContainerPresent := false - for idx, container := range rSpec.Template.Spec.Containers { - if container.Image == "" { - allErrs = append(allErrs, field.Required(containersPath.Index(idx).Child("image"), "must be required")) - } - if container.Name == trainingoperator.PaddleJobDefaultContainerName { - defaultContainerPresent = true - } - } - // Make sure there has at least one container named "paddle" - if !defaultContainerPresent { - allErrs = append(allErrs, field.Required(containersPath, fmt.Sprintf("must have at least one container with name %q", trainingoperator.PaddleJobDefaultContainerName))) - } - } - return allErrs + + return utils.ValidateReplicaSpecs(rSpecs, + trainingoperator.PaddleJobDefaultContainerName, + validReplicaTypes, + paddleReplicaSpecPath) } diff --git a/pkg/webhooks/pytorch/pytorchjob_webhook.go b/pkg/webhooks/pytorch/pytorchjob_webhook.go index 2459815935..a8dc8fb30c 100644 --- a/pkg/webhooks/pytorch/pytorchjob_webhook.go +++ b/pkg/webhooks/pytorch/pytorchjob_webhook.go @@ -19,10 +19,8 @@ package pytorch import ( "context" "fmt" - "slices" "strings" - apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/klog/v2" @@ -32,6 +30,8 @@ import ( trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/common/util" + "github.com/kubeflow/training-operator/pkg/webhooks/utils" + apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" ) var ( @@ -107,45 +107,21 @@ func validateSpec(spec trainingoperator.PyTorchJobSpec) (admission.Warnings, fie } func validatePyTorchReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { - var allErrs field.ErrorList - - if rSpecs == nil { - allErrs = append(allErrs, field.Required(pytorchReplicaSpecPath, "must be required")) + // Make sure the replica type is valid. + validReplicaTypes := []trainingoperator.ReplicaType{ + trainingoperator.PyTorchJobReplicaTypeMaster, + trainingoperator.PyTorchJobReplicaTypeWorker, } - for rType, rSpec := range rSpecs { - rolePath := pytorchReplicaSpecPath.Key(string(rType)) - containersPath := rolePath.Child("template").Child("spec").Child("containers") - - // Make sure the replica type is valid. - validRoleTypes := []trainingoperator.ReplicaType{ - trainingoperator.PyTorchJobReplicaTypeMaster, - trainingoperator.PyTorchJobReplicaTypeWorker, - } - if !slices.Contains(validRoleTypes, rType) { - allErrs = append(allErrs, field.NotSupported(rolePath, rType, validRoleTypes)) - } - if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { - allErrs = append(allErrs, field.Required(containersPath, "must be specified")) - } + allErrs := utils.ValidateReplicaSpecs(rSpecs, + trainingoperator.PyTorchJobDefaultContainerName, + validReplicaTypes, + pytorchReplicaSpecPath) - // Make sure the image is defined in the container - defaultContainerPresent := false - for idx, container := range rSpec.Template.Spec.Containers { - if container.Image == "" { - allErrs = append(allErrs, field.Required(containersPath.Index(idx).Child("image"), "must be required")) - } - if container.Name == trainingoperator.PyTorchJobDefaultContainerName { - defaultContainerPresent = true - } - } - // Make sure there has at least one container named "pytorch" - if !defaultContainerPresent { - allErrs = append(allErrs, field.Required(containersPath, fmt.Sprintf("must have at least one container with name %s", trainingoperator.PyTorchJobDefaultContainerName))) - } + for rType, rSpec := range rSpecs { if rType == trainingoperator.PyTorchJobReplicaTypeMaster { if rSpec.Replicas == nil || int(*rSpec.Replicas) != 1 { - allErrs = append(allErrs, field.Forbidden(rolePath.Child("replicas"), "must be 1")) + allErrs = append(allErrs, field.Forbidden(pytorchReplicaSpecPath.Key(string(rType)).Child("replicas"), "must be 1")) } } } diff --git a/pkg/webhooks/tensorflow/tfjob_webhook.go b/pkg/webhooks/tensorflow/tfjob_webhook.go index 95f187f44f..248184b362 100644 --- a/pkg/webhooks/tensorflow/tfjob_webhook.go +++ b/pkg/webhooks/tensorflow/tfjob_webhook.go @@ -21,7 +21,6 @@ import ( "fmt" "strings" - apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/klog/v2" @@ -31,6 +30,8 @@ import ( trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/common/util" + "github.com/kubeflow/training-operator/pkg/webhooks/utils" + apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" ) var ( @@ -88,37 +89,16 @@ func validateSpec(spec trainingoperator.TFJobSpec) field.ErrorList { } func validateTFReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { - var allErrs field.ErrorList - - if rSpecs == nil { - allErrs = append(allErrs, field.Required(tfReplicaSpecPath, "must be required")) - } + allErrs := utils.ValidateReplicaSpecs(rSpecs, + trainingoperator.TFJobDefaultContainerName, + nil, + tfReplicaSpecPath) chiefOrMaster := 0 - for rType, rSpec := range rSpecs { - rolePath := tfReplicaSpecPath.Key(string(rType)) - containerPath := rolePath.Child("template").Child("spec").Child("containers") - - if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { - allErrs = append(allErrs, field.Required(containerPath, "must be specified")) - } + for rType := range rSpecs { if trainingoperator.IsChiefOrMaster(rType) { chiefOrMaster++ } - // Make sure the image is defined in the container. - defaultContainerPresent := false - for idx, container := range rSpec.Template.Spec.Containers { - if container.Image == "" { - allErrs = append(allErrs, field.Required(containerPath.Index(idx).Child("image"), "must be required")) - } - if container.Name == trainingoperator.TFJobDefaultContainerName { - defaultContainerPresent = true - } - } - // Make sure there has at least one container named "tensorflow". - if !defaultContainerPresent { - allErrs = append(allErrs, field.Required(containerPath, fmt.Sprintf("must have at least one container with name %s", trainingoperator.TFJobDefaultContainerName))) - } } if chiefOrMaster > 1 { allErrs = append(allErrs, field.Forbidden(tfReplicaSpecPath, "must not have more than 1 Chief or Master role")) diff --git a/pkg/webhooks/utils/utils.go b/pkg/webhooks/utils/utils.go new file mode 100644 index 0000000000..10294e361c --- /dev/null +++ b/pkg/webhooks/utils/utils.go @@ -0,0 +1,55 @@ +package utils + +import ( + "fmt" + "slices" + + trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + + "k8s.io/apimachinery/pkg/util/validation/field" +) + +func ValidateReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec, + defaultContainerName string, + validReplicaTypes []trainingoperator.ReplicaType, + replicaSpecPath *field.Path) field.ErrorList { + + var allErrs field.ErrorList + + if rSpecs == nil { + allErrs = append(allErrs, field.Required(replicaSpecPath, "must be required")) + } + + for rType, rSpec := range rSpecs { + rolePath := replicaSpecPath.Key(string(rType)) + containersPath := rolePath.Child("template").Child("spec").Child("containers") + + if len(validReplicaTypes) > 0 { + if !slices.Contains(validReplicaTypes, rType) { + allErrs = append(allErrs, field.NotSupported(rolePath, rType, validReplicaTypes)) + } + } + + if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { + allErrs = append(allErrs, field.Required(containersPath, "must be specified")) + } + + // Make sure the image is defined in the container + defaultContainerPresent := false + for idx, container := range rSpec.Template.Spec.Containers { + if container.Image == "" { + allErrs = append(allErrs, field.Required(containersPath.Index(idx).Child("image"), "must be required")) + } + if container.Name == defaultContainerName { + defaultContainerPresent = true + } + } + + if !defaultContainerPresent { + allErrs = append(allErrs, field.Required(containersPath, + fmt.Sprintf("must have at least one container with name %s", defaultContainerName))) + } + } + + return allErrs +} diff --git a/pkg/webhooks/xgboost/xgboostjob_webhook.go b/pkg/webhooks/xgboost/xgboostjob_webhook.go index 5372317487..3a34d928d4 100644 --- a/pkg/webhooks/xgboost/xgboostjob_webhook.go +++ b/pkg/webhooks/xgboost/xgboostjob_webhook.go @@ -19,7 +19,6 @@ package xgboost import ( "context" "fmt" - "slices" "strings" apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" @@ -32,6 +31,7 @@ import ( trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/common/util" + "github.com/kubeflow/training-operator/pkg/webhooks/utils" ) var ( @@ -89,47 +89,23 @@ func validateSpec(spec trainingoperator.XGBoostJobSpec) field.ErrorList { } func validateXGBReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { - var allErrs field.ErrorList - - if rSpecs == nil { - allErrs = append(allErrs, field.Required(xgbReplicaSpecPath, "must be required")) + // Make sure the replica type is valid. + validReplicaTypes := []trainingoperator.ReplicaType{ + trainingoperator.XGBoostJobReplicaTypeMaster, + trainingoperator.XGBoostJobReplicaTypeWorker, } - masterExists := false - for rType, rSpec := range rSpecs { - rolePath := xgbReplicaSpecPath.Key(string(rType)) - containersPath := rolePath.Child("template").Child("spec").Child("containers") - - // Make sure the replica type is valid. - validReplicaTypes := []trainingoperator.ReplicaType{ - trainingoperator.XGBoostJobReplicaTypeMaster, - trainingoperator.XGBoostJobReplicaTypeWorker, - } - if !slices.Contains(validReplicaTypes, rType) { - allErrs = append(allErrs, field.NotSupported(rolePath, rType, validReplicaTypes)) - } - if rSpec == nil || len(rSpec.Template.Spec.Containers) == 0 { - allErrs = append(allErrs, field.Required(containersPath, "must be specified")) - } + allErrs := utils.ValidateReplicaSpecs(rSpecs, + trainingoperator.XGBoostJobDefaultContainerName, + validReplicaTypes, + xgbReplicaSpecPath) - // Make sure the image is defined in the container - defaultContainerPresent := false - for idx, container := range rSpec.Template.Spec.Containers { - if container.Image == "" { - allErrs = append(allErrs, field.Required(containersPath.Index(idx).Child("image"), "must be required")) - } - if container.Name == trainingoperator.XGBoostJobDefaultContainerName { - defaultContainerPresent = true - } - } - // Make sure there has at least one container named "xgboost" - if !defaultContainerPresent { - allErrs = append(allErrs, field.Required(containersPath, fmt.Sprintf("must have at least one container with name %s", trainingoperator.XGBoostJobDefaultContainerName))) - } + masterExists := false + for rType, rSpec := range rSpecs { if rType == trainingoperator.XGBoostJobReplicaTypeMaster { masterExists = true if rSpec.Replicas == nil || int(*rSpec.Replicas) != 1 { - allErrs = append(allErrs, field.Forbidden(rolePath.Child("replicas"), "must be 1")) + allErrs = append(allErrs, field.Forbidden(xgbReplicaSpecPath.Key(string(rType)).Child("replicas"), "must be 1")) } } }