Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(backend) Migrated the Persistence Agent controller to use controller-runtime #11582

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
51 changes: 30 additions & 21 deletions backend/src/agent/persistence/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package main

import (
"flag"
"os"
"time"

"github.com/kubeflow/pipelines/backend/src/agent/persistence/client"
Expand All @@ -26,6 +27,7 @@ import (
log "github.com/sirupsen/logrus"
_ "k8s.io/client-go/plugin/pkg/client/auth/gcp"
"k8s.io/client-go/tools/clientcmd"
ctrl "sigs.k8s.io/controller-runtime"
)

var (
Expand Down Expand Up @@ -72,9 +74,16 @@ const (
DefaultSATokenRefresherIntervalInSecs = 60 * 60 // 1 Hour in seconds
)

var (
persistenceAgentFlags = flag.NewFlagSet("persistence_agent",flag.ContinueOnError)
)

func main() {
flag.Parse()
if err := persistenceAgentFlags.Parse(os.Args[1:]); err != nil {
log.Fatalf("Failed to parse flags: %v", err)
}

flag.Parse()
// set up signals so we handle the first shutdown signal gracefully
stopCh := signals.SetupSignalHandler()

Expand Down Expand Up @@ -131,44 +140,44 @@ func main() {
log.Fatalf("Error creating ML pipeline API Server client: %v", err)
}

controller, err := NewPersistenceAgent(
mgr, err := NewPersistenceAgent(
swfInformerFactory,
execInformer,
pipelineClient,
util.NewRealTime())
if err != nil {
log.Fatalf("Failed to instantiate the controller: %v", err)
}

go swfInformerFactory.Start(stopCh)
go execInformer.InformerFactoryStart(stopCh)

if err = controller.Run(numWorker, stopCh); err != nil {
log.Fatalf("Error running controller: %s", err.Error())
if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil {
log.Fatalf("problem running manager: %v", err)
}

}

func init() {
flag.StringVar(&kubeconfig, kubeconfigFlagName, "", "Path to a kubeconfig. Only required if out-of-cluster.")
flag.StringVar(&masterURL, masterFlagName, "", "The address of the Kubernetes API server. Overrides any value in kubeconfig. Only required if out-of-cluster.")
flag.StringVar(&logLevel, logLevelFlagName, "", "Defines the log level for the application.")
flag.DurationVar(&initializeTimeout, initializationTimeoutFlagName, 2*time.Minute, "Duration to wait for initialization of the ML pipeline API server.")
flag.DurationVar(&timeout, timeoutFlagName, 1*time.Minute, "Duration to wait for calls to complete.")
flag.StringVar(&mlPipelineAPIServerName, mlPipelineAPIServerNameFlagName, "ml-pipeline", "Name of the ML pipeline API server.")
flag.StringVar(&mlPipelineServiceHttpPort, mlPipelineAPIServerHttpPortFlagName, "8888", "Http Port of the ML pipeline API server.")
flag.StringVar(&mlPipelineServiceGRPCPort, mlPipelineAPIServerGRPCPortFlagName, "8887", "GRPC Port of the ML pipeline API server.")
flag.StringVar(&mlPipelineAPIServerBasePath, mlPipelineAPIServerBasePathFlagName,
persistenceAgentFlags.StringVar(&kubeconfig, kubeconfigFlagName, "", "Path to a kubeconfig. Only required if out-of-cluster.")
persistenceAgentFlags.StringVar(&masterURL, masterFlagName, "", "The address of the Kubernetes API server. Overrides any value in kubeconfig. Only required if out-of-cluster.")
persistenceAgentFlags.StringVar(&logLevel, logLevelFlagName, "", "Defines the log level for the application.")
persistenceAgentFlags.DurationVar(&initializeTimeout, initializationTimeoutFlagName, 2*time.Minute, "Duration to wait for initialization of the ML pipeline API server.")
persistenceAgentFlags.DurationVar(&timeout, timeoutFlagName, 1*time.Minute, "Duration to wait for calls to complete.")
persistenceAgentFlags.StringVar(&mlPipelineAPIServerName, mlPipelineAPIServerNameFlagName, "ml-pipeline", "Name of the ML pipeline API server.")
persistenceAgentFlags.StringVar(&mlPipelineServiceHttpPort, mlPipelineAPIServerHttpPortFlagName, "8888", "Http Port of the ML pipeline API server.")
persistenceAgentFlags.StringVar(&mlPipelineServiceGRPCPort, mlPipelineAPIServerGRPCPortFlagName, "8887", "GRPC Port of the ML pipeline API server.")
persistenceAgentFlags.StringVar(&mlPipelineAPIServerBasePath, mlPipelineAPIServerBasePathFlagName,
"/apis/v1beta1", "The base path for the ML pipeline API server.")
flag.StringVar(&namespace, namespaceFlagName, "", "The namespace name used for Kubernetes informers to obtain the listers.")
flag.Int64Var(&ttlSecondsAfterWorkflowFinish, ttlSecondsAfterWorkflowFinishFlagName, 604800 /* 7 days */, "The TTL for Argo workflow to persist after workflow finish.")
flag.IntVar(&numWorker, numWorkerName, 2, "Number of worker for sync job.")
persistenceAgentFlags.StringVar(&namespace, namespaceFlagName, "", "The namespace name used for Kubernetes informers to obtain the listers.")
persistenceAgentFlags.Int64Var(&ttlSecondsAfterWorkflowFinish, ttlSecondsAfterWorkflowFinishFlagName, 604800 /* 7 days */, "The TTL for Argo workflow to persist after workflow finish.")
persistenceAgentFlags.IntVar(&numWorker, numWorkerName, 2, "Number of worker for sync job.")
// Use default value of client QPS (5) & burst (10) defined in
// k8s.io/client-go/rest/config.go#RESTClientFor
flag.Float64Var(&clientQPS, clientQPSFlagName, 5, "The maximum QPS to the master from this client.")
flag.IntVar(&clientBurst, clientBurstFlagName, 10, "Maximum burst for throttle from this client.")
flag.StringVar(&executionType, executionTypeFlagName, "Workflow", "Custom Resource's name of the backend Orchestration Engine")
persistenceAgentFlags.Float64Var(&clientQPS, clientQPSFlagName, 5, "The maximum QPS to the master from this client.")
persistenceAgentFlags.IntVar(&clientBurst, clientBurstFlagName, 10, "Maximum burst for throttle from this client.")
persistenceAgentFlags.StringVar(&executionType, executionTypeFlagName, "Workflow", "Custom Resource's name of the backend Orchestration Engine")
// TODO use viper/config file instead. Sync `saTokenRefreshIntervalFlagName` with the value from manifest file by using ENV var.
flag.Int64Var(&saTokenRefreshIntervalInSecs, saTokenRefreshIntervalFlagName, DefaultSATokenRefresherIntervalInSecs, "Persistence agent service account token read interval in seconds. "+
persistenceAgentFlags.Int64Var(&saTokenRefreshIntervalInSecs, saTokenRefreshIntervalFlagName, DefaultSATokenRefresherIntervalInSecs, "Persistence agent service account token read interval in seconds. "+
"Defines how often `/var/run/secrets/kubeflow/tokens/kubeflow-persistent_agent-api-token` to be read")

}
170 changes: 111 additions & 59 deletions backend/src/agent/persistence/persistence_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,106 +15,158 @@
package main

import (
"context"
"fmt"
"time"

workflowregister "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow"
swfScheme "github.com/kubeflow/pipelines/backend/src/crd/pkg/client/clientset/versioned/scheme"
swfinformers "github.com/kubeflow/pipelines/backend/src/crd/pkg/client/informers/externalversions"

"github.com/kubeflow/pipelines/backend/src/agent/persistence/client"
"github.com/kubeflow/pipelines/backend/src/agent/persistence/worker"
"github.com/kubeflow/pipelines/backend/src/common/util"
swfregister "github.com/kubeflow/pipelines/backend/src/crd/pkg/apis/scheduledworkflow"
swfScheme "github.com/kubeflow/pipelines/backend/src/crd/pkg/client/clientset/versioned/scheme"
swfinformers "github.com/kubeflow/pipelines/backend/src/crd/pkg/client/informers/externalversions"
log "github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apimachinery/pkg/util/wait"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes/scheme"
_ "k8s.io/client-go/plugin/pkg/client/auth/gcp"
"k8s.io/client-go/tools/cache"
ctrl "sigs.k8s.io/controller-runtime"
k8sclient "sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
)

// PersistenceAgent is an agent to persist resources to a database.
type PersistenceAgent struct {
swfClient *client.ScheduledWorkflowClient
workflowClient *client.WorkflowClient
swfWorker *worker.PersistenceWorker
workflowWorker *worker.PersistenceWorker
k8sclient.Client
Scheme *runtime.Scheme
logger *log.Logger
swfsaver *worker.ScheduledWorkflowSaver
wfsaver *worker.WorkflowSaver
}

// NewPersistenceAgent returns a new persistence agent.
// NewPersistenceAgent creates a new controller-runtime Manager, sets up the scheme, initializes saver instances,
// creates a reconciler with the persistence logic, and registers the controller with the manager.
// It returns the configured manager.
func NewPersistenceAgent(
swfInformerFactory swfinformers.SharedInformerFactory,
execInformer util.ExecutionInformer,
pipelineClient *client.PipelineClient,
time util.TimeInterface,
) (*PersistenceAgent, error) {
) (ctrl.Manager, error) {
// obtain references to shared informers
swfInformer := swfInformerFactory.Scheduledworkflow().V1beta1().ScheduledWorkflows()

// Add controller types to the default Kubernetes Scheme so Events can be
// logged for controller types.
swfScheme.AddToScheme(scheme.Scheme)
scheme := scheme.Scheme
swfScheme.AddToScheme(scheme)

swfClient := client.NewScheduledWorkflowClient(swfInformer)
workflowClient := client.NewWorkflowClient(execInformer)

swfWorker, err := worker.NewPersistenceWorker(time, swfregister.Kind, swfInformer.Informer(), true,
worker.NewScheduledWorkflowSaver(swfClient, pipelineClient))
mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), ctrl.Options{
Scheme: scheme,
})
if err != nil {
return nil, err
}

workflowWorker, err := worker.NewPersistenceWorker(time, workflowregister.WorkflowKind,
execInformer, true,
worker.NewWorkflowSaver(workflowClient, pipelineClient, ttlSecondsAfterWorkflowFinish))
// Create saver instances.
workflowSaver := worker.NewWorkflowSaver(workflowClient, pipelineClient, ttlSecondsAfterWorkflowFinish)
scheduledWorkflowSaver := worker.NewScheduledWorkflowSaver(swfClient, pipelineClient)

// Initialize the reconciler with saver logic.
reconciler := &PersistenceAgent{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
wfsaver: workflowSaver,
swfsaver: scheduledWorkflowSaver,
logger: log.New(),
}

// Set up the controller to watch both ScheduledWorkflow and Workflow resources.
_ , err = builder.ControllerManagedBy(mgr).
For(&util.ScheduledWorkflow{}).
Owns(&Workflow{}).
Build(reconciler)
if err != nil {
return nil, err
return nil , err
}
return mgr, nil
}

agent := &PersistenceAgent{
swfClient: swfClient,
workflowClient: workflowClient,
swfWorker: swfWorker,
workflowWorker: workflowWorker,
// NewPersistenceAgent returns a new persistence agent.
func (r *PersistenceAgent) Reconcile(
ctx context.Context,
req reconcile.Request,
) (reconcile.Result, error) {
nowEpoch := time.Now().Unix()
// Attempt to fetch a ScheduledWorkflow first.
swf := &util.ScheduledWorkflow{}
err := r.Get(ctx, req.NamespacedName, swf)
if err == nil {
r.logger.Info("Reconciling ScheduledWorkflow", "namespace", swf.Namespace, "name", swf.Name)
key := fmt.Sprintf("%s/%s", swf.Namespace, swf.Name)
if err := r.swfsaver.Save(key, swf.Namespace, swf.Name, nowEpoch); err != nil {
if util.HasCustomCode(err, util.CUSTOM_CODE_TRANSIENT) {
r.logger.Error(err, "Transient error saving ScheduledWorkflow; will retry", "key", key)
return reconcile.Result{RequeueAfter: 1 * time.Second}, err
}
// For permanent errors, log and do not requeue.
r.logger.Error(err, "Permanent error saving ScheduledWorkflow", "key", key)
return reconcile.Result{}, nil
}
// Successfully saved ScheduledWorkflow; nothing more to do.
return reconcile.Result{}, nil
}

log.Info("Setting up event handlers")
// If not found, try to fetch a Workflow.
wf := &Workflow{}
err = r.Get(ctx, req.NamespacedName, wf)
if err == nil {
r.logger.Info("Reconciling Workflow", "namespace", wf.Namespace, "name", wf.Name)
key := fmt.Sprintf("%s/%s", wf.Namespace, wf.Name)
if err := r.wfsaver.Save(key, wf.Namespace, wf.Name, nowEpoch); err != nil {
if util.HasCustomCode(err, util.CUSTOM_CODE_TRANSIENT) {
r.logger.Error(err, "Transient error saving Workflow; will retry", "key", key)
return reconcile.Result{RequeueAfter: 1 * time.Second}, err
}
r.logger.Error(err, "Permanent error saving Workflow", "key", key)
return reconcile.Result{}, nil
}
return reconcile.Result{}, nil
}

return agent, nil
r.logger.Info("Object not found; may have been deleted", "key", req.NamespacedName)
return reconcile.Result{}, nil
}

// Run will set up the event handlers for types we are interested in, as well
// as syncing informer caches and starting workers. It will block until stopCh
// is closed, at which point it will shutdown the workqueue and wait for
// workers to finish processing their current work items.
func (p *PersistenceAgent) Run(threadiness int, stopCh <-chan struct{}) error {
defer runtime.HandleCrash()
defer p.swfWorker.Shutdown()
defer p.workflowWorker.Shutdown()

// Start the informer factories to begin populating the informer caches
log.Info("Starting The persistence agent")

// Wait for the caches to be synced before starting workers
log.Info("Waiting for informer caches to sync")

if ok := cache.WaitForCacheSync(stopCh,
p.workflowClient.HasSynced(),
p.swfClient.HasSynced()); !ok {
return fmt.Errorf("Failed to wait for caches to sync")
}
type Workflow struct {
util.Workflow
}

// Launch multiple workers to process ScheduledWorkflows
log.Info("Starting workers")
for i := 0; i < threadiness; i++ {
go wait.Until(p.swfWorker.RunWorker, time.Second, stopCh)
go wait.Until(p.workflowWorker.RunWorker, time.Second, stopCh)
// Override SetAnnotations on our Workflow type.
// This method will be used instead of util.Workflow's version.
func (w *Workflow) SetAnnotations(val map[string]string) {
if w.Annotations == nil {
w.Annotations = make(map[string]string)
}
log.Info("Started workers")

log.Info("Wait for shut down")
<-stopCh
log.Info("Shutting down workers")
for key, value := range val {
w.Annotations[key] = value
}
}

return nil
// Override SetLabels on our Workflow type.
// This method will be used instead of util.Workflow's version.
func (w *Workflow) SetLabels(val map[string]string) {
if w.Labels == nil {
w.Labels = make(map[string]string)
}
for key, value := range val {
w.Labels[key] = value
}
}

func (w *Workflow) SetOwnerReferences(refs []metav1.OwnerReference) {
w.ObjectMeta.OwnerReferences = refs
}
3 changes: 2 additions & 1 deletion backend/src/crd/controller/scheduledworkflow/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ const (
)

func main() {
initFlags()
flag.Parse()

// set up signals so we handle the first shutdown signal gracefully
Expand Down Expand Up @@ -148,7 +149,7 @@ func initEnv() {
viper.AllowEmptyEnv(true)
}

func init() {
func initFlags() {
initEnv()

flag.StringVar(&logLevel, "logLevel", "", "Defines the log level for the application.")
Expand Down
2 changes: 2 additions & 0 deletions backend/third_party_licenses/persistence_agent.csv
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ gopkg.in/inf.v0,https://github.com/go-inf/inf/blob/v0.9.1/LICENSE,BSD-3-Clause
gopkg.in/yaml.v2,https://github.com/go-yaml/yaml/blob/v2.4.0/LICENSE,Apache-2.0
gopkg.in/yaml.v3,https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE,MIT
k8s.io/api,https://github.com/kubernetes/api/blob/v0.30.1/LICENSE,Apache-2.0
k8s.io/apiextensions-apiserver/pkg/apis/apiextensions,https://github.com/kubernetes/apiextensions-apiserver/blob/v0.30.1/LICENSE,Apache-2.0
k8s.io/apimachinery/pkg,https://github.com/kubernetes/apimachinery/blob/v0.30.1/LICENSE,Apache-2.0
k8s.io/apimachinery/third_party/forked/golang,https://github.com/kubernetes/apimachinery/blob/v0.30.1/third_party/forked/golang/LICENSE,BSD-3-Clause
k8s.io/client-go,https://github.com/kubernetes/client-go/blob/v0.30.1/LICENSE,Apache-2.0
Expand All @@ -123,6 +124,7 @@ k8s.io/kube-openapi/pkg/validation/spec,https://github.com/kubernetes/kube-opena
k8s.io/utils,https://github.com/kubernetes/utils/blob/3b25d923346b/LICENSE,Apache-2.0
k8s.io/utils/internal/third_party/forked/golang/net,https://github.com/kubernetes/utils/blob/3b25d923346b/internal/third_party/forked/golang/LICENSE,BSD-3-Clause
knative.dev/pkg,https://github.com/knative/pkg/blob/56bfe0dd9626/LICENSE,Apache-2.0
sigs.k8s.io/controller-runtime,https://github.com/kubernetes-sigs/controller-runtime/blob/v0.18.6/LICENSE,Apache-2.0
sigs.k8s.io/json,https://github.com/kubernetes-sigs/json/blob/bc3834ca7abd/LICENSE,Apache-2.0
sigs.k8s.io/structured-merge-diff/v4,https://github.com/kubernetes-sigs/structured-merge-diff/blob/v4.4.1/LICENSE,Apache-2.0
sigs.k8s.io/yaml,https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/LICENSE,Apache-2.0
Expand Down
Loading