diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index 5ce4d8983..3ef7eb1ab 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -20,21 +20,19 @@ import ( "context" "crypto/tls" "errors" - goflag "flag" + "flag" "fmt" "net/http" "net/http/pprof" "os" "regexp" "runtime" - "strconv" - "strings" "sync/atomic" "time" "github.com/go-logr/logr" "github.com/prometheus/client_golang/prometheus" - flag "github.com/spf13/pflag" + "github.com/spf13/pflag" uberzap "go.uber.org/zap" "go.uber.org/zap/zapcore" "google.golang.org/grpc" @@ -42,7 +40,6 @@ import ( "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" - "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/rest" ctrl "sigs.k8s.io/controller-runtime" @@ -79,7 +76,6 @@ import ( testfilter "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/test/filter" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" "sigs.k8s.io/gateway-api-inference-extension/version" ) @@ -112,48 +108,6 @@ var flowControlConfig = flowcontrol.Config{ } var ( - grpcPort = flag.Int("grpc-port", runserver.DefaultGrpcPort, "The gRPC port used for communicating with Envoy proxy") - grpcHealthPort = flag.Int("grpc-health-port", runserver.DefaultGrpcHealthPort, "The port used for gRPC liveness and readiness probes") - metricsPort = flag.Int("metrics-port", runserver.DefaultMetricsPort, "The metrics port") - metricsEndpointAuth = flag.Bool("metrics-endpoint-auth", true, "Enables authentication and authorization of the metrics endpoint") - enablePprof = flag.Bool("enable-pprof", runserver.DefaultEnablePprof, "Enables pprof handlers. Defaults to true. Set to false to disable pprof handlers.") - poolName = flag.String("pool-name", runserver.DefaultPoolName, "Name of the InferencePool this Endpoint Picker is associated with.") - poolGroup = flag.String("pool-group", runserver.DefaultPoolGroup, "group of the InferencePool this Endpoint Picker is associated with.") - poolNamespace = flag.String("pool-namespace", "", "Namespace of the InferencePool this Endpoint Picker is associated with.") - endpointSelector = flag.String("endpoint-selector", "", "selector to filter model server pods on, only key=value paris is supported. Format: a comma-separated list of key value paris, e.g., 'app=vllm-llama3-8b-instruct,env=prod'.") - endpointTargetPorts = flag.String("endpoint-target-ports", "", "target ports of model server pods. Format: a comma-separated list of numbers, e.g., '3000,3001,3002'") - logVerbosity = flag.Int("v", logging.DEFAULT, "number for the log level verbosity") - secureServing = flag.Bool("secure-serving", runserver.DefaultSecureServing, "Enables secure serving. Defaults to true.") - healthChecking = flag.Bool("health-checking", runserver.DefaultHealthChecking, "Enables health checking") - certPath = flag.String("cert-path", runserver.DefaultCertPath, "The path to the certificate for secure serving. The certificate and private key files "+ - "are assumed to be named tls.crt and tls.key, respectively. If not set, and secureServing is enabled, "+ - "then a self-signed certificate is used.") - enableCertReload = flag.Bool("enable-cert-reload", runserver.DefaultCertReload, "Enables certificate reloading of the certificates specified in --cert-path") - // metric flags - totalQueuedRequestsMetric = flag.String("total-queued-requests-metric", runserver.DefaultTotalQueuedRequestsMetric, "Prometheus metric for the number of queued requests.") - totalRunningRequestsMetric = flag.String("total-running-requests-metric", runserver.DefaultTotalRunningRequestsMetric, "Prometheus metric for the number of running requests.") - kvCacheUsagePercentageMetric = flag.String("kv-cache-usage-percentage-metric", runserver.DefaultKvCacheUsagePercentageMetric, "Prometheus metric for the fraction of KV-cache blocks currently in use (from 0 to 1).") - // LoRA metrics - loraInfoMetric = flag.String("lora-info-metric", runserver.DefaultLoraInfoMetric, "Prometheus metric for the LoRA info metrics (must be in vLLM label format).") - // Cache info metrics - cacheInfoMetric = flag.String("cache-info-metric", runserver.DefaultCacheInfoMetric, "Prometheus metric for the cache info metrics.") - // metrics related flags - refreshMetricsInterval = flag.Duration("refresh-metrics-interval", runserver.DefaultRefreshMetricsInterval, "interval to refresh metrics") - refreshPrometheusMetricsInterval = flag.Duration("refresh-prometheus-metrics-interval", runserver.DefaultRefreshPrometheusMetricsInterval, "interval to flush prometheus metrics") - metricsStalenessThreshold = flag.Duration("metrics-staleness-threshold", runserver.DefaultMetricsStalenessThreshold, "Duration after which metrics are considered stale. This is used to determine if a pod's metrics are fresh enough.") - // configuration flags - configFile = flag.String("config-file", runserver.DefaultConfigFile, "The path to the configuration file") - configText = flag.String("config-text", runserver.DefaultConfigText, "The configuration specified as text, in lieu of a file") - - modelServerMetricsPort = flag.Int("model-server-metrics-port", 0, "[DEPRECATED] Port to scrape metrics from pods. "+ - "Default value will be set to the InferencePool.Spec.TargetPorts[0].Number if not set."+ - "This option will be removed in the next release.") - modelServerMetricsPath = flag.String("model-server-metrics-path", "/metrics", "Path to scrape metrics from pods") - modelServerMetricsScheme = flag.String("model-server-metrics-scheme", "http", "Scheme to scrape metrics from pods") - modelServerMetricsHttpsInsecureSkipVerify = flag.Bool("model-server-metrics-https-insecure-skip-verify", true, "When using 'https' scheme for 'model-server-metrics-scheme', configure 'InsecureSkipVerify' (default to true)") - haEnableLeaderElection = flag.Bool("ha-enable-leader-election", false, "Enables leader election for high availability. When enabled, readiness probes will only pass on the leader.") - tracing = flag.Bool("tracing", true, "Enables emitting traces") - setupLog = ctrl.Log.WithName("setup") ) @@ -198,18 +152,27 @@ func (r *Runner) WithCustomCollectors(collectors ...prometheus.Collector) *Runne } func (r *Runner) Run(ctx context.Context) error { - opts := zap.Options{ + opts := runserver.NewOptions() + zapopts := zap.Options{ Development: true, } - gfs := goflag.NewFlagSet("zap", goflag.ExitOnError) - opts.BindFlags(gfs) // zap expects a standard Go FlagSet and pflag.FlagSet is not compatible. - flag.CommandLine.AddGoFlagSet(gfs) - flag.Parse() - initLogging(&opts) + gfs := flag.NewFlagSet("zap", flag.ExitOnError) + zapopts.BindFlags(gfs) // zap expects a standard Go FlagSet and pflag.FlagSet is not compatible. + pflag.CommandLine.AddGoFlagSet(gfs) + opts.AddFlags(pflag.CommandLine) + pflag.Parse() - r.deprecatedFlagsHandler(setupLog) + if err := opts.Complete(); err != nil { + return err + } + if err := opts.Validate(); err != nil { + setupLog.Error(err, "Failed to validate flags") + return err + } - if *tracing { + initLogging(&zapopts, opts) + + if opts.Tracing { err := common.InitTracing(ctx, setupLog) if err != nil { return err @@ -218,12 +181,6 @@ func (r *Runner) Run(ctx context.Context) error { setupLog.Info(r.eppExecutableName+" build", "commit-sha", version.CommitSHA, "build-ref", version.BuildRef) - // Validate flags - if err := validateFlags(); err != nil { - setupLog.Error(err, "Failed to validate flags") - return err - } - // Print all flag values flags := make(map[string]any) flag.VisitAll(func(f *flag.Flag) { @@ -238,25 +195,26 @@ func (r *Runner) Run(ctx context.Context) error { return err } - rawConfig, err := r.parseConfigurationPhaseOne(ctx) + rawConfig, err := r.parseConfigurationPhaseOne(ctx, opts) if err != nil { setupLog.Error(err, "Failed to parse configuration") return err } // --- Setup Datastore --- - epf, err := r.setupMetricsCollection(setupLog, r.featureGates[datalayer.FeatureGate]) + epf, err := r.setupMetricsCollection(setupLog, r.featureGates[datalayer.FeatureGate], opts) if err != nil { return err } - gknn, err := extractGKNN(*poolName, *poolGroup, *poolNamespace, *endpointSelector) + gknn, err := extractGKNN(opts.PoolName, opts.PoolGroup, opts.PoolNamespace, opts.EndpointSelector) if err != nil { setupLog.Error(err, "Failed to extract GKNN") return err } - disableK8sCrdReconcile := *endpointSelector != "" - ds, err := setupDatastore(setupLog, ctx, epf, int32(*modelServerMetricsPort), disableK8sCrdReconcile, *poolName, *poolNamespace, *endpointSelector, *endpointTargetPorts) + disableK8sCrdReconcile := opts.EndpointSelector != "" + ds, err := setupDatastore(setupLog, ctx, epf, int32(opts.ModelServerMetricsPort), disableK8sCrdReconcile, + opts.PoolName, opts.PoolNamespace, opts.EndpointSelector, opts.EndpointTargetPorts) if err != nil { setupLog.Error(err, "Failed to setup datastore") return err @@ -277,9 +235,9 @@ func (r *Runner) Run(ctx context.Context) error { // - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.19.1/pkg/metrics/server // - https://book.kubebuilder.io/reference/metrics.html metricsServerOptions := metricsserver.Options{ - BindAddress: fmt.Sprintf(":%d", *metricsPort), + BindAddress: fmt.Sprintf(":%d", opts.MetricsPort), FilterProvider: func() func(c *rest.Config, httpClient *http.Client) (metricsserver.Filter, error) { - if *metricsEndpointAuth { + if opts.MetricsEndpointAuth { return filters.WithAuthenticationAndAuthorization } @@ -290,13 +248,13 @@ func (r *Runner) Run(ctx context.Context) error { isLeader := &atomic.Bool{} isLeader.Store(false) - mgr, err := runserver.NewDefaultManager(disableK8sCrdReconcile, *gknn, cfg, metricsServerOptions, *haEnableLeaderElection) + mgr, err := runserver.NewDefaultManager(disableK8sCrdReconcile, *gknn, cfg, metricsServerOptions, opts.EnableLeaderElection) if err != nil { setupLog.Error(err, "Failed to create controller manager") return err } - if *haEnableLeaderElection { + if opts.EnableLeaderElection { setupLog.Info("Leader election enabled") go func() { <-mgr.Elected() @@ -308,7 +266,7 @@ func (r *Runner) Run(ctx context.Context) error { isLeader.Store(true) } - if *enablePprof { + if opts.EnablePprof { setupLog.Info("Enabling pprof handlers") err = setupPprofHandlers(mgr) if err != nil { @@ -375,16 +333,16 @@ func (r *Runner) Run(ctx context.Context) error { // --- Setup ExtProc Server Runner --- serverRunner := &runserver.ExtProcServerRunner{ - GrpcPort: *grpcPort, + GrpcPort: opts.GRPCPort, GKNN: *gknn, Datastore: ds, DisableK8sCrdReconcile: disableK8sCrdReconcile, - SecureServing: *secureServing, - HealthChecking: *healthChecking, - CertPath: *certPath, - EnableCertReload: *enableCertReload, - RefreshPrometheusMetricsInterval: *refreshPrometheusMetricsInterval, - MetricsStalenessThreshold: *metricsStalenessThreshold, + SecureServing: opts.SecureServing, + HealthChecking: opts.HealthChecking, + CertPath: opts.CertPath, + EnableCertReload: opts.EnableCertReload, + RefreshPrometheusMetricsInterval: opts.RefreshPrometheusMetricsInterval, + MetricsStalenessThreshold: opts.MetricsStalenessThreshold, Director: director, SaturationDetector: saturationDetector, UseExperimentalDatalayerV2: r.featureGates[datalayer.FeatureGate], // pluggable data layer feature flag @@ -396,7 +354,7 @@ func (r *Runner) Run(ctx context.Context) error { // --- Add Runnables to Manager --- // Register health server. - if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), ds, *grpcHealthPort, isLeader, *haEnableLeaderElection); err != nil { + if err := registerHealthServer(mgr, ctrl.Log.WithName("health"), ds, opts.GRPCHealthPort, isLeader, opts.EnableLeaderElection); err != nil { return err } @@ -416,7 +374,8 @@ func (r *Runner) Run(ctx context.Context) error { return nil } -func setupDatastore(setupLog logr.Logger, ctx context.Context, epFactory datalayer.EndpointFactory, modelServerMetricsPort int32, disableK8sCrdReconcile bool, namespace, name, endpointSelector, endpointTargetPorts string) (datastore.Datastore, error) { +func setupDatastore(setupLog logr.Logger, ctx context.Context, epFactory datalayer.EndpointFactory, modelServerMetricsPort int32, + disableK8sCrdReconcile bool, namespace, name, endpointSelector string, endpointTargetPorts []int) (datastore.Datastore, error) { if !disableK8sCrdReconcile { return datastore.NewDatastore(ctx, epFactory, modelServerMetricsPort), nil } else { @@ -427,11 +386,7 @@ func setupDatastore(setupLog logr.Logger, ctx context.Context, epFactory datalay return nil, err } endpointPool.Selector = labelsMap - endpointPool.TargetPorts, err = strToUniqueIntSlice(endpointTargetPorts) - if err != nil { - setupLog.Error(err, "Failed to parse flag %q with error: %w", "endpoint-target-ports", err) - return nil, err - } + _ = copy(endpointPool.TargetPorts, endpointTargetPorts) endpointPoolOption := datastore.WithEndpointPool(endpointPool) return datastore.NewDatastore(ctx, epFactory, modelServerMetricsPort, endpointPoolOption), nil @@ -460,21 +415,21 @@ func (r *Runner) registerInTreePlugins() { plugins.Register(dlmetrics.MetricsExtractorType, dlmetrics.ModelServerExtractorFactory) } -func (r *Runner) parseConfigurationPhaseOne(ctx context.Context) (*configapi.EndpointPickerConfig, error) { - if *configText == "" && *configFile == "" { +func (r *Runner) parseConfigurationPhaseOne(ctx context.Context, opts *runserver.Options) (*configapi.EndpointPickerConfig, error) { + if opts.ConfigText == "" && opts.ConfigFile == "" { return nil, nil // configuring through code, not through file } logger := log.FromContext(ctx) var configBytes []byte - if *configText != "" { - configBytes = []byte(*configText) - } else if *configFile != "" { // if config was specified through a file + if opts.ConfigText != "" { + configBytes = []byte(opts.ConfigText) + } else if opts.ConfigFile != "" { // if config was specified through a file var err error - configBytes, err = os.ReadFile(*configFile) + configBytes, err = os.ReadFile(opts.ConfigFile) if err != nil { - return nil, fmt.Errorf("failed to load config from a file '%s' - %w", *configFile, err) + return nil, fmt.Errorf("failed to load config from a file '%s' - %w", opts.ConfigFile, err) } } @@ -531,14 +486,6 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf return cfg, nil } -func (r *Runner) deprecatedFlagsHandler(logger logr.Logger) { - flag.Visit(func(f *flag.Flag) { - if f.Name == "model-server-metrics-port" { // future: use map/set to store deprecated flags (and replacements?) - logger.Info("deprecated option will be removed in the next release.", "option", f.Name) - } - }) -} - func (r *Runner) deprecatedConfigurationHelper(cfg *config.Config, logger logr.Logger) { // Handle deprecated environment variable based feature flags @@ -577,24 +524,24 @@ func (r *Runner) deprecatedConfigurationHelper(cfg *config.Config, logger logr.L } } -func (r *Runner) setupMetricsCollection(setupLog logr.Logger, useExperimentalDatalayer bool) (datalayer.EndpointFactory, error) { +func (r *Runner) setupMetricsCollection(setupLog logr.Logger, useExperimentalDatalayer bool, opts *runserver.Options) (datalayer.EndpointFactory, error) { if useExperimentalDatalayer { - return setupDatalayer(setupLog) + return setupDatalayer(setupLog, opts) } if len(datalayer.GetSources()) != 0 { setupLog.Info("data sources registered but pluggable datalayer is disabled") } - return setupMetricsV1(setupLog) + return setupMetricsV1(setupLog, opts) } -func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) { +func setupMetricsV1(setupLog logr.Logger, opts *runserver.Options) (datalayer.EndpointFactory, error) { mapping, err := backendmetrics.NewMetricMapping( - *totalQueuedRequestsMetric, - *totalRunningRequestsMetric, - *kvCacheUsagePercentageMetric, - *loraInfoMetric, - *cacheInfoMetric, + opts.TotalQueuedRequestsMetric, + opts.TotalRunningRequestsMetric, + opts.KVCacheUsagePercentageMetric, + opts.LoRAInfoMetric, + opts.CacheInfoMetric, ) if err != nil { setupLog.Error(err, "Failed to create metric mapping from flags.") @@ -603,11 +550,11 @@ func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) { verifyMetricMapping(*mapping, setupLog) var metricsHttpClient *http.Client - if *modelServerMetricsScheme == "https" { + if opts.ModelServerMetricsScheme == "https" { metricsHttpClient = &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ - InsecureSkipVerify: *modelServerMetricsHttpsInsecureSkipVerify, + InsecureSkipVerify: opts.ModelServerMetricsHTTPSInsecure, }, }, } @@ -617,11 +564,11 @@ func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.PodMetricsClientImpl{ MetricMapping: mapping, - ModelServerMetricsPath: *modelServerMetricsPath, - ModelServerMetricsScheme: *modelServerMetricsScheme, + ModelServerMetricsPath: opts.ModelServerMetricsPath, + ModelServerMetricsScheme: opts.ModelServerMetricsScheme, Client: metricsHttpClient, }, - *refreshMetricsInterval) + opts.RefreshMetricsInterval) return pmf, nil } @@ -633,15 +580,15 @@ func setupMetricsV1(setupLog logr.Logger) (datalayer.EndpointFactory, error) { // endpoint factory) should be moved accordingly. // Regardless, registration of all sources (e.g., if additional sources // are to be configured), must be done before the EndpointFactory is initialized. -func setupDatalayer(logger logr.Logger) (datalayer.EndpointFactory, error) { +func setupDatalayer(logger logr.Logger, opts *runserver.Options) (datalayer.EndpointFactory, error) { // create and register a metrics data source and extractor. - source := dlmetrics.NewMetricsDataSource(*modelServerMetricsScheme, - *modelServerMetricsPath, - *modelServerMetricsHttpsInsecureSkipVerify) - extractor, err := dlmetrics.NewModelServerExtractor(*totalQueuedRequestsMetric, - *totalRunningRequestsMetric, - *kvCacheUsagePercentageMetric, - *loraInfoMetric, *cacheInfoMetric) + source := dlmetrics.NewMetricsDataSource(opts.ModelServerMetricsScheme, + opts.ModelServerMetricsPath, + opts.ModelServerMetricsHTTPSInsecure) + extractor, err := dlmetrics.NewModelServerExtractor(opts.TotalQueuedRequestsMetric, + opts.TotalRunningRequestsMetric, + opts.KVCacheUsagePercentageMetric, + opts.LoRAInfoMetric, opts.CacheInfoMetric) if err != nil { return nil, err @@ -658,11 +605,11 @@ func setupDatalayer(logger logr.Logger) (datalayer.EndpointFactory, error) { for _, src := range sources { logger.Info("data layer configuration", "source", src.TypedName().String(), "extractors", src.Extractors()) } - factory := datalayer.NewEndpointFactory(sources, *refreshMetricsInterval) + factory := datalayer.NewEndpointFactory(sources, opts.RefreshMetricsInterval) return factory, nil } -func initLogging(opts *zap.Options) { +func initLogging(opts *zap.Options, cliopts *runserver.Options) { // Unless -zap-log-level is explicitly set, use -v useV := true flag.Visit(func(f *flag.Flag) { @@ -672,7 +619,7 @@ func initLogging(opts *zap.Options) { }) if useV { // See https://pkg.go.dev/sigs.k8s.io/controller-runtime/pkg/log/zap#Options.Level - lvl := -1 * (*logVerbosity) + lvl := -1 * (cliopts.LogVerbosity) opts.Level = uberzap.NewAtomicLevelAt(zapcore.Level(int8(lvl))) } @@ -707,58 +654,6 @@ func registerHealthServer(mgr manager.Manager, logger logr.Logger, ds datastore. return nil } -func validateFlags() error { - if (*poolName != "" && *endpointSelector != "") || (*poolName == "" && *endpointSelector == "") { - return errors.New("either pool-name or endpoint-selector must be set") - } - if *endpointSelector != "" { - targetPortsList, err := strToUniqueIntSlice(*endpointTargetPorts) - if err != nil { - return fmt.Errorf("unexpected value for %q flag with error %w", "endpoint-target-ports", err) - } - if len(targetPortsList) == 0 || len(targetPortsList) > 8 { - return fmt.Errorf("flag %q should have length from 1 to 8", "endpoint-target-ports") - } - } - - if *configText != "" && *configFile != "" { - return fmt.Errorf("both the %q and %q flags can not be set at the same time", "configText", "configFile") - } - if *modelServerMetricsScheme != "http" && *modelServerMetricsScheme != "https" { - return fmt.Errorf("unexpected %q value for %q flag, it can only be set to 'http' or 'https'", *modelServerMetricsScheme, "model-server-metrics-scheme") - } - - return nil -} - -func strToUniqueIntSlice(s string) ([]int, error) { - seen := sets.NewInt() - var intList []int - - if s == "" { - return intList, nil - } - - strList := strings.Split(s, ",") - - for _, str := range strList { - trimmedStr := strings.TrimSpace(str) - if trimmedStr == "" { - continue - } - portInt, err := strconv.Atoi(trimmedStr) - if err != nil { - return nil, fmt.Errorf("invalid number: '%s' is not an integer", trimmedStr) - } - - if _, ok := seen[portInt]; !ok { - seen[portInt] = struct{}{} - intList = append(intList, portInt) - } - } - return intList, nil -} - func verifyMetricMapping(mapping backendmetrics.MetricMapping, logger logr.Logger) { if mapping.TotalQueuedRequests == nil { logger.Info("Not scraping metric: TotalQueuedRequests") diff --git a/pkg/epp/server/options.go b/pkg/epp/server/options.go new file mode 100644 index 000000000..5e6f92f4b --- /dev/null +++ b/pkg/epp/server/options.go @@ -0,0 +1,218 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package server + +import ( + "errors" + "fmt" + "time" + + "github.com/spf13/pflag" + "k8s.io/apimachinery/pkg/util/sets" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + DefaultGrpcPort = 9002 + DefaultPoolNamespace = "default" // default when pool namespace is empty (CLI flag default is empty) +) + +// Options contains configuration values necessary to create and run the EPP. +type Options struct { + // + // ext_proc configuration. + // + GRPCPort int // gRPC port used for communicating with Envoy proxy. (TODO: uint16?) + EnableLeaderElection bool // Enables leader election for high availability + // + // InferencePool. + // + PoolGroup string // Kubernetes resource group of the InferencePool this Endpoint Picker is associated with. + PoolNamespace string // Namespace of the InferencePool this Endpoint Picker is associated with. + PoolName string // Name of the InferencePool this Endpoint Picker is associated with. + // + // Endpoints (in lieu of using an InferencePool for service discovery). + // + EndpointSelector string // Selector to filter model server pods on, only 'key=value' pairs are supported. (TODO: k8s.Selector, pflag.StringSlice?) + EndpointTargetPorts []int // Target ports of model server pods. + // + // MSP metrics scraping. + // + ModelServerMetricsScheme string // Protocol scheme used in scraping metrics from endpoints. + ModelServerMetricsPath string // URL path used in scraping metrics from endpoints. + ModelServerMetricsPort int // Port to scrape metrics from endpoints. (TODO: Deprecated, uint16) + ModelServerMetricsHTTPSInsecure bool // Disable certificate verification when using 'https' scheme for 'model-server-metrics-scheme'. + RefreshMetricsInterval time.Duration // Interval to refresh metrics. + RefreshPrometheusMetricsInterval time.Duration // Interval to flush Prometheus metrics. + MetricsStalenessThreshold time.Duration // Duration after which metrics are considered stale. + TotalQueuedRequestsMetric string // Prometheus metric specification for the number of queued requests. + TotalRunningRequestsMetric string // Prometheus metric specification for the number of running requests. + KVCacheUsagePercentageMetric string // Prometheus metric specification for the fraction of KV-cache blocks currently in use. + LoRAInfoMetric string // Prometheus metric specification for the LoRA info metrics. + CacheInfoMetric string // Prometheus metric specification for the cache info metrics. + // + // Diagnostics. + // + LogVerbosity int // Number for the log level verbosity. + Tracing bool // Enables emitting traces. + HealthChecking bool // Enables health checking. + MetricsPort int // The metrics port exposed by EPP. (TODO: uint16) + GRPCHealthPort int // The port used for gRPC liveness and readiness probes. (TODO: uint16) + EnablePprof bool // Enables pprof handlers. + CertPath string // The path to the certificate for secure serving. + EnableCertReload bool // Enables certificate reloading of the certificates specified in --cert-path. + SecureServing bool // Enables secure serving. + MetricsEndpointAuth bool // Enables authentication and authorization of the metrics endpoint. + // + // Configuration. + // + ConfigFile string // The path to the configuration file. + ConfigText string // The configuration specified as text, in lieu of a file. +} + +// NewOptions returns a new Options struct initialized with the default values. +func NewOptions() *Options { + return &Options{ // "zero" values are no explicitly set + GRPCPort: DefaultGrpcPort, + PoolGroup: "inference.networking.k8s.io", + EndpointTargetPorts: []int{}, + ModelServerMetricsScheme: "http", + ModelServerMetricsPath: "/metrics", + ModelServerMetricsHTTPSInsecure: true, + RefreshMetricsInterval: 50 * time.Millisecond, + RefreshPrometheusMetricsInterval: 5 * time.Second, + MetricsStalenessThreshold: 2 * time.Second, + TotalQueuedRequestsMetric: "vllm:num_requests_waiting", + TotalRunningRequestsMetric: "vllm:num_requests_running", + KVCacheUsagePercentageMetric: "vllm:kv_cache_usage_perc", + LoRAInfoMetric: "vllm:lora_requests_info", + CacheInfoMetric: "vllm:cache_config_info", + LogVerbosity: logging.DEFAULT, + Tracing: true, + MetricsPort: 9090, + GRPCHealthPort: 9003, + EnablePprof: true, + SecureServing: true, + MetricsEndpointAuth: true, + } +} + +func (opts *Options) AddFlags(fs *pflag.FlagSet) { + if fs == nil { + fs = pflag.CommandLine + } + + fs.IntVar(&opts.GRPCPort, "grpc-port", opts.GRPCPort, "gRPC port used for communicating with Envoy proxy.") + fs.BoolVar(&opts.EnableLeaderElection, "ha-enable-leader-election", opts.EnableLeaderElection, + "Enables leader election for high availability. When enabled, readiness probes will only pass on the leader.") + fs.StringVar(&opts.PoolGroup, "pool-group", opts.PoolGroup, + "Kubernetes resource group of the InferencePool this Endpoint Picker is associated with.") + fs.StringVar(&opts.PoolNamespace, "pool-namespace", opts.PoolNamespace, + "Namespace of the InferencePool this Endpoint Picker is associated with.") + fs.StringVar(&opts.PoolName, "pool-name", opts.PoolName, "Name of the InferencePool this Endpoint Picker is associated with.") + fs.StringVar(&opts.EndpointSelector, "endpoint-selector", opts.EndpointSelector, + "Selector to filter model server pods on, only 'key=value' pairs are supported. "+ + "Format: a comma-separated list of key=value pairs without whitespace (e.g., 'app=vllm-llama3-8b-instruct,env=prod').") + fs.IntSliceVar(&opts.EndpointTargetPorts, "endpoint-target-ports", opts.EndpointTargetPorts, "Target ports of model server pods. "+ + "Format: a comma-separated list of numbers without whitespace (e.g., '3000,3001,3002').") + fs.StringVar(&opts.ModelServerMetricsScheme, "model-server-metrics-scheme", opts.ModelServerMetricsScheme, + "Protocol scheme used in scraping metrics from endpoints.") + fs.StringVar(&opts.ModelServerMetricsPath, "model-server-metrics-path", opts.ModelServerMetricsPath, + "URL path used in scraping metrics from endpoints.") + fs.IntVar(&opts.ModelServerMetricsPort, "model-server-metrics-port", opts.ModelServerMetricsPort, + "Port to scrape metrics from endpoints. Set to the InferencePool.Spec.TargetPorts[0].Number if not defined.") + _ = fs.MarkDeprecated("model-server-metrics-port", "This flag is deprecated and will be removed in a future release.") + fs.BoolVar(&opts.ModelServerMetricsHTTPSInsecure, "model-server-metrics-https-insecure-skip-verify", opts.ModelServerMetricsHTTPSInsecure, + "Disable certificate verification when using 'https' scheme for 'model-server-metrics-scheme'.") + fs.DurationVar(&opts.RefreshMetricsInterval, "refresh-metrics-interval", opts.RefreshMetricsInterval, "Interval to refresh metrics.") + fs.DurationVar(&opts.RefreshPrometheusMetricsInterval, "refresh-prometheus-metrics-interval", opts.RefreshPrometheusMetricsInterval, + "Interval to flush Prometheus metrics.") + fs.DurationVar(&opts.MetricsStalenessThreshold, "metrics-staleness-threshold", opts.MetricsStalenessThreshold, + "Duration after which metrics are considered stale. This is used to determine if an endpoint's metrics are fresh enough.") + fs.StringVar(&opts.TotalQueuedRequestsMetric, "total-queued-requests-metric", opts.TotalQueuedRequestsMetric, + "Prometheus metric for the number of queued requests.") + fs.StringVar(&opts.TotalRunningRequestsMetric, "total-running-requests-metric", opts.TotalRunningRequestsMetric, + "Prometheus metric for the number of running requests.") + fs.StringVar(&opts.KVCacheUsagePercentageMetric, "kv-cache-usage-percentage-metric", opts.KVCacheUsagePercentageMetric, + "Prometheus metric for the fraction of KV-cache blocks currently in use (from 0 to 1).") + fs.StringVar(&opts.LoRAInfoMetric, "lora-info-metric", opts.LoRAInfoMetric, + "Prometheus metric for the LoRA info metrics (must be in vLLM label format).") + fs.StringVar(&opts.CacheInfoMetric, "cache-info-metric", opts.CacheInfoMetric, "Prometheus metric for the cache info metrics.") + fs.IntVar(&opts.LogVerbosity, "v", opts.LogVerbosity, "Number for the log level verbosity.") + fs.BoolVar(&opts.Tracing, "tracing", opts.Tracing, "Enables emitting traces.") + fs.BoolVar(&opts.HealthChecking, "health-checking", opts.HealthChecking, "Enables health checking.") + fs.IntVar(&opts.MetricsPort, "metrics-port", opts.MetricsPort, "The metrics port exposed by EPP.") + fs.IntVar(&opts.GRPCHealthPort, "grpc-health-port", opts.GRPCHealthPort, + "The port used for gRPC liveness and readiness probes.") + fs.BoolVar(&opts.EnablePprof, "enable-pprof", opts.EnablePprof, + "Enables pprof handlers. Defaults to true. Set to false to disable pprof handlers.") + fs.StringVar(&opts.CertPath, "cert-path", opts.CertPath, + "The path to the certificate for secure serving. The certificate and private key files "+ + "are assumed to be named tls.crt and tls.key, respectively. If not set, and secureServing is enabled, "+ + "then a self-signed certificate is used.") + fs.BoolVar(&opts.EnableCertReload, "enable-cert-reload", opts.EnableCertReload, + "Enables certificate reloading of the certificates specified in --cert-path.") + fs.BoolVar(&opts.SecureServing, "secure-serving", opts.SecureServing, "Enables secure serving.") + fs.BoolVar(&opts.MetricsEndpointAuth, "metrics-endpoint-auth", opts.MetricsEndpointAuth, + "Enables authentication and authorization of the metrics endpoint.") + fs.StringVar(&opts.ConfigFile, "config-file", opts.ConfigFile, "The path to the configuration file.") + fs.StringVar(&opts.ConfigText, "config-text", opts.ConfigText, "The configuration specified as text, in lieu of a file.") +} + +func (opts *Options) Complete() error { + // TODO: postprocessing or command line arguments. For example, convert EndpointSelector + // from raw string to k8s.LabelSelector, load ConfigFile into ConfigText, etc. + + opts.EndpointTargetPorts = removeDuplicatePorts(opts.EndpointTargetPorts) + + return nil +} + +func (opts *Options) Validate() error { + if (opts.PoolName != "" && opts.EndpointSelector != "") || (opts.PoolName == "" && opts.EndpointSelector == "") { + return errors.New("either pool-name or endpoint-selector must be set") + } + if opts.EndpointSelector != "" { + if len(opts.EndpointTargetPorts) == 0 || len(opts.EndpointTargetPorts) > 8 { + return fmt.Errorf("flag %q should have length from 1 to 8", "endpoint-target-ports") + } + } + + if opts.ConfigText != "" && opts.ConfigFile != "" { + return fmt.Errorf("both the %q and %q flags can not be set at the same time", "configText", "configFile") + } + if opts.ModelServerMetricsScheme != "http" && opts.ModelServerMetricsScheme != "https" { + return fmt.Errorf("unexpected %q value for %q flag, it can only be set to 'http' or 'https'", + opts.ModelServerMetricsScheme, "model-server-metrics-scheme") + } + + return nil +} + +func removeDuplicatePorts(ports []int) []int { + seen := sets.NewInt() + unique := make([]int, 0, len(ports)) + + for _, val := range ports { + if !seen.Has(val) { + unique = append(unique, val) + seen.Insert(val) + } + } + return unique +} diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index d35e70c74..565f059f2 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -66,49 +66,29 @@ type ExtProcServerRunner struct { TestPodMetricsClient *backendmetrics.FakePodMetricsClient } -// Default values for CLI flags in main -const ( - DefaultGrpcPort = 9002 // default for --grpc-port - DefaultGrpcHealthPort = 9003 // default for --grpc-health-port - DefaultMetricsPort = 9090 // default for --metrics-port - DefaultPoolName = "" // required but no default - DefaultPoolNamespace = "default" // default for --pool-namespace - DefaultRefreshMetricsInterval = 50 * time.Millisecond // default for --refresh-metrics-interval - DefaultRefreshPrometheusMetricsInterval = 5 * time.Second // default for --refresh-prometheus-metrics-interval - DefaultSecureServing = true // default for --secure-serving - DefaultHealthChecking = false // default for --health-checking - DefaultEnablePprof = true // default for --enable-pprof - DefaultTotalQueuedRequestsMetric = "vllm:num_requests_waiting" // default for --total-queued-requests-metric - DefaultTotalRunningRequestsMetric = "vllm:num_requests_running" // default for --total-running-requests-metric - DefaultKvCacheUsagePercentageMetric = "vllm:kv_cache_usage_perc" // default for --kv-cache-usage-percentage-metric - DefaultLoraInfoMetric = "vllm:lora_requests_info" // default for --lora-info-metric - DefaultCacheInfoMetric = "vllm:cache_config_info" // default for --cache-info-metric - DefaultCertPath = "" // default for --cert-path - DefaultCertReload = false // default for --enable-cert-reload - DefaultConfigFile = "" // default for --config-file - DefaultConfigText = "" // default for --config-text - DefaultPoolGroup = "inference.networking.k8s.io" // default for --pool-group - DefaultMetricsStalenessThreshold = 2 * time.Second -) - // NewDefaultExtProcServerRunner creates a runner with default values. // Note: Dependencies like Datastore, Scheduler, SD need to be set separately. func NewDefaultExtProcServerRunner() *ExtProcServerRunner { + opts := NewOptions() + if opts.PoolNamespace == "" { + opts.PoolNamespace = DefaultPoolNamespace + } + gknn := common.GKNN{ - NamespacedName: types.NamespacedName{Name: DefaultPoolName, Namespace: DefaultPoolNamespace}, + NamespacedName: types.NamespacedName{Name: opts.PoolName, Namespace: opts.PoolNamespace}, GroupKind: schema.GroupKind{ - Group: DefaultPoolGroup, + Group: opts.PoolGroup, Kind: "InferencePool", }, } return &ExtProcServerRunner{ - GrpcPort: DefaultGrpcPort, + GrpcPort: opts.GRPCPort, GKNN: gknn, DisableK8sCrdReconcile: false, - SecureServing: DefaultSecureServing, - HealthChecking: DefaultHealthChecking, - RefreshPrometheusMetricsInterval: DefaultRefreshPrometheusMetricsInterval, - MetricsStalenessThreshold: DefaultMetricsStalenessThreshold, + SecureServing: opts.SecureServing, + HealthChecking: opts.HealthChecking, + RefreshPrometheusMetricsInterval: opts.RefreshPrometheusMetricsInterval, + MetricsStalenessThreshold: opts.MetricsStalenessThreshold, // Dependencies can be assigned later. } }