diff --git a/pkg/epp/backend/metrics/fake.go b/pkg/epp/backend/metrics/fake.go index a61e4ee74..cc4a3c0db 100644 --- a/pkg/epp/backend/metrics/fake.go +++ b/pkg/epp/backend/metrics/fake.go @@ -25,14 +25,13 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) // FakePodMetrics is an implementation of PodMetrics that doesn't run the async refresh loop. type FakePodMetrics struct { - Pod *backend.Pod + Metadata *datalayer.EndpointMetadata Metrics *MetricsState Attributes *datalayer.Attributes } @@ -41,8 +40,8 @@ func (fpm *FakePodMetrics) String() string { return fmt.Sprintf("Metadata: %v; Metrics: %v", fpm.GetMetadata(), fpm.GetMetrics()) } -func (fpm *FakePodMetrics) GetMetadata() *backend.Pod { - return fpm.Pod +func (fpm *FakePodMetrics) GetMetadata() *datalayer.EndpointMetadata { + return fpm.Metadata } func (fpm *FakePodMetrics) GetMetrics() *MetricsState { @@ -50,7 +49,7 @@ func (fpm *FakePodMetrics) GetMetrics() *MetricsState { } func (fpm *FakePodMetrics) UpdateMetadata(metadata *datalayer.EndpointMetadata) { - fpm.Pod = metadata + fpm.Metadata = metadata } func (fpm *FakePodMetrics) GetAttributes() *datalayer.Attributes { return fpm.Attributes @@ -72,7 +71,7 @@ type FakePodMetricsClient struct { Res map[types.NamespacedName]*MetricsState } -func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) { +func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *datalayer.EndpointMetadata, existing *MetricsState) (*MetricsState, error) { f.errMu.RLock() err, ok := f.Err[pod.NamespacedName] f.errMu.RUnlock() diff --git a/pkg/epp/backend/metrics/metrics.go b/pkg/epp/backend/metrics/metrics.go index 4c7bb13a4..165b99f39 100644 --- a/pkg/epp/backend/metrics/metrics.go +++ b/pkg/epp/backend/metrics/metrics.go @@ -28,7 +28,7 @@ import ( "github.com/prometheus/common/model" "go.uber.org/multierr" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" ) const ( @@ -50,22 +50,22 @@ type PodMetricsClientImpl struct { } // FetchMetrics fetches metrics from a given pod, clones the existing metrics object and returns an updated one. -func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) { - url := p.getMetricEndpoint(pod) +func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, metadata *datalayer.EndpointMetadata, existing *MetricsState) (*MetricsState, error) { + url := p.getMetricEndpoint(metadata) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %v", err) } resp, err := p.Client.Do(req) if err != nil { - return nil, fmt.Errorf("failed to fetch metrics from %s: %w", pod.NamespacedName, err) + return nil, fmt.Errorf("failed to fetch metrics from %s: %w", metadata.NamespacedName, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code from %s: %v", pod.NamespacedName, resp.StatusCode) + return nil, fmt.Errorf("unexpected status code from %s: %v", metadata.NamespacedName, resp.StatusCode) } parser := expfmt.NewTextParser(model.LegacyValidation) @@ -76,8 +76,8 @@ func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Po return p.promToPodMetrics(metricFamilies, existing) } -func (p *PodMetricsClientImpl) getMetricEndpoint(pod *backend.Pod) string { - return p.ModelServerMetricsScheme + "://" + pod.GetMetricsHost() + p.ModelServerMetricsPath +func (p *PodMetricsClientImpl) getMetricEndpoint(metadata *datalayer.EndpointMetadata) string { + return p.ModelServerMetricsScheme + "://" + metadata.GetMetricsHost() + p.ModelServerMetricsPath } // promToPodMetrics updates internal pod metrics with scraped Prometheus metrics. diff --git a/pkg/epp/backend/metrics/metrics_test.go b/pkg/epp/backend/metrics/metrics_test.go index f1256ec6b..dcea8621d 100644 --- a/pkg/epp/backend/metrics/metrics_test.go +++ b/pkg/epp/backend/metrics/metrics_test.go @@ -31,7 +31,7 @@ import ( "google.golang.org/protobuf/proto" "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -577,7 +577,7 @@ func TestPromToPodMetrics(t *testing.T) { // there's no server running on the specified port. func TestFetchMetrics(t *testing.T) { ctx := logutil.NewTestLoggerIntoContext(context.Background()) - pod := &backend.Pod{ + metadata := &datalayer.EndpointMetadata{ Address: "127.0.0.1", Port: "9999", MetricsHost: "127.0.0.1:9999", @@ -594,7 +594,7 @@ func TestFetchMetrics(t *testing.T) { Client: http.DefaultClient, } - _, err := p.FetchMetrics(ctx, pod, existing) // Use a port that's unlikely to be in use + _, err := p.FetchMetrics(ctx, metadata, existing) // Use a port that's unlikely to be in use if err == nil { t.Errorf("FetchMetrics() expected error, got nil") } diff --git a/pkg/epp/backend/pod.go b/pkg/epp/backend/pod.go deleted file mode 100644 index e24494042..000000000 --- a/pkg/epp/backend/pod.go +++ /dev/null @@ -1,23 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" -) - -type Pod = datalayer.EndpointMetadata diff --git a/pkg/epp/config/loader/configloader_test.go b/pkg/epp/config/loader/configloader_test.go index 51c20334b..210c5a21b 100644 --- a/pkg/epp/config/loader/configloader_test.go +++ b/pkg/epp/config/loader/configloader_test.go @@ -408,14 +408,14 @@ func (m *mockPlugin) TypedName() plugins.TypedName { return m.t } // Mock Scorer type mockScorer struct{ mockPlugin } -func (m *mockScorer) Score(context.Context, *types.CycleState, *types.LLMRequest, []types.Pod) map[types.Pod]float64 { +func (m *mockScorer) Score(context.Context, *types.CycleState, *types.LLMRequest, []types.Endpoint) map[types.Endpoint]float64 { return nil } // Mock Picker type mockPicker struct{ mockPlugin } -func (m *mockPicker) Pick(context.Context, *types.CycleState, []*types.ScoredPod) *types.ProfileRunResult { +func (m *mockPicker) Pick(context.Context, *types.CycleState, []*types.ScoredEndpoint) *types.ProfileRunResult { return nil } diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index bfc4e147a..5b5e92cc5 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -40,15 +40,15 @@ func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extP // an EoS in the request headers means this request has no body or trailers. if req.RequestHeaders.EndOfStream { - // We will route this request to a random pod as this is assumed to just be a GET + // We will route this request to a random endpoint as this is assumed to just be a GET // More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526 // The above PR will address endpoint admission, but currently any request without a body will be - // routed to a random upstream pod. - pod := s.director.GetRandomPod() - if pod == nil { + // routed to a random upstream endpoint. + endpoint := s.director.GetRandomEndpoint() + if endpoint == nil { return errutil.Error{Code: errutil.Internal, Msg: "no pods available in datastore"} } - reqCtx.TargetEndpoint = pod.GetIPAddress() + ":" + pod.GetPort() + reqCtx.TargetEndpoint = endpoint.GetIPAddress() + ":" + endpoint.GetPort() reqCtx.RequestSize = 0 reqCtx.reqHeaderResp = s.generateRequestHeaderResponse(reqCtx) return nil diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 188a4197e..0bf955750 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -23,7 +23,7 @@ import ( "github.com/google/go-cmp/cmp" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -103,8 +103,8 @@ func (m *mockDirector) HandleResponseReceived(ctx context.Context, reqCtx *Reque func (m *mockDirector) HandlePreRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { return reqCtx, nil } -func (m *mockDirector) GetRandomPod() *backend.Pod { - return &backend.Pod{} +func (m *mockDirector) GetRandomEndpoint() *datalayer.EndpointMetadata { + return &datalayer.EndpointMetadata{} } func (m *mockDirector) HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { return reqCtx, nil diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 0e8bb599a..bec4aec42 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -31,7 +31,6 @@ import ( "google.golang.org/grpc/status" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" @@ -57,7 +56,7 @@ type Director interface { HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) - GetRandomPod() *backend.Pod + GetRandomEndpoint() *datalayer.EndpointMetadata } type Datastore interface { @@ -76,7 +75,7 @@ type StreamingServer struct { // Specifically, there are fields related to the ext-proc protocol, and then fields related to the lifecycle of the request. // We should split these apart as this monolithic object exposes too much data to too many layers. type RequestContext struct { - TargetPod *backend.Pod + TargetPod *datalayer.EndpointMetadata TargetEndpoint string IncomingModelName string TargetModelName string diff --git a/pkg/epp/requestcontrol/dag_test.go b/pkg/epp/requestcontrol/dag_test.go index b86581b1c..031cdfae6 100644 --- a/pkg/epp/requestcontrol/dag_test.go +++ b/pkg/epp/requestcontrol/dag_test.go @@ -45,8 +45,8 @@ func (m *mockPrepareRequestDataP) Consumes() map[string]any { return m.consumes } -func (m *mockPrepareRequestDataP) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error { - pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42}) +func (m *mockPrepareRequestDataP) PrepareRequestData(ctx context.Context, request *types.LLMRequest, endpoints []types.Endpoint) error { + endpoints[0].Put(mockProducedDataKey, mockProducedDataType{value: 42}) return nil } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 4517e8742..1dc56a31f 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -28,7 +28,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" @@ -57,7 +56,7 @@ type Datastore interface { // Scheduler defines the interface required by the Director for scheduling. type Scheduler interface { - Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) (result *schedulingtypes.SchedulingResult, err error) + Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Endpoint) (result *schedulingtypes.SchedulingResult, err error) } // NewDirectorWithConfig creates a new Director instance with all dependencies. @@ -245,20 +244,20 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"} } // primary profile is used to set destination - targetPods := []*backend.Pod{} + targetMetadatas := []*datalayer.EndpointMetadata{} targetEndpoints := []string{} - for _, pod := range result.ProfileResults[result.PrimaryProfileName].TargetPods { - curPod := pod.GetPod() - curEndpoint := net.JoinHostPort(curPod.GetIPAddress(), curPod.GetPort()) - targetPods = append(targetPods, curPod) + for _, pod := range result.ProfileResults[result.PrimaryProfileName].TargetEndpoints { + curMetadata := pod.GetMetadata() + curEndpoint := net.JoinHostPort(curMetadata.GetIPAddress(), curMetadata.GetPort()) + targetMetadatas = append(targetMetadatas, curMetadata) targetEndpoints = append(targetEndpoints, curEndpoint) } multiEndpointString := strings.Join(targetEndpoints, ",") logger.V(logutil.VERBOSE).Info("Request handled", "objectiveKey", reqCtx.ObjectiveKey, "incomingModelName", reqCtx.IncomingModelName, "targetModel", reqCtx.TargetModelName, "endpoint", multiEndpointString) - reqCtx.TargetPod = targetPods[0] + reqCtx.TargetPod = targetMetadatas[0] reqCtx.TargetEndpoint = multiEndpointString d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result) @@ -266,13 +265,13 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC return reqCtx, nil } -func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []schedulingtypes.Pod { - pm := make([]schedulingtypes.Pod, len(pods)) +func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []schedulingtypes.Endpoint { + pm := make([]schedulingtypes.Endpoint, len(pods)) for i, pod := range pods { if pod.GetAttributes() != nil { - pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: pod.GetAttributes().Clone()} + pm[i] = &schedulingtypes.PodMetrics{EndpointMetadata: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: pod.GetAttributes().Clone()} } else { - pm[i] = &schedulingtypes.PodMetrics{Pod: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: datalayer.NewAttributes()} + pm[i] = &schedulingtypes.PodMetrics{EndpointMetadata: pod.GetMetadata().Clone(), MetricsState: pod.GetMetrics().Clone(), AttributeMap: datalayer.NewAttributes()} } } @@ -323,7 +322,7 @@ func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handl return reqCtx, nil } -func (d *Director) GetRandomPod() *backend.Pod { +func (d *Director) GetRandomEndpoint() *datalayer.EndpointMetadata { pods := d.datastore.PodList(datastore.AllPodsPredicate) if len(pods) == 0 { return nil @@ -346,19 +345,19 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling } func (d *Director) runPrepareDataPlugins(ctx context.Context, - request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) error { if len(d.requestControlPlugins.prepareDataPlugins) == 0 { return nil } - return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods) + return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, endpoints) } func (d *Director) runAdmissionPlugins(ctx context.Context, - request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) bool { + request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) bool { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) for _, plugin := range d.requestControlPlugins.admissionPlugins { loggerDebug.Info("Running AdmitRequest plugin", "plugin", plugin.TypedName()) - if denyReason := plugin.AdmitRequest(ctx, request, pods); denyReason != nil { + if denyReason := plugin.AdmitRequest(ctx, request, endpoints); denyReason != nil { loggerDebug.Info("AdmitRequest plugin denied the request", "plugin", plugin.TypedName(), "reason", denyReason.Error()) return false } @@ -367,34 +366,34 @@ func (d *Director) runAdmissionPlugins(ctx context.Context, return true } -func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) for _, plugin := range d.requestControlPlugins.responseReceivedPlugins { loggerDebug.Info("Running ResponseReceived plugin", "plugin", plugin.TypedName()) before := time.Now() - plugin.ResponseReceived(ctx, request, response, targetPod) + plugin.ResponseReceived(ctx, request, response, targetEndpoint) metrics.RecordPluginProcessingLatency(ResponseReceivedExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before)) loggerDebug.Info("Completed running ResponseReceived plugin successfully", "plugin", plugin.TypedName()) } } -func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) for _, plugin := range d.requestControlPlugins.responseStreamingPlugins { loggerTrace.Info("Running ResponseStreaming plugin", "plugin", plugin.TypedName()) before := time.Now() - plugin.ResponseStreaming(ctx, request, response, targetPod) + plugin.ResponseStreaming(ctx, request, response, targetEndpoint) metrics.RecordPluginProcessingLatency(ResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before)) loggerTrace.Info("Completed running ResponseStreaming plugin successfully", "plugin", plugin.TypedName()) } } -func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) for _, plugin := range d.requestControlPlugins.responseCompletePlugins { loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName()) before := time.Now() - plugin.ResponseComplete(ctx, request, response, targetPod) + plugin.ResponseComplete(ctx, request, response, targetEndpoint) metrics.RecordPluginProcessingLatency(ResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before)) loggerDebug.Info("Completed running ResponseComplete plugin successfully", "plugin", plugin.TypedName()) } diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index 9b583cf2c..10bef6c81 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -36,7 +36,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" @@ -70,9 +69,9 @@ type mockScheduler struct { dataProduced bool // denotes whether data production is expected. } -func (m *mockScheduler) Schedule(_ context.Context, _ *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) (*schedulingtypes.SchedulingResult, error) { - if pods != nil && m.dataProduced { - data, ok := pods[0].Get(mockProducedDataKey) +func (m *mockScheduler) Schedule(_ context.Context, _ *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) (*schedulingtypes.SchedulingResult, error) { + if endpoints != nil && m.dataProduced { + data, ok := endpoints[0].Get(mockProducedDataKey) if !ok || data.(mockProducedDataType).value != 42 { return nil, errors.New("expected produced data not found in pod") } @@ -120,8 +119,8 @@ func (m *mockPrepareDataPlugin) Consumes() map[string]any { return m.consumes } -func (m *mockPrepareDataPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { - pods[0].Put(mockProducedDataKey, mockProducedDataType{value: 42}) +func (m *mockPrepareDataPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) error { + endpoints[0].Put(mockProducedDataKey, mockProducedDataType{value: 42}) return nil } @@ -149,7 +148,7 @@ func (m *mockAdmissionPlugin) TypedName() plugins.TypedName { return m.typedName } -func (m *mockAdmissionPlugin) AdmitRequest(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { +func (m *mockAdmissionPlugin) AdmitRequest(ctx context.Context, request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) error { return m.denialError } @@ -258,11 +257,11 @@ func TestDirector_HandleRequest(t *testing.T) { defaultSuccessfulScheduleResults := &schedulingtypes.SchedulingResult{ ProfileResults: map[string]*schedulingtypes.ProfileRunResult{ "testProfile": { - TargetPods: []schedulingtypes.Pod{ - &schedulingtypes.ScoredPod{ - Pod: &schedulingtypes.PodMetrics{ + TargetEndpoints: []schedulingtypes.Endpoint{ + &schedulingtypes.ScoredEndpoint{ + Endpoint: &schedulingtypes.PodMetrics{ AttributeMap: datalayer.NewAttributes(), - Pod: &backend.Pod{ + EndpointMetadata: &datalayer.EndpointMetadata{ Address: "192.168.1.100", Port: "8000", MetricsHost: "192.168.1.100:8000", @@ -270,10 +269,10 @@ func TestDirector_HandleRequest(t *testing.T) { }, }, }, - &schedulingtypes.ScoredPod{ - Pod: &schedulingtypes.PodMetrics{ + &schedulingtypes.ScoredEndpoint{ + Endpoint: &schedulingtypes.PodMetrics{ AttributeMap: datalayer.NewAttributes(), - Pod: &backend.Pod{ + EndpointMetadata: &datalayer.EndpointMetadata{ Address: "192.168.2.100", Port: "8000", MetricsHost: "192.168.2.100:8000", @@ -281,10 +280,10 @@ func TestDirector_HandleRequest(t *testing.T) { }, }, }, - &schedulingtypes.ScoredPod{ - Pod: &schedulingtypes.PodMetrics{ + &schedulingtypes.ScoredEndpoint{ + Endpoint: &schedulingtypes.PodMetrics{ AttributeMap: datalayer.NewAttributes(), - Pod: &backend.Pod{ + EndpointMetadata: &datalayer.EndpointMetadata{ Address: "192.168.4.100", Port: "8000", MetricsHost: "192.168.4.100:8000", @@ -326,7 +325,7 @@ func TestDirector_HandleRequest(t *testing.T) { wantReqCtx: &handlers.RequestContext{ ObjectiveKey: objectiveName, TargetModelName: model, - TargetPod: &backend.Pod{ + TargetPod: &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", Port: "8000", @@ -350,7 +349,7 @@ func TestDirector_HandleRequest(t *testing.T) { wantReqCtx: &handlers.RequestContext{ ObjectiveKey: model, TargetModelName: modelRewritten, - TargetPod: &backend.Pod{ + TargetPod: &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", Port: "8000", @@ -378,7 +377,7 @@ func TestDirector_HandleRequest(t *testing.T) { initialTargetModelName: model, wantReqCtx: &handlers.RequestContext{ TargetModelName: model, - TargetPod: &backend.Pod{ + TargetPod: &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", Port: "8000", @@ -407,7 +406,7 @@ func TestDirector_HandleRequest(t *testing.T) { }, wantReqCtx: &handlers.RequestContext{ TargetModelName: model, - TargetPod: &backend.Pod{ + TargetPod: &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", Port: "8000", @@ -436,7 +435,7 @@ func TestDirector_HandleRequest(t *testing.T) { }, wantReqCtx: &handlers.RequestContext{ TargetModelName: model, - TargetPod: &backend.Pod{ + TargetPod: &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", Port: "8000", @@ -491,7 +490,7 @@ func TestDirector_HandleRequest(t *testing.T) { wantReqCtx: &handlers.RequestContext{ ObjectiveKey: objectiveName, TargetModelName: model, - TargetPod: &backend.Pod{ + TargetPod: &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", Port: "8000", @@ -514,7 +513,7 @@ func TestDirector_HandleRequest(t *testing.T) { wantReqCtx: &handlers.RequestContext{ ObjectiveKey: objectiveNameResolve, TargetModelName: "resolved-target-model-A", - TargetPod: &backend.Pod{ + TargetPod: &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", Port: "8000", @@ -534,7 +533,7 @@ func TestDirector_HandleRequest(t *testing.T) { wantReqCtx: &handlers.RequestContext{ ObjectiveKey: "food-review-1", TargetModelName: "food-review-1", - TargetPod: &backend.Pod{ + TargetPod: &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, Address: "192.168.1.100", Port: "8000", @@ -714,7 +713,7 @@ func TestDirector_HandleRequest(t *testing.T) { } } -func TestGetRandomPod(t *testing.T) { +func TestGetRandomEndpoint(t *testing.T) { tests := []struct { name string storePods []*corev1.Pod @@ -776,12 +775,12 @@ func TestGetRandomPod(t *testing.T) { ds.PodUpdateOrAddIfNotExist(pod) } d := &Director{datastore: ds} - gotPod := d.GetRandomPod() + gotEndpoint := d.GetRandomEndpoint() - if test.expectNil && gotPod != nil { - t.Errorf("expected nil pod, got: %v", gotPod) + if test.expectNil && gotEndpoint != nil { + t.Errorf("expected nil pod, got: %v", gotEndpoint) } - if !test.expectNil && gotPod == nil { + if !test.expectNil && gotEndpoint == nil { t.Errorf("expected non-nil pod, got nil") } }) @@ -1075,7 +1074,7 @@ func TestDirector_HandleResponseReceived(t *testing.T) { Headers: map[string]string{"X-Test-Response-Header": "TestValue"}, }, - TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, + TargetPod: &datalayer.EndpointMetadata{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } _, err := director.HandleResponseReceived(ctx, reqCtx) @@ -1112,7 +1111,7 @@ func TestDirector_HandleResponseStreaming(t *testing.T) { Response: &handlers.Response{ Headers: map[string]string{"X-Test-Streaming-Header": "StreamValue"}, }, - TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, + TargetPod: &datalayer.EndpointMetadata{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } _, err := director.HandleResponseBodyStreaming(ctx, reqCtx) @@ -1149,7 +1148,7 @@ func TestDirector_HandleResponseComplete(t *testing.T) { Response: &handlers.Response{ Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"}, }, - TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, + TargetPod: &datalayer.EndpointMetadata{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } _, err := director.HandleResponseBodyComplete(ctx, reqCtx) @@ -1222,17 +1221,17 @@ func (p *testResponseComplete) TypedName() plugins.TypedName { return p.typedName } -func (p *testResponseReceived) ResponseReceived(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (p *testResponseReceived) ResponseReceived(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *datalayer.EndpointMetadata) { p.lastRespOnResponse = response p.lastTargetPodOnResponse = targetPod.NamespacedName.String() } -func (p *testResponseStreaming) ResponseStreaming(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (p *testResponseStreaming) ResponseStreaming(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *datalayer.EndpointMetadata) { p.lastRespOnStreaming = response p.lastTargetPodOnStreaming = targetPod.NamespacedName.String() } -func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *datalayer.EndpointMetadata) { p.lastRespOnComplete = response p.lastTargetPodOnComplete = targetPod.NamespacedName.String() } diff --git a/pkg/epp/requestcontrol/locator_test.go b/pkg/epp/requestcontrol/locator_test.go index 0f5a5df1c..5c53bdf94 100644 --- a/pkg/epp/requestcontrol/locator_test.go +++ b/pkg/epp/requestcontrol/locator_test.go @@ -26,8 +26,8 @@ import ( "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" ) @@ -276,7 +276,7 @@ func (m *mockPodLocator) callCount() int { func makeMockPodMetrics(name, ip string) backendmetrics.PodMetrics { return &backendmetrics.FakePodMetrics{ - Pod: &backend.Pod{ + Metadata: &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{Namespace: "default", Name: name}, Address: ip, }, diff --git a/pkg/epp/requestcontrol/plugin_executor.go b/pkg/epp/requestcontrol/plugin_executor.go index 7771c4f67..292ddb09c 100644 --- a/pkg/epp/requestcontrol/plugin_executor.go +++ b/pkg/epp/requestcontrol/plugin_executor.go @@ -27,9 +27,9 @@ import ( // executePluginsAsDAG executes PrepareData plugins as a DAG based on their dependencies asynchronously. // So, a plugin is executed only after all its dependencies have been executed. // If there is a cycle or any plugin fails with error, it returns an error. -func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { +func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) error { for _, plugin := range plugins { - if err := plugin.PrepareRequestData(ctx, request, pods); err != nil { + if err := plugin.PrepareRequestData(ctx, request, endpoints); err != nil { return errors.New("prepare data plugin " + plugin.TypedName().String() + " failed: " + err.Error()) } } @@ -38,10 +38,10 @@ func executePluginsAsDAG(plugins []PrepareDataPlugin, ctx context.Context, reque // prepareDataPluginsWithTimeout executes the PrepareRequestData plugins with retries and timeout. func prepareDataPluginsWithTimeout(timeout time.Duration, plugins []PrepareDataPlugin, - ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { + ctx context.Context, request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) error { errCh := make(chan error, 1) go func() { - errCh <- executePluginsAsDAG(plugins, ctx, request, pods) + errCh <- executePluginsAsDAG(plugins, ctx, request, endpoints) }() select { diff --git a/pkg/epp/requestcontrol/plugin_executor_test.go b/pkg/epp/requestcontrol/plugin_executor_test.go index 1825f7a4c..354bc6374 100644 --- a/pkg/epp/requestcontrol/plugin_executor_test.go +++ b/pkg/epp/requestcontrol/plugin_executor_test.go @@ -41,7 +41,7 @@ func (m *mockPrepareRequestDataPlugin) TypedName() plugins.TypedName { return plugins.TypedName{Type: "mock", Name: m.name} } -func (m *mockPrepareRequestDataPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { +func (m *mockPrepareRequestDataPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) error { m.executed = true if m.delay > 0 { select { @@ -167,11 +167,11 @@ type dagTestPlugin struct { mu sync.Mutex } -func (p *dagTestPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error { +func (p *dagTestPlugin) PrepareRequestData(ctx context.Context, request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) error { p.mu.Lock() defer p.mu.Unlock() p.execTime = time.Now() - return p.mockPrepareRequestDataPlugin.PrepareRequestData(ctx, request, pods) + return p.mockPrepareRequestDataPlugin.PrepareRequestData(ctx, request, endpoints) } func (p *dagTestPlugin) Produces() map[string]any { diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go index 8c6602049..4830dd1af 100644 --- a/pkg/epp/requestcontrol/plugins.go +++ b/pkg/epp/requestcontrol/plugins.go @@ -19,7 +19,7 @@ package requestcontrol import ( "context" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -43,19 +43,19 @@ type PreRequest interface { // The given pod argument is the pod that served the request. type ResponseReceived interface { plugins.Plugin - ResponseReceived(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) + ResponseReceived(ctx context.Context, request *types.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) } // ResponseStreaming is called by the director after each chunk of streaming response is sent. type ResponseStreaming interface { plugins.Plugin - ResponseStreaming(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) + ResponseStreaming(ctx context.Context, request *types.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) } // ResponseComplete is called by the director after the complete response is sent. type ResponseComplete interface { plugins.Plugin - ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) + ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetEndpoint *datalayer.EndpointMetadata) } // PrepareRequestData is called by the director before scheduling requests. @@ -63,7 +63,7 @@ type ResponseComplete interface { type PrepareDataPlugin interface { plugins.ProducerPlugin plugins.ConsumerPlugin - PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error + PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Endpoint) error } // AdmissionPlugin is called by the director after the prepare data phase and before scheduling. @@ -73,5 +73,5 @@ type AdmissionPlugin interface { plugins.Plugin // AdmitRequest returns the denial reason, wrapped as error if the request is denied. // If the request is allowed, it returns nil. - AdmitRequest(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error + AdmitRequest(ctx context.Context, request *types.LLMRequest, pods []types.Endpoint) error } diff --git a/pkg/epp/requestcontrol/plugins/test/responsereceived/destination_endpoint_served_verifier.go b/pkg/epp/requestcontrol/plugins/test/responsereceived/destination_endpoint_served_verifier.go index 216746ab1..50e07ab5f 100644 --- a/pkg/epp/requestcontrol/plugins/test/responsereceived/destination_endpoint_served_verifier.go +++ b/pkg/epp/requestcontrol/plugins/test/responsereceived/destination_endpoint_served_verifier.go @@ -22,7 +22,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" @@ -71,7 +71,7 @@ func NewDestinationEndpointServedVerifier() *DestinationEndpointServedVerifier { } // ResponseReceived is the handler for the ResponseReceived extension point. -func (p *DestinationEndpointServedVerifier) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, _ *backend.Pod) { +func (p *DestinationEndpointServedVerifier) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, _ *datalayer.EndpointMetadata) { logger := log.FromContext(ctx).WithName(p.TypedName().String()) logger.V(logging.DEBUG).Info("Verifying destination endpoint served") diff --git a/pkg/epp/saturationdetector/saturationdetector_test.go b/pkg/epp/saturationdetector/saturationdetector_test.go index 833d0b245..4696b4247 100644 --- a/pkg/epp/saturationdetector/saturationdetector_test.go +++ b/pkg/epp/saturationdetector/saturationdetector_test.go @@ -24,13 +24,13 @@ import ( "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" ) func newMockPodMetrics(name string, metrics *backendmetrics.MetricsState) *backendmetrics.FakePodMetrics { return &backendmetrics.FakePodMetrics{ - Pod: &backend.Pod{ + Metadata: &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{Name: name, Namespace: "ns1"}, }, Metrics: metrics, diff --git a/pkg/epp/scheduling/framework/plugins.go b/pkg/epp/scheduling/framework/plugins.go index 99397a4b3..b6f7aca3b 100644 --- a/pkg/epp/scheduling/framework/plugins.go +++ b/pkg/epp/scheduling/framework/plugins.go @@ -51,7 +51,7 @@ type ProfileHandler interface { // Filter defines the interface for filtering a list of pods based on context. type Filter interface { plugins.Plugin - Filter(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod + Filter(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Endpoint) []types.Endpoint } // Scorer defines the interface for scoring a list of pods based on context. @@ -60,11 +60,11 @@ type Filter interface { // If a scorer returns value lower than 0, it will be treated as score 0. type Scorer interface { plugins.Plugin - Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 + Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Endpoint) map[types.Endpoint]float64 } // Picker picks the final pod(s) to send the request to. type Picker interface { plugins.Plugin - Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult + Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredEndpoint) *types.ProfileRunResult } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index fd03e4bbb..65da7b2fb 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -215,26 +215,26 @@ func (p *Plugin) Consumes() map[string]any { return map[string]any{} } -// PrepareRequestData hashes prompt, finds longest prefix match and stores it in pod as attribute. -func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error { - hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch) +// PrepareRequestData hashes prompt, finds longest prefix match and stores it in endpoint as attribute. +func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, endpoints []types.Endpoint) error { + hashes := hashPrompt(ctx, request, getBlockSize(endpoints, p.config), p.config.MaxPrefixBlocksToMatch) state := &SchedulingContextState{ PrefixHashes: hashes, PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), } total := len(state.PrefixHashes) - for _, pod := range pods { - matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)] - pod.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(matchLen, total)) + for _, endpoint := range endpoints { + matchLen := state.PrefixCacheServers[ServerID(endpoint.GetMetadata().NamespacedName)] + endpoint.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(matchLen, total)) } return nil } // Score returns the scoring result for the given list of pods based on context. -func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, endpoints []types.Endpoint) map[types.Endpoint]float64 { // pre score step, hashing prompt and find longest prefix match. - hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch) + hashes := hashPrompt(ctx, request, getBlockSize(endpoints, p.config), p.config.MaxPrefixBlocksToMatch) state := &SchedulingContextState{ PrefixHashes: hashes, PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), @@ -244,18 +244,18 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques p.pluginState.Write(request.RequestId, plugins.StateKey(p.TypedName().String()), state) log.FromContext(ctx).V(logutil.TRACE).Info("prefix cached state", "cached-servers", state.PrefixCacheServers, "hashes", state.PrefixHashes) // calculate the scores of pods - scores := make(map[types.Pod]float64, len(pods)) + scores := make(map[types.Endpoint]float64, len(endpoints)) total := len(state.PrefixHashes) - podScoreFunc := func(pod types.Pod) float64 { + podScoreFunc := func(endpoint types.Endpoint) float64 { if total == 0 { return 0 } - matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)] + matchLen := state.PrefixCacheServers[ServerID(endpoint.GetMetadata().NamespacedName)] return float64(matchLen) / float64(total) } - for _, pod := range pods { + for _, pod := range endpoints { scores[pod] = podScoreFunc(pod) } return scores @@ -264,11 +264,11 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques // PreRequest records in the plugin cache the result of the scheduling selection. func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) { primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName] - targetPod := primaryProfileResult.TargetPods[0] // get the first pod of the primary profile + targetEndpoint := primaryProfileResult.TargetEndpoints[0] // get the first endpoint of the primary profile gpuBlocks := p.config.LRUCapacityPerServer - if p.config.AutoTune && targetPod.GetMetrics().CacheNumGPUBlocks > 0 { - gpuBlocks = targetPod.GetMetrics().CacheNumGPUBlocks + if p.config.AutoTune && targetEndpoint.GetMetrics().CacheNumGPUBlocks > 0 { + gpuBlocks = targetEndpoint.GetMetrics().CacheNumGPUBlocks } state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) @@ -285,16 +285,16 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche p.wg.Add(1) go func() { p.indexer.Add(state.PrefixHashes, Server{ - ServerID(targetPod.GetPod().NamespacedName), + ServerID(targetEndpoint.GetMetadata().NamespacedName), gpuBlocks, }) p.wg.Done() }() total := len(state.PrefixHashes) - matchLen := state.PrefixCacheServers[ServerID(targetPod.GetPod().NamespacedName)] + matchLen := state.PrefixCacheServers[ServerID(targetEndpoint.GetMetadata().NamespacedName)] - blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config) + blockSize := getBlockSize(primaryProfileResult.TargetEndpoints, p.config) metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize) } @@ -408,20 +408,20 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { return json.Marshal(request.Body.ChatCompletions.Messages) } -func getBlockSize(pods []types.Pod, config Config) int { +func getBlockSize(endpoints []types.Endpoint, config Config) int { if !config.AutoTune { return config.BlockSize } // Fallback to BlockSize if no metrics are available. - if len(pods) == 0 { + if len(endpoints) == 0 { return config.BlockSize } - // Since all PODs originate from the same inference pool, they are considered to have identical configurations. - // Therefore, using the CacheBlockSize value from the first POD suffices. - if pod := pods[0]; pod.GetMetrics() != nil { - cacheBlockSize := pod.GetMetrics().CacheBlockSize + // Since all Endpoints originate from the same inference pool, they are considered to have identical configurations. + // Therefore, using the CacheBlockSize value from the first Endpoint suffices. + if endpoint := endpoints[0]; endpoint.GetMetrics() != nil { + cacheBlockSize := endpoint.GetMetrics().CacheBlockSize if cacheBlockSize > 0 { return cacheBlockSize * averageCharactersPerToken } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 31cb51edb..5b8af56fb 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -27,7 +27,6 @@ import ( "github.com/stretchr/testify/assert" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" dplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix" @@ -47,9 +46,9 @@ func TestPrefixPluginCompletion(t *testing.T) { } plugin := New(context.Background(), config) - pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState()} - pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState()} - pods := []types.Pod{pod1, pod2} + endpoint1 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState()} + endpoint2 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState()} + endpoints := []types.Endpoint{endpoint1, endpoint2} // First request. req1 := &types.LLMRequest{ @@ -61,7 +60,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, }, } - scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + scores := plugin.Score(context.Background(), types.NewCycleState(), req1, endpoints) state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) @@ -69,14 +68,14 @@ func TestPrefixPluginCompletion(t *testing.T) { // Total hashes = 1 (the first one is for the prefix with model) assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect") assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers") - assert.Equal(t, float64(0), scores[pod1], "score for pod1") - assert.Equal(t, float64(0), scores[pod2], "score for pod2") + assert.Equal(t, float64(0), scores[endpoint1], "score for endpoint1") + assert.Equal(t, float64(0), scores[endpoint2], "score for endpoint2") // Simulate pod1 was picked. schedulingResult := &types.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*types.ProfileRunResult{ - "default": {TargetPods: []types.Pod{pod1}}, + "default": {TargetEndpoints: []types.Endpoint{endpoint1}}, }, } plugin.PreRequest(context.Background(), req1, schedulingResult) @@ -93,7 +92,7 @@ func TestPrefixPluginCompletion(t *testing.T) { }, }, } - scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods) + scores = plugin.Score(context.Background(), types.NewCycleState(), req2, endpoints) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) @@ -101,14 +100,14 @@ func TestPrefixPluginCompletion(t *testing.T) { // Total hashes = 1 (the first one is for the prefix with model) assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect") assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers") - assert.Equal(t, float64(0), scores[pod1], "score for pod1") - assert.Equal(t, float64(0), scores[pod2], "score for pod2") + assert.Equal(t, float64(0), scores[endpoint1], "score for endpoint1") + assert.Equal(t, float64(0), scores[endpoint2], "score for endpoint2") // Simulate pod2 was picked. schedulingResult = &types.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*types.ProfileRunResult{ - "default": {TargetPods: []types.Pod{pod2}}, + "default": {TargetEndpoints: []types.Endpoint{endpoint2}}, }, } plugin.PreRequest(context.Background(), req2, schedulingResult) @@ -124,21 +123,21 @@ func TestPrefixPluginCompletion(t *testing.T) { }, }, } - scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods) + scores = plugin.Score(context.Background(), types.NewCycleState(), req3, endpoints) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 8, hash block size is 4, so 2 hashes will be calculated. // Total hashes = 2 (the first one is for the prefix with model) assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") - assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") - assert.Equal(t, 0.5, scores[pod1], "score should be 0.5 - the model and the first prefix block match") - assert.Equal(t, float64(0), scores[pod2], "score for pod2") + assert.Equal(t, 1, len(state.PrefixCacheServers), "endpoint1 should have cached the aaaa prefix") + assert.Equal(t, 0.5, scores[endpoint1], "score should be 0.5 - the model and the first prefix block match") + assert.Equal(t, float64(0), scores[endpoint2], "score for pod2") schedulingResult = &types.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*types.ProfileRunResult{ - "default": {TargetPods: []types.Pod{pod1}}, + "default": {TargetEndpoints: []types.Endpoint{endpoint1}}, }, } plugin.PreRequest(context.Background(), req3, schedulingResult) @@ -154,21 +153,21 @@ func TestPrefixPluginCompletion(t *testing.T) { }, }, } - scores = plugin.Score(context.Background(), types.NewCycleState(), req4, pods) + scores = plugin.Score(context.Background(), types.NewCycleState(), req4, endpoints) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 8, hash block size is 4, so 2 hashes will be calculated. // Total hashes = 2 (the first one is for the prefix with model) assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") - assert.Equal(t, 0, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") - assert.Equal(t, float64(0), scores[pod1], "score for pod1") - assert.Equal(t, float64(0), scores[pod2], "score for pod2") + assert.Equal(t, 0, len(state.PrefixCacheServers), "endpoint1 should have cached the aaaa prefix") + assert.Equal(t, float64(0), scores[endpoint1], "score for endpoint1") + assert.Equal(t, float64(0), scores[endpoint2], "score for endpoint2") schedulingResult = &types.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*types.ProfileRunResult{ - "default": {TargetPods: []types.Pod{pod1}}, + "default": {TargetEndpoints: []types.Endpoint{endpoint1}}, }, } plugin.PreRequest(context.Background(), req4, schedulingResult) @@ -184,21 +183,21 @@ func TestPrefixPluginCompletion(t *testing.T) { }, }, } - scores = plugin.Score(context.Background(), types.NewCycleState(), req5, pods) + scores = plugin.Score(context.Background(), types.NewCycleState(), req5, endpoints) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Input size is 12, hash block size is 4, so 3 hashes will be calculated. // Total hashes = 3 (the first one is for the prefix with model) assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect") - assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") - assert.Equal(t, 2./3, scores[pod1], "score should be 2./3 - the model and the first 2 prefix blocks match") - assert.Equal(t, float64(0), scores[pod2], "score for pod2") + assert.Equal(t, 1, len(state.PrefixCacheServers), "endpoint1 should have cached the aaaa prefix") + assert.Equal(t, 2./3, scores[endpoint1], "score should be 2./3 - the model and the first 2 prefix blocks match") + assert.Equal(t, float64(0), scores[endpoint2], "score for endpoint2") schedulingResult = &types.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*types.ProfileRunResult{ - "default": {TargetPods: []types.Pod{pod1}}, + "default": {TargetEndpoints: []types.Endpoint{endpoint1}}, }, } plugin.PreRequest(context.Background(), req5, schedulingResult) @@ -213,8 +212,8 @@ func TestPrefixPluginChatCompletions(t *testing.T) { } plugin := New(context.Background(), config) - pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{}} - pods := []types.Pod{pod1} + endpoint1 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{}} + endpoints := []types.Endpoint{endpoint1} // Test with chat completions request req1 := &types.LLMRequest{ @@ -229,14 +228,14 @@ func TestPrefixPluginChatCompletions(t *testing.T) { }, }, } - scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + scores := plugin.Score(context.Background(), types.NewCycleState(), req1, endpoints) state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Chat completions - Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) // Should have some hashes for the JSON-encoded messages assert.Greater(t, len(state.PrefixHashes), 1, "should have hashes for chat completions") assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers initially") - assert.Equal(t, float64(0), scores[pod1], "score for pod1") + assert.Equal(t, float64(0), scores[endpoint1], "score for endpoint1") } func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { @@ -247,9 +246,9 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { } plugin := New(context.Background(), config) - pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{}} - pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: &backendmetrics.MetricsState{}} - pods := []types.Pod{pod1, pod2} + endpoint1 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{}} + endpoint2 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: &backendmetrics.MetricsState{}} + endpoints := []types.Endpoint{endpoint1, endpoint2} // First request with initial conversation req1 := &types.LLMRequest{ @@ -264,21 +263,21 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { }, }, } - scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + scores := plugin.Score(context.Background(), types.NewCycleState(), req1, endpoints) state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Initial conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers) initialHashCount := len(state.PrefixHashes) assert.Greater(t, initialHashCount, 1, "should have hashes for chat completions") assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers initially") - assert.Equal(t, float64(0), scores[pod1], "score for pod1") - assert.Equal(t, float64(0), scores[pod2], "score for pod2") + assert.Equal(t, float64(0), scores[endpoint1], "score for endpoint1") + assert.Equal(t, float64(0), scores[endpoint2], "score for endpoint2") // Simulate pod1 was picked schedulingResult := &types.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*types.ProfileRunResult{ - "default": {TargetPods: []types.Pod{pod1}}, + "default": {TargetEndpoints: []types.Endpoint{endpoint1}}, }, } plugin.PreRequest(context.Background(), req1, schedulingResult) @@ -299,7 +298,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { }, }, } - scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods) + scores = plugin.Score(context.Background(), types.NewCycleState(), req2, endpoints) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Extended conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers) @@ -308,10 +307,10 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { assert.Greater(t, len(state.PrefixCacheServers), 0, "should have cached servers from prefix match") // Calculate expected score - pod1 should have cached the initial prefix - cachedBlocks := state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] + cachedBlocks := state.PrefixCacheServers[ServerID(endpoint1.GetMetadata().NamespacedName)] expectedScore := float64(cachedBlocks) / float64(extendedHashCount) - assert.Equal(t, expectedScore, scores[pod1], "pod1 should have prefix cache hit") - assert.Equal(t, float64(0), scores[pod2], "pod2 should have no cache hit") + assert.Equal(t, expectedScore, scores[endpoint1], "endpoint1 should have prefix cache hit") + assert.Equal(t, float64(0), scores[endpoint2], "endpoint2 should have no cache hit") // Simulate pod1 was picked again plugin.PreRequest(context.Background(), req2, schedulingResult) @@ -334,7 +333,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { }, }, } - scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods) + scores = plugin.Score(context.Background(), types.NewCycleState(), req3, endpoints) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(t, err) t.Logf("Long conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers) @@ -343,11 +342,11 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { assert.Greater(t, len(state.PrefixCacheServers), 0, "should have cached servers from prefix match") // pod1 should have an even higher cache hit rate now - cachedBlocks = state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] + cachedBlocks = state.PrefixCacheServers[ServerID(endpoint1.GetMetadata().NamespacedName)] expectedScore = float64(cachedBlocks) / float64(longHashCount) - assert.Equal(t, expectedScore, scores[pod1], "pod1 should have higher prefix cache hit") - assert.Greater(t, scores[pod1], float64(0.5), "cache hit rate should be substantial for growing conversation") - assert.Equal(t, float64(0), scores[pod2], "pod2 should still have no cache hit") + assert.Equal(t, expectedScore, scores[endpoint1], "endpoint1 should have higher prefix cache hit") + assert.Greater(t, scores[endpoint1], float64(0.5), "cache hit rate should be substantial for growing conversation") + assert.Equal(t, float64(0), scores[endpoint2], "endpoint2 should still have no cache hit") } // TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length. @@ -373,15 +372,15 @@ func BenchmarkPrefixPluginStress(b *testing.B) { b.Run(fmt.Sprintf("messages_%d_length_%d", i, v), func(b *testing.B) { // Generate increasing-length random prompts prompt := randomPrompt(4 + v) - pod := &types.PodMetrics{ - Pod: &backend.Pod{ + endpoint := &types.PodMetrics{ + EndpointMetadata: &datalayer.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{ Name: fmt.Sprintf("random-pod-%d", v), }, }, } - pods := []types.Pod{pod} + endpoints := []types.Endpoint{endpoint} req := &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "model-stress", @@ -394,7 +393,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) { b.ResetTimer() // Benchmark the scoring operation - scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods) + scores := plugin.Score(context.Background(), types.NewCycleState(), req, endpoints) _ = scores // Use the result to prevent optimization // Clean up state for next iteration @@ -475,14 +474,14 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) { func TestPrefixPluginAutoTune(t *testing.T) { // Setup common test data podName := "pod-autotune" - pod := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: podName}}, + endpoint := &types.PodMetrics{ + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: podName}}, MetricsState: &backendmetrics.MetricsState{ CacheBlockSize: 16, // 16 tokens * 4 chars/token = 64 chars per block CacheNumGPUBlocks: 1000, // 1000 blocks capacity }, } - pods := []types.Pod{pod} + endpoints := []types.Endpoint{endpoint} req := &types.LLMRequest{ RequestId: uuid.NewString(), @@ -508,7 +507,7 @@ func TestPrefixPluginAutoTune(t *testing.T) { plugin := New(context.Background(), config) // 1. Verify Score uses pod metrics for block size - scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods) + scores := plugin.Score(context.Background(), types.NewCycleState(), req, endpoints) _ = scores state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String())) @@ -522,14 +521,14 @@ func TestPrefixPluginAutoTune(t *testing.T) { schedulingResult := &types.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*types.ProfileRunResult{ - "default": {TargetPods: []types.Pod{pod}}, + "default": {TargetEndpoints: []types.Endpoint{endpoint}}, }, } plugin.PreRequest(context.Background(), req, schedulingResult) plugin.wg.Wait() // Check indexer state - assert.Contains(t, plugin.indexer.Pods(), ServerID(pod.GetPod().NamespacedName)) + assert.Contains(t, plugin.indexer.Pods(), ServerID(endpoint.GetMetadata().NamespacedName)) }) t.Run("AutoTune Disabled", func(t *testing.T) { @@ -543,7 +542,7 @@ func TestPrefixPluginAutoTune(t *testing.T) { // 1. Verify Score uses config BlockSize req.RequestId = uuid.NewString() // New request ID - scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods) + scores := plugin.Score(context.Background(), types.NewCycleState(), req, endpoints) _ = scores state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String())) @@ -557,13 +556,13 @@ func TestPrefixPluginAutoTune(t *testing.T) { schedulingResult := &types.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*types.ProfileRunResult{ - "default": {TargetPods: []types.Pod{pod}}, + "default": {TargetEndpoints: []types.Endpoint{endpoint}}, }, } plugin.PreRequest(context.Background(), req, schedulingResult) plugin.wg.Wait() - assert.Contains(t, plugin.indexer.Pods(), ServerID(pod.GetPod().NamespacedName)) + assert.Contains(t, plugin.indexer.Pods(), ServerID(endpoint.GetMetadata().NamespacedName)) }) } @@ -585,9 +584,9 @@ func TestPrepareRequestData(t *testing.T) { } plugin := New(context.Background(), config) - pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState(), AttributeMap: datalayer.NewAttributes()} - pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState(), AttributeMap: datalayer.NewAttributes()} - pods := []types.Pod{pod1, pod2} + endpoint1 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState(), AttributeMap: datalayer.NewAttributes()} + endpoint2 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState(), AttributeMap: datalayer.NewAttributes()} + endpoints := []types.Endpoint{endpoint1, endpoint2} // First request to populate cache. req1 := &types.LLMRequest{ @@ -599,11 +598,11 @@ func TestPrepareRequestData(t *testing.T) { }, }, } - _ = plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + _ = plugin.Score(context.Background(), types.NewCycleState(), req1, endpoints) schedulingResult := &types.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*types.ProfileRunResult{ - "default": {TargetPods: []types.Pod{pod1}}, + "default": {TargetEndpoints: []types.Endpoint{endpoint1}}, }, } plugin.PreRequest(context.Background(), req1, schedulingResult) @@ -620,18 +619,18 @@ func TestPrepareRequestData(t *testing.T) { }, } - err := plugin.PrepareRequestData(context.Background(), req2, pods) + err := plugin.PrepareRequestData(context.Background(), req2, endpoints) assert.NoError(t, err) // Verify pod1 has the correct prefix match info - info1, ok := pod1.Get(dplugins.PrefixCacheMatchInfoKey) + info1, ok := endpoint1.Get(dplugins.PrefixCacheMatchInfoKey) assert.True(t, ok) prefixInfo1 := info1.(*dplugins.PrefixCacheMatchInfo) assert.Equal(t, 1, prefixInfo1.MatchLength()) // "aaaa" matches assert.Equal(t, 2, prefixInfo1.TotalLength()) // "aaaacccc" -> 2 blocks // Verify pod2 has no match info - info2, ok := pod2.Get(dplugins.PrefixCacheMatchInfoKey) + info2, ok := endpoint2.Get(dplugins.PrefixCacheMatchInfoKey) assert.True(t, ok) prefixInfo2 := info2.(*dplugins.PrefixCacheMatchInfo) assert.Equal(t, 0, prefixInfo2.MatchLength()) // No match for pod2 @@ -678,14 +677,14 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) { messages[i] = types.Message{Role: role, Content: types.Content{Raw: content}} } - pod := &types.PodMetrics{ - Pod: &backend.Pod{ + endpoint := &types.PodMetrics{ + EndpointMetadata: &datalayer.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{ Name: fmt.Sprintf("chat-pod-%d-%d", scenario.messageCount, scenario.messageLength), }, }, } - pods := []types.Pod{pod} + endpoints := []types.Endpoint{endpoint} req := &types.LLMRequest{ RequestId: uuid.NewString(), @@ -700,7 +699,7 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { // Benchmark the scoring operation - scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods) + scores := plugin.Score(context.Background(), types.NewCycleState(), req, endpoints) _ = scores // Use the result to prevent optimization // Clean up state for next iteration diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go index c28ec6f47..dd24a323b 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/helpers.go @@ -26,7 +26,7 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds []podPredictionResult, r *rand.Rand, strategy headroomStrategy) schedulingtypes.Pod { +func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds []endpointPredictionResult, r *rand.Rand, strategy headroomStrategy) schedulingtypes.Endpoint { total := 0 choices := s.buildCompositeChoices( ctx, allPreds, s.config.CompositeKVWeight, s.config.CompositeQueueWeight, s.config.CompositePrefixWeight, &total, @@ -37,10 +37,10 @@ func (s *SLOAwareRouter) selectFromCompositeScores(ctx context.Context, allPreds choices[i].weight = minWeight + wMax - choices[i].weight } } - selectedPod := s.performWeightedRandomSelection(choices, total, allPreds, r) - return selectedPod + selectedEndpoint := s.performWeightedRandomSelection(choices, total, allPreds, r) + return selectedEndpoint } -func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []choice, total int, candidates []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { +func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []choice, total int, candidates []endpointPredictionResult, r *rand.Rand) schedulingtypes.Endpoint { if total == 0 { return nil } @@ -50,43 +50,43 @@ func (s *SLOAwareRouter) performWeightedRandomSelection(weightedChoices []choice logger.V(logutil.DEBUG).Info("Pod selection mode: MAX - selecting pod with highest weight") maxWeight := 0 - var selectedPod schedulingtypes.Pod + var selectedEndpoint schedulingtypes.Endpoint for _, c := range weightedChoices { if c.weight > maxWeight { maxWeight = c.weight - selectedPod = c.podName + selectedEndpoint = c.endpointName } } - if selectedPod != nil { - return selectedPod + if selectedEndpoint != nil { + return selectedEndpoint } // Fallback to first pod if no selection made - return candidates[0].Pod + return candidates[0].Endpoint } // Original weighted random selection logic logger.V(logutil.DEBUG).Info("Pod selection mode: LINEAR - performing weighted random selection") idx := r.Intn(total) - var selectedPod schedulingtypes.Pod + var selectedEndpoint schedulingtypes.Endpoint for _, c := range weightedChoices { if idx < c.weight { - selectedPod = c.podName + selectedEndpoint = c.endpointName break } idx -= c.weight } - // If no pod was selected (shouldn't happen), fallback to first pod - if selectedPod == nil { - selectedPod = candidates[0].Pod + // If no endpoint was selected (shouldn't happen), fallback to first endpoint + if selectedEndpoint == nil { + selectedEndpoint = candidates[0].Endpoint } - return selectedPod + return selectedEndpoint } func (s *SLOAwareRouter) buildCompositeChoices( ctx context.Context, - candidates []podPredictionResult, + candidates []endpointPredictionResult, wkv, wq, wpref float64, total *int, ) []choice { @@ -104,9 +104,9 @@ func (s *SLOAwareRouter) buildCompositeChoices( // Precompute queue stats minQ, maxQ := math.MaxInt32, -1 queueCounts := make(map[string]int, len(candidates)) - for _, p := range candidates { - q := p.Pod.GetMetrics().WaitingQueueSize - queueCounts[p.Pod.GetPod().String()] = q + for _, e := range candidates { + q := e.Endpoint.GetMetrics().WaitingQueueSize + queueCounts[e.Endpoint.GetMetadata().String()] = q if q < minQ { minQ = q } @@ -118,23 +118,23 @@ func (s *SLOAwareRouter) buildCompositeChoices( choices := make([]choice, 0, len(candidates)) for _, p := range candidates { - q := queueCounts[p.Pod.GetPod().String()] + q := queueCounts[p.Endpoint.GetMetadata().String()] relQueue := 1.0 if den > 0 { relQueue = (float64(maxQ-q) / den) } - kvUsage := p.Pod.GetMetrics().KVCacheUsagePercent + kvUsage := p.Endpoint.GetMetrics().KVCacheUsagePercent kvFree := (1.0 - kvUsage) prefix := (p.PrefixCacheScore) composite := wkv*kvFree + wq*relQueue + wpref*prefix w := int(math.Round(float64(minWeight) + (float64(wMax-minWeight) * composite))) *total += w - choices = append(choices, choice{podName: p.Pod, weight: w}) + choices = append(choices, choice{endpointName: p.Endpoint, weight: w}) log.FromContext(ctx).V(logutil.TRACE).Info("Composite (neg/pos) score", - "pod", p.Pod.GetPod().String(), + "endpoint", p.Endpoint.GetMetadata().String(), "kvUsage", kvUsage, "kvFree", kvFree, "queue", q, "relQueue", relQueue, "prefix", prefix, diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go index 03d2ce59a..c70a19189 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go @@ -36,10 +36,10 @@ import ( // refreshLastSeenMetrics updates sloCtx.LastSeenMetrics from the latest scheduling result. func refreshLastSeenMetrics(ctx context.Context, sloCtx *sloRequestContext) { if sr := sloCtx.schedulingResult; sr != nil { - if pr := sr.ProfileResults[sr.PrimaryProfileName]; pr != nil && pr.TargetPods != nil { + if pr := sr.ProfileResults[sr.PrimaryProfileName]; pr != nil && pr.TargetEndpoints != nil { for profileName, profileResult := range sr.ProfileResults { - if profileResult != nil && profileResult.TargetPods != nil && len(profileResult.TargetPods) > 0 { - sloCtx.lastSeenMetrics[profileName] = profileResult.TargetPods[0].GetMetrics().Clone() + if profileResult != nil && profileResult.TargetEndpoints != nil && len(profileResult.TargetEndpoints) > 0 { + sloCtx.lastSeenMetrics[profileName] = profileResult.TargetEndpoints[0].GetMetrics().Clone() } } } @@ -80,8 +80,8 @@ func processHeaderForLatencyPrediction( return err } - targetPod := sloCtx.targetPod - prefix_cache_score := sloCtx.prefixCacheScoresForPods[targetPod.String()] + targetPod := sloCtx.targetMetadata + prefix_cache_score := sloCtx.prefixCacheScoresForEndpoints[targetPod.String()] in := latencypredictor.PredictionRequest{ KVCachePercentage: m.KVCacheUsagePercent, @@ -135,8 +135,8 @@ func processFirstTokenForLatencyPrediction( logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err) return } - targetPod := sloCtx.targetPod - prefixCacheScore := sloCtx.prefixCacheScoresForPods[targetPod.String()] + targetPod := sloCtx.targetMetadata + prefixCacheScore := sloCtx.prefixCacheScoresForEndpoints[targetPod.String()] recordTTFTTrainingData(ctx, predictor, sloCtx, m, now, prefixCacheScore) diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go index f31932538..c3096edd2 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/prediction.go @@ -27,8 +27,8 @@ import ( latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" ) -type podPredictionResult struct { - Pod schedulingtypes.Pod +type endpointPredictionResult struct { + Endpoint schedulingtypes.Endpoint TTFT float64 TPOT float64 TTFTValid bool @@ -41,26 +41,26 @@ type podPredictionResult struct { } // generatePredictions creates prediction results for all candidate pods -func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidatePods []schedulingtypes.Pod) ([]podPredictionResult, error) { +func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, sloCtx *sloRequestContext, candidateEndpoints []schedulingtypes.Endpoint) ([]endpointPredictionResult, error) { logger := log.FromContext(ctx) - predictions := make([]podPredictionResult, 0, len(candidatePods)) + predictions := make([]endpointPredictionResult, 0, len(candidateEndpoints)) // Prepare inputs for bulk prediction - metricsStates := make([]*backendmetrics.MetricsState, len(candidatePods)) - prompts := make([]string, len(candidatePods)) - generatedTokenCounts := make([]int, len(candidatePods)) - prefixCacheScores := make([]float64, len(candidatePods)) + metricsStates := make([]*backendmetrics.MetricsState, len(candidateEndpoints)) + prompts := make([]string, len(candidateEndpoints)) + generatedTokenCounts := make([]int, len(candidateEndpoints)) + prefixCacheScores := make([]float64, len(candidateEndpoints)) - for i, pod := range candidatePods { - logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "pod", pod.GetPod().String(), "metrics", pod.GetMetrics().String()) + for i, endpoint := range candidateEndpoints { + logger.V(logutil.TRACE).Info("Candidate pod for scheduling", "endpoint", endpoint.GetMetadata().String(), "metrics", endpoint.GetMetrics().String()) // Get prefix cache score for the pod - prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) - sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore + prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, endpoint) + sloCtx.prefixCacheScoresForEndpoints[endpoint.GetMetadata().String()] = prefixCacheScore - logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "prefixCacheScore", prefixCacheScore) + logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", endpoint.GetMetadata().String(), "prefixCacheScore", prefixCacheScore) - metricsStates[i] = pod.GetMetrics() + metricsStates[i] = endpoint.GetMetrics() prompts[i] = request.Body.Completions.Prompt generatedTokenCounts[i] = 1 prefixCacheScores[i] = prefixCacheScore @@ -74,19 +74,19 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul } // Process results - for i, pod := range candidatePods { + for i, endpoint := range candidateEndpoints { prediction := bulkPredictions[i] - predResult := podPredictionResult{Pod: pod} + predResult := endpointPredictionResult{Endpoint: endpoint} predResult.PrefixCacheScore = prefixCacheScores[i] predResult.TTFT = prediction.TTFT predResult.TPOT = prediction.TPOT - podMinTPOTSLO := s.getPodMinTPOTSLO(pod) + podMinTPOTSLO := s.getEndpointMinTPOTSLO(endpoint) predResult.TTFTValid, predResult.TPOTValid, predResult.IsValid, predResult.Headroom, predResult.TTFTHeadroom = s.validatePrediction(prediction, sloCtx, podMinTPOTSLO) logger.V(logutil.DEBUG).Info("Prediction for scheduling", - "pod", pod.GetPod().String(), + "endpoint", endpoint.GetMetadata().String(), "prefixCacheScore", predResult.PrefixCacheScore, "TTFT", prediction.TTFT, "TPOT", prediction.TPOT, @@ -107,18 +107,18 @@ func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedul } // updateRequestContextWithPredictions updates the request context with prediction data -func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *sloRequestContext, predictions []podPredictionResult) { +func (s *SLOAwareRouter) updateRequestContextWithPredictions(sloCtx *sloRequestContext, predictions []endpointPredictionResult) { for _, pred := range predictions { if pred.Error == nil { - podKey := pred.Pod.GetPod().String() + endpointKey := pred.Endpoint.GetMetadata().String() if sloCtx.predictedTTFTForScheduling == nil { sloCtx.predictedTTFTForScheduling = make(map[string]float64) } if sloCtx.predictedTPOTForScheduling == nil { sloCtx.predictedTPOTForScheduling = make(map[string]float64) } - sloCtx.predictedTTFTForScheduling[podKey] = pred.TTFT - sloCtx.predictedTPOTForScheduling[podKey] = pred.TPOT + sloCtx.predictedTTFTForScheduling[endpointKey] = pred.TTFT + sloCtx.predictedTPOTForScheduling[endpointKey] = pred.TPOT } } } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go index d91aac2cf..60ce42899 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go @@ -26,8 +26,8 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" @@ -42,7 +42,7 @@ var _ requestcontrol.ResponseComplete = &SLOAwareRouter{} type sloRequestContext struct { schedulingRequest schedulingtypes.LLMRequest - targetPod *backend.Pod + targetMetadata *datalayer.EndpointMetadata schedulingResult *schedulingtypes.SchedulingResult lastSeenMetrics map[string]*backendmetrics.MetricsState lastTokenTimestamp time.Time @@ -57,7 +57,7 @@ type sloRequestContext struct { tpotObservations []float64 predictedTPOTObservations []float64 - prefixCacheScoresForPods map[string]float64 + prefixCacheScoresForEndpoints map[string]float64 // ttftSLO is the target time to first token SLO for the request. ttftSLO float64 @@ -66,22 +66,22 @@ type sloRequestContext struct { // predictorBasedScheduling indicates whether to use predictor based scheduling. predictorBasedScheduling bool - // predictedTTFTForScheduling is the map of pod names to predicted TTFT values for scheduling. + // predictedTTFTForScheduling is the map of endpoint names to predicted TTFT values for scheduling. predictedTTFTForScheduling map[string]float64 - // predictedTPOTForScheduling is the map of pod names to predicted TPOT values for scheduling. + // predictedTPOTForScheduling is the map of endpoint names to predicted TPOT values for scheduling. predictedTPOTForScheduling map[string]float64 - // boolean set if request has valid pod based on predictions - hasValidPod bool + // boolean set if request has valid endpoint based on predictions + hasValidEndpoint bool } func newSLORequestContext(request *schedulingtypes.LLMRequest) *sloRequestContext { return &sloRequestContext{ - schedulingRequest: *request, - lastSeenMetrics: make(map[string]*backendmetrics.MetricsState), - prefixCacheScoresForPods: make(map[string]float64), - predictedTTFTForScheduling: make(map[string]float64), - predictedTPOTForScheduling: make(map[string]float64), + schedulingRequest: *request, + lastSeenMetrics: make(map[string]*backendmetrics.MetricsState), + prefixCacheScoresForEndpoints: make(map[string]float64), + predictedTTFTForScheduling: make(map[string]float64), + predictedTPOTForScheduling: make(map[string]float64), } } @@ -117,26 +117,26 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype return } - targetPod := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName].TargetPods[0].GetPod() - if !t.checkPredictor(logger, targetPod) { + targetMetadata := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName].TargetEndpoints[0].GetMetadata() + if !t.checkPredictor(logger, targetMetadata) { return } - podName := types.NamespacedName{ - Name: targetPod.NamespacedName.Name, - Namespace: targetPod.NamespacedName.Namespace, + endpointName := types.NamespacedName{ + Name: targetMetadata.NamespacedName.Name, + Namespace: targetMetadata.NamespacedName.Namespace, } - logger.V(logutil.TRACE).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "podName", podName) + logger.V(logutil.TRACE).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "endpointName", endpointName) if request.Headers[requtil.RequestIdHeaderKey] == "" { logger.V(logutil.DEBUG).Error(errors.New("missing request ID"), "SLOAwareRouter.PreRequest: Request is missing request ID header") } id := request.Headers[requtil.RequestIdHeaderKey] - podRequestList, ok := t.runningRequestLists[podName] + endpointRequestList, ok := t.runningRequestLists[endpointName] if !ok { - podRequestList = newRequestPriorityQueue() - t.runningRequestLists[podName] = podRequestList + endpointRequestList = newRequestPriorityQueue() + t.runningRequestLists[endpointName] = endpointRequestList } sloCtx, err := t.getSLOContextForRequest(request) @@ -146,26 +146,26 @@ func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtype return } - added := podRequestList.Add(id, sloCtx.avgTPOTSLO) + added := endpointRequestList.Add(id, sloCtx.avgTPOTSLO) if !added { - logger.V(logutil.TRACE).Info("SLOAwareRouter: Item already exists in queue", "podName", podName, "requestID", id) + logger.V(logutil.TRACE).Info("SLOAwareRouter: Item already exists in queue", "endpointName", endpointName, "requestID", id) } // Set up SLO request context - sloCtx.targetPod = targetPod + sloCtx.targetMetadata = targetMetadata sloCtx.schedulingResult = schedulingResult sloCtx.requestReceivedTimestamp = time.Now() refreshLastSeenMetrics(ctx, sloCtx) t.setSLOContextForRequest(request, sloCtx) } -func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { +func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetMetadata *datalayer.EndpointMetadata) { logger := log.FromContext(ctx) if request == nil { logger.V(logutil.DEBUG).Info("SLOAwareRouter.ResponseReceived: request is nil, skipping") return } - if !t.checkPredictor(logger, targetPod) { + if !t.checkPredictor(logger, targetMetadata) { return } @@ -183,13 +183,13 @@ func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *scheduli } -func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { +func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, metadata *datalayer.EndpointMetadata) { logger := log.FromContext(ctx) if request == nil { logger.V(logutil.DEBUG).Info("SLOAwareRouter.ResponseStreaming: request is nil, skipping") return } - if !t.checkPredictor(logger, pod) || response.EndOfStream { + if !t.checkPredictor(logger, metadata) || response.EndOfStream { return } @@ -209,14 +209,14 @@ func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedul } -func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) { +func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, metadata *datalayer.EndpointMetadata) { logger := log.FromContext(ctx) if request == nil { logger.V(logutil.DEBUG).Info("SLOAwareRouter.ResponseComplete: request is nil, skipping") return } - targetPod := pod - if !t.checkPredictor(logger, targetPod) { + targetMetadata := metadata + if !t.checkPredictor(logger, targetMetadata) { return } @@ -247,28 +247,28 @@ func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *scheduli logger.V(logutil.TRACE).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", sloCtx.predictorBasedScheduling) - podName := types.NamespacedName{ - Name: targetPod.NamespacedName.Name, - Namespace: targetPod.NamespacedName.Namespace, + endpointName := types.NamespacedName{ + Name: targetMetadata.NamespacedName.Name, + Namespace: targetMetadata.NamespacedName.Namespace, } id := request.Headers[requtil.RequestIdHeaderKey] - podRequestList, ok := t.runningRequestLists[podName] + endpointRequestList, ok := t.runningRequestLists[endpointName] if !ok { - err := fmt.Errorf("no running request list found for pod %s", podName.String()) + err := fmt.Errorf("no running request list found for endpoint %s", endpointName.String()) logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Failed to remove request from queue", "requestID", id) } - _, removed := podRequestList.Remove(id) + _, removed := endpointRequestList.Remove(id) if !removed { - logger.V(logutil.TRACE).Info("SLOAwareRouter: Item not found in queue", "podName", podName, "requestID", id) + logger.V(logutil.TRACE).Info("SLOAwareRouter: Item not found in queue", "endpointName", endpointName, "requestID", id) } t.deleteSLOContextForRequest(request) } -func (t *SLOAwareRouter) checkPredictor(logger logr.Logger, targetPod *backend.Pod) bool { - if targetPod == nil { - logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because no target pod was provided.") +func (t *SLOAwareRouter) checkPredictor(logger logr.Logger, metadata *datalayer.EndpointMetadata) bool { + if metadata == nil { + logger.V(logutil.TRACE).Info("SLOAwareRouter: Skipping hook because no target metadata was provided.") return false } if t.latencypredictor == nil { diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go index ac7344e30..be4f76168 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks_test.go @@ -28,8 +28,8 @@ import ( "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" @@ -44,15 +44,15 @@ const ( // Helper functions -func createTestSchedulingResult(pod *backend.Pod) *schedulingtypes.SchedulingResult { +func createTestSchedulingResult(metadata *datalayer.EndpointMetadata) *schedulingtypes.SchedulingResult { - mockPod := createTestPod(pod.NamespacedName.Name, kvUsage, runningRequests, waitingQueue) + mockPod := createTestEndpoint(metadata.NamespacedName.Name, kvUsage, runningRequests, waitingQueue) return &schedulingtypes.SchedulingResult{ PrimaryProfileName: "default", ProfileResults: map[string]*schedulingtypes.ProfileRunResult{ "default": { - TargetPods: []schedulingtypes.Pod{mockPod}, + TargetEndpoints: []schedulingtypes.Endpoint{mockPod}, }, }, } @@ -76,11 +76,11 @@ func TestNewSLORequestContext(t *testing.T) { assert.NotNil(t, ctx) assert.Equal(t, *request, ctx.schedulingRequest) assert.NotNil(t, ctx.lastSeenMetrics) - assert.NotNil(t, ctx.prefixCacheScoresForPods) + assert.NotNil(t, ctx.prefixCacheScoresForEndpoints) assert.NotNil(t, ctx.predictedTTFTForScheduling) assert.NotNil(t, ctx.predictedTPOTForScheduling) assert.Empty(t, ctx.lastSeenMetrics) - assert.Empty(t, ctx.prefixCacheScoresForPods) + assert.Empty(t, ctx.prefixCacheScoresForEndpoints) } func TestSLOAwareRouter_SetAndGetSLOContext(t *testing.T) { @@ -161,9 +161,9 @@ func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) - schedulingResult := createTestSchedulingResult(pod.GetPod()) + schedulingResult := createTestSchedulingResult(endpoint.GetMetadata()) // Create and set initial SLO context sloCtx := newSLORequestContext(request) @@ -171,7 +171,7 @@ func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { router.setSLOContextForRequest(request, sloCtx) // Initialize the request priority queue - router.runningRequestLists[pod.GetPod().NamespacedName] = newRequestPriorityQueue() + router.runningRequestLists[endpoint.GetMetadata().NamespacedName] = newRequestPriorityQueue() beforeTime := time.Now() router.PreRequest(ctx, request, schedulingResult) @@ -180,7 +180,7 @@ func TestSLOAwareRouter_PreRequest_Success(t *testing.T) { // Verify SLO context was updated retrievedCtx, err := router.getSLOContextForRequest(request) require.NoError(t, err) - assert.Equal(t, pod.GetPod(), retrievedCtx.targetPod) + assert.Equal(t, endpoint.GetMetadata(), retrievedCtx.targetMetadata) assert.Equal(t, schedulingResult, retrievedCtx.schedulingResult) assert.True(t, retrievedCtx.requestReceivedTimestamp.After(beforeTime) || retrievedCtx.requestReceivedTimestamp.Equal(beforeTime)) @@ -194,9 +194,9 @@ func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) - schedulingResult := createTestSchedulingResult(pod.GetPod()) + schedulingResult := createTestSchedulingResult(endpoint.GetMetadata()) // Create and set initial SLO context sloCtx := newSLORequestContext(request) @@ -207,8 +207,8 @@ func TestSLOAwareRouter_PreRequest_AddsToQueue(t *testing.T) { router.PreRequest(ctx, request, schedulingResult) // Verify queue was created and request was added - queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] - assert.True(t, exists, "Queue should be created for pod") + queue, exists := router.runningRequestLists[endpoint.GetMetadata().NamespacedName] + assert.True(t, exists, "Queue should be created for endpoint") assert.NotNil(t, queue) } @@ -218,10 +218,10 @@ func TestSLOAwareRouter_PreRequest_QueueAlreadyExists(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request1 := createTestLLMRequest("test-id-1", 100, 50, true) request2 := createTestLLMRequest("test-id-2", 100, 50, true) - schedulingResult := createTestSchedulingResult(pod.GetPod()) + schedulingResult := createTestSchedulingResult(endpoint.GetMetadata()) // Create and set initial SLO contexts sloCtx1 := newSLORequestContext(request1) @@ -239,7 +239,7 @@ func TestSLOAwareRouter_PreRequest_QueueAlreadyExists(t *testing.T) { router.PreRequest(ctx, request2, schedulingResult) // Verify both are in the same queue - queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + queue, exists := router.runningRequestLists[endpoint.GetMetadata().NamespacedName] assert.True(t, exists) assert.NotNil(t, queue) } @@ -249,7 +249,7 @@ func TestSLOAwareRouter_ResponseReceived_NilPredictor(t *testing.T) { router.latencypredictor = nil ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} @@ -257,7 +257,7 @@ func TestSLOAwareRouter_ResponseReceived_NilPredictor(t *testing.T) { router.setSLOContextForRequest(request, sloCtx) // Should not panic and should return early - router.ResponseReceived(ctx, request, response, pod.GetPod()) + router.ResponseReceived(ctx, request, response, endpoint.GetMetadata()) // Context should still exist _, err := router.getSLOContextForRequest(request) @@ -289,12 +289,12 @@ func TestSLOAwareRouter_ResponseReceived_NoContext(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} // Don't set SLO context - router.ResponseReceived(ctx, request, response, pod.GetPod()) + router.ResponseReceived(ctx, request, response, endpoint.GetMetadata()) // Should handle missing context gracefully @@ -305,7 +305,7 @@ func TestSLOAwareRouter_ResponseStreaming_NilPredictor(t *testing.T) { router.latencypredictor = nil ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} @@ -313,7 +313,7 @@ func TestSLOAwareRouter_ResponseStreaming_NilPredictor(t *testing.T) { router.setSLOContextForRequest(request, sloCtx) // Should not panic and should return early - router.ResponseStreaming(ctx, request, response, pod.GetPod()) + router.ResponseStreaming(ctx, request, response, endpoint.GetMetadata()) // Context should still exist _, err := router.getSLOContextForRequest(request) @@ -325,10 +325,10 @@ func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - schedulingResult := createTestSchedulingResult(pod.GetPod()) + schedulingResult := createTestSchedulingResult(endpoint.GetMetadata()) sloCtx := newSLORequestContext(request) sloCtx.requestReceivedTimestamp = time.Now() @@ -355,10 +355,10 @@ func TestSLOAwareRouter_ResponseStreaming_FirstToken(t *testing.T) { // Initialize the queue and add the request queue := newRequestPriorityQueue() queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) - router.runningRequestLists[pod.GetPod().NamespacedName] = queue + router.runningRequestLists[endpoint.GetMetadata().NamespacedName] = queue beforeTime := time.Now() - router.ResponseStreaming(ctx, request, response, pod.GetPod()) + router.ResponseStreaming(ctx, request, response, endpoint.GetMetadata()) afterTime := time.Now() // Verify first token timestamp was set @@ -376,10 +376,10 @@ func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - schedulingResult := createTestSchedulingResult(pod.GetPod()) + schedulingResult := createTestSchedulingResult(endpoint.GetMetadata()) sloCtx := newSLORequestContext(request) sloCtx.requestReceivedTimestamp = time.Now() @@ -408,9 +408,9 @@ func TestSLOAwareRouter_ResponseStreaming_SubsequentTokens(t *testing.T) { // Initialize the queue and add the request queue := newRequestPriorityQueue() queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) - router.runningRequestLists[pod.GetPod().NamespacedName] = queue + router.runningRequestLists[endpoint.GetMetadata().NamespacedName] = queue - router.ResponseStreaming(ctx, request, response, pod.GetPod()) + router.ResponseStreaming(ctx, request, response, endpoint.GetMetadata()) // Verify token timestamp was updated retrievedCtx, err := router.getSLOContextForRequest(request) @@ -424,20 +424,20 @@ func TestSLOAwareRouter_ResponseComplete_QueueNotFound(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + pod := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} sloCtx := newSLORequestContext(request) sloCtx.incomingModelName = testModelName - sloCtx.targetPod = pod.GetPod() // ADD THIS to avoid other issues + sloCtx.targetMetadata = pod.GetMetadata() // ADD THIS to avoid other issues router.setSLOContextForRequest(request, sloCtx) // Create an EMPTY queue (not nil, but empty) to test queue.Remove behavior - router.runningRequestLists[pod.GetPod().NamespacedName] = newRequestPriorityQueue() + router.runningRequestLists[pod.GetMetadata().NamespacedName] = newRequestPriorityQueue() // Should handle gracefully when request is not in queue - router.ResponseComplete(ctx, request, response, pod.GetPod()) + router.ResponseComplete(ctx, request, response, pod.GetMetadata()) // Context should be deleted _, err := router.getSLOContextForRequest(request) @@ -449,12 +449,12 @@ func TestSLOAwareRouter_ResponseStreaming_NoContext(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} // Don't set SLO context - should handle gracefully - router.ResponseStreaming(ctx, request, response, pod.GetPod()) + router.ResponseStreaming(ctx, request, response, endpoint.GetMetadata()) // Should not panic @@ -466,13 +466,13 @@ func TestSLOAwareRouter_ResponseComplete_Success(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} // Create queue and add request queue := newRequestPriorityQueue() - router.runningRequestLists[pod.GetPod().NamespacedName] = queue + router.runningRequestLists[endpoint.GetMetadata().NamespacedName] = queue queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) sloCtx := newSLORequestContext(request) @@ -485,7 +485,7 @@ func TestSLOAwareRouter_ResponseComplete_Success(t *testing.T) { sloCtx.incomingModelName = "incoming-model" router.setSLOContextForRequest(request, sloCtx) - router.ResponseComplete(ctx, request, response, pod.GetPod()) + router.ResponseComplete(ctx, request, response, endpoint.GetMetadata()) // Verify context was deleted _, err := router.getSLOContextForRequest(request) @@ -500,7 +500,7 @@ func TestSLOAwareRouter_ResponseComplete_NilPredictor(t *testing.T) { router.latencypredictor = nil ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} @@ -508,7 +508,7 @@ func TestSLOAwareRouter_ResponseComplete_NilPredictor(t *testing.T) { router.setSLOContextForRequest(request, sloCtx) // Should not panic - router.ResponseComplete(ctx, request, response, pod.GetPod()) + router.ResponseComplete(ctx, request, response, endpoint.GetMetadata()) // Context should still exist (deletion happens only with predictor) _, err := router.getSLOContextForRequest(request) @@ -541,12 +541,12 @@ func TestSLOAwareRouter_ResponseComplete_NoContext(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} // Don't set SLO context - should handle gracefully - router.ResponseComplete(ctx, request, response, pod.GetPod()) + router.ResponseComplete(ctx, request, response, endpoint.GetMetadata()) // Should not panic @@ -558,13 +558,13 @@ func TestSLOAwareRouter_ResponseComplete_WithMetrics(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} // Create queue queue := newRequestPriorityQueue() - router.runningRequestLists[pod.GetPod().NamespacedName] = queue + router.runningRequestLists[endpoint.GetMetadata().NamespacedName] = queue queue.Add(request.Headers[requtil.RequestIdHeaderKey], 50.0) sloCtx := newSLORequestContext(request) @@ -578,7 +578,7 @@ func TestSLOAwareRouter_ResponseComplete_WithMetrics(t *testing.T) { router.setSLOContextForRequest(request, sloCtx) // Should record metrics without panicking - router.ResponseComplete(ctx, request, response, pod.GetPod()) + router.ResponseComplete(ctx, request, response, endpoint.GetMetadata()) // Verify cleanup _, err := router.getSLOContextForRequest(request) @@ -591,13 +591,13 @@ func TestSLOAwareRouter_ResponseComplete_NoSLOs(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test-id", 0, 0, true) // No SLOs response := &requestcontrol.Response{} // Create queue queue := newRequestPriorityQueue() - router.runningRequestLists[pod.GetPod().NamespacedName] = queue + router.runningRequestLists[endpoint.GetMetadata().NamespacedName] = queue queue.Add(request.Headers[requtil.RequestIdHeaderKey], 0) sloCtx := newSLORequestContext(request) @@ -607,7 +607,7 @@ func TestSLOAwareRouter_ResponseComplete_NoSLOs(t *testing.T) { router.setSLOContextForRequest(request, sloCtx) // Should handle missing SLOs gracefully - router.ResponseComplete(ctx, request, response, pod.GetPod()) + router.ResponseComplete(ctx, request, response, endpoint.GetMetadata()) // Verify cleanup _, err := router.getSLOContextForRequest(request) @@ -627,9 +627,9 @@ func TestSLOAwareRouter_CheckPredictor_NilPredictor(t *testing.T) { router := createTestRouter() router.latencypredictor = nil logger := logr.Discard() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) - result := router.checkPredictor(logger, pod.GetPod()) + result := router.checkPredictor(logger, endpoint.GetMetadata()) assert.False(t, result) } @@ -639,9 +639,9 @@ func TestSLOAwareRouter_CheckPredictor_Success(t *testing.T) { mockPredictor := new(mockPredictor) router.latencypredictor = mockPredictor logger := logr.Discard() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) - result := router.checkPredictor(logger, pod.GetPod()) + result := router.checkPredictor(logger, endpoint.GetMetadata()) assert.True(t, result) } @@ -652,7 +652,7 @@ func TestSLORequestContext_Fields(t *testing.T) { // Test all field initialization assert.NotNil(t, ctx.lastSeenMetrics) - assert.NotNil(t, ctx.prefixCacheScoresForPods) + assert.NotNil(t, ctx.prefixCacheScoresForEndpoints) assert.NotNil(t, ctx.predictedTTFTForScheduling) assert.NotNil(t, ctx.predictedTPOTForScheduling) assert.Empty(t, ctx.tpotObservations) @@ -660,7 +660,7 @@ func TestSLORequestContext_Fields(t *testing.T) { assert.Zero(t, ctx.generatedTokenCount) assert.Zero(t, ctx.ttft) assert.Zero(t, ctx.avgTPOT) - assert.Nil(t, ctx.targetPod) + assert.Nil(t, ctx.targetMetadata) assert.Nil(t, ctx.schedulingResult) assert.Nil(t, ctx.tokenSampler) } @@ -702,13 +702,13 @@ func TestSLORequestContext_PrefixCacheScores(t *testing.T) { ctx := newSLORequestContext(request) // Set prefix cache scores - ctx.prefixCacheScoresForPods["pod1"] = 0.8 - ctx.prefixCacheScoresForPods["pod2"] = 0.6 - ctx.prefixCacheScoresForPods["pod3"] = 0.9 + ctx.prefixCacheScoresForEndpoints["pod1"] = 0.8 + ctx.prefixCacheScoresForEndpoints["pod2"] = 0.6 + ctx.prefixCacheScoresForEndpoints["pod3"] = 0.9 - assert.Len(t, ctx.prefixCacheScoresForPods, 3) - assert.Equal(t, 0.8, ctx.prefixCacheScoresForPods["pod1"]) - assert.Equal(t, 0.9, ctx.prefixCacheScoresForPods["pod3"]) + assert.Len(t, ctx.prefixCacheScoresForEndpoints, 3) + assert.Equal(t, 0.8, ctx.prefixCacheScoresForEndpoints["pod1"]) + assert.Equal(t, 0.9, ctx.prefixCacheScoresForEndpoints["pod3"]) } func TestSLOAwareRouter_ConcurrentContextAccess(t *testing.T) { @@ -749,13 +749,13 @@ func TestSLOAwareRouter_MultipleRequests_SamePod(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request1 := createTestLLMRequest("test-id-1", 100, 50, true) request2 := createTestLLMRequest("test-id-2", 100, 50, true) request3 := createTestLLMRequest("test-id-3", 100, 50, true) - schedulingResult := createTestSchedulingResult(pod.GetPod()) + schedulingResult := createTestSchedulingResult(endpoint.GetMetadata()) // Create and set SLO contexts for _, req := range []*schedulingtypes.LLMRequest{request1, request2, request3} { @@ -770,7 +770,7 @@ func TestSLOAwareRouter_MultipleRequests_SamePod(t *testing.T) { router.PreRequest(ctx, request3, schedulingResult) // Verify queue has all requests - queue, exists := router.runningRequestLists[pod.GetPod().NamespacedName] + queue, exists := router.runningRequestLists[endpoint.GetMetadata().NamespacedName] assert.True(t, exists) assert.NotNil(t, queue) } @@ -781,10 +781,10 @@ func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { router.latencypredictor = mockPredictor ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) request := createTestLLMRequest("test", 100, 50, true) response := &requestcontrol.Response{} - schedulingResult := createTestSchedulingResult(pod.GetPod()) + schedulingResult := createTestSchedulingResult(endpoint.GetMetadata()) // Create initial context sloCtx := newSLORequestContext(request) @@ -798,26 +798,26 @@ func TestSLOAwareRouter_RequestLifecycle_Complete(t *testing.T) { // Verify context exists retrievedCtx, err := router.getSLOContextForRequest(request) require.NoError(t, err) - assert.NotNil(t, retrievedCtx.targetPod) + assert.NotNil(t, retrievedCtx.targetMetadata) // 2. ResponseReceived - router.ResponseReceived(ctx, request, response, pod.GetPod()) + router.ResponseReceived(ctx, request, response, endpoint.GetMetadata()) // 3. ResponseStreaming (first token) - router.ResponseStreaming(ctx, request, response, pod.GetPod()) + router.ResponseStreaming(ctx, request, response, endpoint.GetMetadata()) // 4. ResponseStreaming (subsequent tokens) retrievedCtx, _ = router.getSLOContextForRequest(request) retrievedCtx.ttft = 100 // Mark first token received router.setSLOContextForRequest(request, retrievedCtx) - router.ResponseStreaming(ctx, request, response, pod.GetPod()) + router.ResponseStreaming(ctx, request, response, endpoint.GetMetadata()) // 5. ResponseComplete retrievedCtx, _ = router.getSLOContextForRequest(request) retrievedCtx.ttft = 80 retrievedCtx.avgTPOT = 30 router.setSLOContextForRequest(request, retrievedCtx) - router.ResponseComplete(ctx, request, response, pod.GetPod()) + router.ResponseComplete(ctx, request, response, endpoint.GetMetadata()) // Verify context was cleaned up _, err = router.getSLOContextForRequest(request) @@ -831,14 +831,14 @@ func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) { ctx := context.Background() - pod1 := createTestPod("test-pod-1", 1, 1, 1) - pod2 := createTestPod("test-pod-2", 1, 1, 1) + endpoint1 := createTestEndpoint("test-pod-1", 1, 1, 1) + endpoint2 := createTestEndpoint("test-pod-2", 1, 1, 1) request1 := createTestLLMRequest("test-id-1", 100, 50, true) request2 := createTestLLMRequest("test-id-2", 100, 50, true) - schedulingResult1 := createTestSchedulingResult(pod1.GetPod()) - schedulingResult2 := createTestSchedulingResult(pod2.GetPod()) + schedulingResult1 := createTestSchedulingResult(endpoint1.GetMetadata()) + schedulingResult2 := createTestSchedulingResult(endpoint2.GetMetadata()) // Create and set SLO contexts sloCtx1 := newSLORequestContext(request1) @@ -854,8 +854,8 @@ func TestSLOAwareRouter_MultipleRequests_DifferentPods(t *testing.T) { router.PreRequest(ctx, request2, schedulingResult2) // Verify separate queues were created - queue1, exists1 := router.runningRequestLists[pod1.GetPod().NamespacedName] - queue2, exists2 := router.runningRequestLists[pod2.GetPod().NamespacedName] + queue1, exists1 := router.runningRequestLists[endpoint1.GetMetadata().NamespacedName] + queue2, exists2 := router.runningRequestLists[endpoint2.GetMetadata().NamespacedName] assert.True(t, exists1) assert.True(t, exists2) @@ -915,8 +915,8 @@ func TestSLORequestContext_SLOValidation(t *testing.T) { func BenchmarkSLOAwareRouter_PreRequest(b *testing.B) { router := createTestRouter() ctx := context.Background() - pod := createTestPod("test-pod", 1, 1, 1) - schedulingResult := createTestSchedulingResult(pod.GetPod()) + endpoint := createTestEndpoint("test-pod", 1, 1, 1) + schedulingResult := createTestSchedulingResult(endpoint.GetMetadata()) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go index 25bb1e8ed..5279a1d90 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go @@ -191,18 +191,18 @@ func (s *SLOAwareRouter) WithName(name string) *SLOAwareRouter { func (s *SLOAwareRouter) epsilonGreedyAffinityGate( ctx context.Context, - candidates []podPredictionResult, + candidates []endpointPredictionResult, r *rand.Rand, label string, // e.g. "positive" or "negative" prefixStickyThreshold float64, -) ([]podPredictionResult, bool) { +) ([]endpointPredictionResult, bool) { logger := log.FromContext(ctx) if prefixStickyThreshold <= 0 { // Affinity gating disabled logger.V(logutil.DEBUG).Info("Affinity gating disabled (threshold <= 0)", "path", label) return candidates, false } - eligible := make([]podPredictionResult, 0, len(candidates)) + eligible := make([]endpointPredictionResult, 0, len(candidates)) for _, p := range candidates { if p.PrefixCacheScore >= prefixStickyThreshold { eligible = append(eligible, p) @@ -231,44 +231,44 @@ func (s *SLOAwareRouter) epsilonGreedyAffinityGate( func (s *SLOAwareRouter) scoreWithoutPredictions( ctx context.Context, state *schedulingtypes.CycleState, - pods []schedulingtypes.Pod, + endpoints []schedulingtypes.Endpoint, r *rand.Rand, -) map[schedulingtypes.Pod]float64 { +) map[schedulingtypes.Endpoint]float64 { logger := log.FromContext(ctx) logger.V(logutil.TRACE).Info("Using composite-only scoring without predictions") - scores := make(map[schedulingtypes.Pod]float64, len(pods)) - for _, pod := range pods { - scores[pod] = 0 + scores := make(map[schedulingtypes.Endpoint]float64, len(endpoints)) + for _, endpoint := range endpoints { + scores[endpoint] = 0 } - if len(pods) == 0 { + if len(endpoints) == 0 { return scores } // Build prediction results with only prefix cache scores - podResults := make([]podPredictionResult, 0, len(pods)) - for _, pod := range pods { - prefixScore := s.getPrefixCacheScoreForPod(ctx, state, pod) - podResults = append(podResults, podPredictionResult{ - Pod: pod, + endpointResults := make([]endpointPredictionResult, 0, len(endpoints)) + for _, endpoint := range endpoints { + prefixScore := s.getPrefixCacheScoreForPod(ctx, state, endpoint) + endpointResults = append(endpointResults, endpointPredictionResult{ + Endpoint: endpoint, PrefixCacheScore: prefixScore, - IsValid: true, // All pods are valid when we don't check predictions + IsValid: true, // All endpoints are valid when we don't check predictions }) } // Select based on composite scores (prefix cache + other non-prediction metrics) - selectedPod := s.selectFromCompositeScores(ctx, podResults, r, headroomStrategyCompositeOnly) + selectedEndpoint := s.selectFromCompositeScores(ctx, endpointResults, r, headroomStrategyCompositeOnly) - if selectedPod != nil { - scores[selectedPod] = 1 - logger.V(logutil.TRACE).Info("Selected pod using composite-only scoring", "pod", selectedPod.GetPod().String()) + if selectedEndpoint != nil { + scores[selectedEndpoint] = 1 + logger.V(logutil.TRACE).Info("Selected endpoint using composite-only scoring", "endpoint", selectedEndpoint.GetMetadata().String()) } return scores } -func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) map[schedulingtypes.Pod]float64 { +func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, endpoints []schedulingtypes.Endpoint) map[schedulingtypes.Endpoint]float64 { logger := log.FromContext(ctx) if s.latencypredictor == nil { logger.V(logutil.DEBUG).Info("SLOAwareRouter: no predictor configured, returning nil scores") @@ -279,9 +279,9 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle s.parseSLOHeaders(ctx, request, sloCtx) - for _, pod := range pods { - prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod) - sloCtx.prefixCacheScoresForPods[pod.GetPod().String()] = prefixCacheScore + for _, endpoint := range endpoints { + prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, endpoint) + sloCtx.prefixCacheScoresForEndpoints[endpoint.GetMetadata().String()] = prefixCacheScore } // Check if SLOs are provided @@ -292,59 +292,59 @@ func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.Cycle } // Initialize scores map with all pods having score 0 - scores := make(map[schedulingtypes.Pod]float64, len(pods)) - for _, pod := range pods { - scores[pod] = 0 + scores := make(map[schedulingtypes.Endpoint]float64, len(endpoints)) + for _, endpoint := range endpoints { + scores[endpoint] = 0 } source := rand.NewSource(time.Now().UnixNano()) r := rand.New(source) - predictions, err := s.generatePredictions(ctx, state, request, sloCtx, pods) + predictions, err := s.generatePredictions(ctx, state, request, sloCtx, endpoints) if err != nil { logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Error generating predictions, falling back to composite-only scoring") // Fall back to composite-only scoring using prefix cache scores s.setSLOContextForRequest(request, sloCtx) - return s.scoreWithoutPredictions(ctx, state, pods, r) + return s.scoreWithoutPredictions(ctx, state, endpoints, r) } s.updateRequestContextWithPredictions(sloCtx, predictions) - allPreds := append([]podPredictionResult(nil), predictions...) + allPreds := append([]endpointPredictionResult(nil), predictions...) allPreds, sticky := s.epsilonGreedyAffinityGate(ctx, allPreds, r, "overall", AffinityGateTauGlobal) - // Check if all pods are invalid and all have running requests - allPodsInvalid := (sloCtx.ttftSLO > 0 && sloCtx.avgTPOTSLO > 0) - allPodsHaveRunningRequests := true + // Check if all endpoints are invalid and all have running requests + allEndpointsInvalid := (sloCtx.ttftSLO > 0 && sloCtx.avgTPOTSLO > 0) + allEndpointsHaveRunningRequests := true for _, pred := range allPreds { if pred.IsValid { - allPodsInvalid = false + allEndpointsInvalid = false } - runningRequestCount := s.getPodRunningRequestCount(pred.Pod) + runningRequestCount := s.getEndpointRunningRequestCount(pred.Endpoint) if runningRequestCount == 0 { - allPodsHaveRunningRequests = false + allEndpointsHaveRunningRequests = false } } - // Set HasValidPod to false if all pods are invalid and all have running requests - if allPodsInvalid && allPodsHaveRunningRequests && !sticky { - sloCtx.hasValidPod = false - logger.V(logutil.DEBUG).Info("All pods are invalid and have running requests, setting HasValidPod to false") + // Set HasValidEndpoint to false if all endpoints are invalid and all have running requests + if allEndpointsInvalid && allEndpointsHaveRunningRequests && !sticky { + sloCtx.hasValidEndpoint = false + logger.V(logutil.DEBUG).Info("All endpoints are invalid and have running requests, setting HasValidEndpoint to false") } // 2) Tiered selection: positive headroom pods get 99% probability, negative get 1% - posHeadroomPods, negHeadroomPods := s.classifyPodsByHeadroom(allPreds) + posHeadroomEndpoints, negHeadroomEndpoints := s.classifyEndpointsByHeadroom(allPreds) logger.V(logutil.DEBUG).Info("Pod headroom distribution", - "positivePods", len(posHeadroomPods), - "negativePods", len(negHeadroomPods)) + "positivePods", len(posHeadroomEndpoints), + "negativePods", len(negHeadroomEndpoints)) - selectedPod := s.selectPodBasedOnStrategy(ctx, r, allPreds, posHeadroomPods, negHeadroomPods) + selectedEndpoint := s.selectEndpointBasedOnStrategy(ctx, r, allPreds, posHeadroomEndpoints, negHeadroomEndpoints) - // Set score = 1 for selected pod, 0 for all others - if selectedPod != nil { - scores[selectedPod] = 1 - logger.V(logutil.DEBUG).Info("Selected pod for scheduling", "pod", selectedPod.GetPod().String()) + // Set score = 1 for selected endpoint, 0 for all others + if selectedEndpoint != nil { + scores[selectedEndpoint] = 1 + logger.V(logutil.DEBUG).Info("Selected endpoint for scheduling", "endpoint", selectedEndpoint.GetMetadata().String()) } s.setSLOContextForRequest(request, sloCtx) @@ -360,8 +360,8 @@ func (t *SLOAwareRouter) getOrMakeSLORequestContext(request *schedulingtypes.LLM return sloCtx } -func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleState *schedulingtypes.CycleState, pod schedulingtypes.Pod) float64 { - log.FromContext(ctx).V(logutil.DEBUG).Info("Running getPrefixCacheScoreForPod, getting prefix cache score for pod", "pod", pod.GetPod().String()) +func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleState *schedulingtypes.CycleState, endpoint schedulingtypes.Endpoint) float64 { + log.FromContext(ctx).V(logutil.DEBUG).Info("Running getPrefixCacheScoreForPod, getting prefix cache score for endpoint", "endpoint", endpoint.GetMetadata().String()) plugintype := prefix.PrefixCachePluginType pluginname := prefix.PrefixCachePluginType cycleStateKey := (plugins.TypedName{Type: plugintype, Name: pluginname}).String() @@ -371,7 +371,7 @@ func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleSta if err != nil { // The prefix cache plugin might not be enabled, which is a valid scenario. - log.FromContext(ctx).V(logutil.DEBUG).Info("prefix cache state not found in cycle state, returning prefix cache score of 0.0", "pod", pod.GetPod().String()) + log.FromContext(ctx).V(logutil.DEBUG).Info("prefix cache state not found in cycle state, returning prefix cache score of 0.0", "pod", endpoint.GetMetadata().String()) return 0.0 } @@ -389,7 +389,7 @@ func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleSta return 0.0 } - matchLen := prefixCacheState.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] - log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", pod.GetPod().String(), "matchLen", matchLen, "totalPrefixes", total) + matchLen := prefixCacheState.PrefixCacheServers[prefix.ServerID(endpoint.GetMetadata().NamespacedName)] + log.FromContext(ctx).V(logutil.DEBUG).Info("Prefix cache score for endpoint", "endpoint", endpoint.GetMetadata().String(), "matchLen", matchLen, "totalPrefixes", total) return float64(matchLen) / float64(total) } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go index bb97ba346..2bbe2b058 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_helpers.go @@ -44,56 +44,56 @@ func (s *SLOAwareRouter) parseSLOHeaders(ctx context.Context, request *schedulin sloCtx.predictorBasedScheduling = !hasHeader(*request, "x-prediction-based-scheduling-off") } -func (s *SLOAwareRouter) classifyPodsByHeadroom(allPreds []podPredictionResult) (posHeadroomPods, negHeadroomPods []podPredictionResult) { +func (s *SLOAwareRouter) classifyEndpointsByHeadroom(allPreds []endpointPredictionResult) (posHeadroomEndpoints, negHeadroomEndpoints []endpointPredictionResult) { for _, p := range allPreds { - // A pod has positive headroom only if BOTH TTFT and TPOT have positive headroom + // An endpoint has positive headroom only if BOTH TTFT and TPOT have positive headroom if p.Headroom > 0 && p.TTFTHeadroom > 0 { - posHeadroomPods = append(posHeadroomPods, p) + posHeadroomEndpoints = append(posHeadroomEndpoints, p) } else { - // A pod has negative headroom if EITHER TTFT or TPOT has negative/zero headroom - negHeadroomPods = append(negHeadroomPods, p) + // An endpoint has negative headroom if EITHER TTFT or TPOT has negative/zero headroom + negHeadroomEndpoints = append(negHeadroomEndpoints, p) } } return } -func (s *SLOAwareRouter) selectPodBasedOnStrategy( +func (s *SLOAwareRouter) selectEndpointBasedOnStrategy( ctx context.Context, r *rand.Rand, - allPreds, posHeadroomPods, negHeadroomPods []podPredictionResult, -) schedulingtypes.Pod { + allPreds, posHeadroomEndpoints, negHeadroomEndpoints []endpointPredictionResult, +) schedulingtypes.Endpoint { logger := log.FromContext(ctx) - var selectedPod schedulingtypes.Pod + var selectedEndpoint schedulingtypes.Endpoint switch { case s.headroomStrategy == headroomStrategyCompositeOnly: logger.V(logutil.DEBUG).Info("Selecting from composite scores only") - selectedPod = s.selectFromCompositeScores(ctx, allPreds, r, headroomStrategyCompositeOnly) - case len(posHeadroomPods) > 0 && len(negHeadroomPods) > 0: - // 99% chance to select from positive headroom pods, 1% from negative + selectedEndpoint = s.selectFromCompositeScores(ctx, allPreds, r, headroomStrategyCompositeOnly) + case len(posHeadroomEndpoints) > 0 && len(negHeadroomEndpoints) > 0: + // 99% chance to select from positive headroom endpoints, 1% from negative if r.Float64() < s.config.EpsilonExploreNeg { - logger.V(logutil.DEBUG).Info("Selecting from negative headroom pods (1% chance)") - selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + logger.V(logutil.DEBUG).Info("Selecting from negative headroom endpoints (1% chance)") + selectedEndpoint = s.selectFromNegativeHeadroomEndpoints(ctx, negHeadroomEndpoints, r) } else { - logger.V(logutil.DEBUG).Info("Selecting from positive headroom pods (99% chance)") - selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) + logger.V(logutil.DEBUG).Info("Selecting from positive headroom endpoints (99% chance)") + selectedEndpoint = s.selectFromPositiveHeadroomEndpoints(ctx, posHeadroomEndpoints, r) } - case len(posHeadroomPods) > 0: - // If only positive headroom pods exist, select from them - logger.V(logutil.DEBUG).Info("Only positive headroom pods available") - selectedPod = s.selectFromPositiveHeadroomPods(ctx, posHeadroomPods, r) - case len(negHeadroomPods) > 0: - // If only negative headroom pods exist, select from them - logger.V(logutil.DEBUG).Info("Only negative headroom pods available") - selectedPod = s.selectFromNegativeHeadroomPods(ctx, negHeadroomPods, r) + case len(posHeadroomEndpoints) > 0: + // If only positive headroom endpoints exist, select from them + logger.V(logutil.DEBUG).Info("Only positive headroom endpoints available") + selectedEndpoint = s.selectFromPositiveHeadroomEndpoints(ctx, posHeadroomEndpoints, r) + case len(negHeadroomEndpoints) > 0: + // If only negative headroom endpoints exist, select from them + logger.V(logutil.DEBUG).Info("Only negative headroom endpoints available") + selectedEndpoint = s.selectFromNegativeHeadroomEndpoints(ctx, negHeadroomEndpoints, r) case len(allPreds) > 0: - // fallback - select randomly from valid pods - logger.V(logutil.DEBUG).Info("No headroom pods available, selecting randomly from valid pods") - selectedPod = allPreds[r.Intn(len(allPreds))].Pod + // fallback - select randomly from valid endpoints + logger.V(logutil.DEBUG).Info("No headroom endpoints available, selecting randomly from valid endpoints") + selectedEndpoint = allPreds[r.Intn(len(allPreds))].Endpoint default: - // No valid pods - return nil (caller handles this) - logger.V(logutil.DEBUG).Info("No valid pods available") + // No valid endpoints - return nil (caller handles this) + logger.V(logutil.DEBUG).Info("No valid endpoints available") return nil } - return selectedPod + return selectedEndpoint } diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go index 8d8f68393..a7a78cd7a 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go @@ -26,8 +26,8 @@ import ( "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync" @@ -103,9 +103,9 @@ func (m *mockPredictor) GetServerStatus(ctx context.Context) (*latencypredictor. return &latencypredictor.ServerStatusResponse{}, nil } -func createTestPod(name string, kvCacheUsage float64, runningRequestsSize, waitingQueueSize int) schedulingtypes.Pod { +func createTestEndpoint(name string, kvCacheUsage float64, runningRequestsSize, waitingQueueSize int) schedulingtypes.Endpoint { return &schedulingtypes.PodMetrics{ - Pod: &backend.Pod{ + EndpointMetadata: &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{ Name: name, Namespace: "default", @@ -148,7 +148,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { predictor *mockPredictor strategy headroomStrategy request *schedulingtypes.LLMRequest - pods []schedulingtypes.Pod + endpoints []schedulingtypes.Endpoint expectedScores map[string]float64 // Map of pod name to expected score expectNil bool }{ @@ -157,9 +157,9 @@ func TestSLOAwareRouter_Score(t *testing.T) { predictor: &mockPredictor{}, strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, false), // predictionBased = false - pods: []schedulingtypes.Pod{ - createTestPod("pod1", 0.5, 2, 1), // 50% KV cache, 2 running, 1 waiting - createTestPod("pod2", 0.7, 3, 2), // 70% KV cache, 3 running, 2 waiting + endpoints: []schedulingtypes.Endpoint{ + createTestEndpoint("pod1", 0.5, 2, 1), // 50% KV cache, 2 running, 1 waiting + createTestEndpoint("pod2", 0.7, 3, 2), // 70% KV cache, 3 running, 2 waiting }, expectNil: true, }, @@ -168,8 +168,8 @@ func TestSLOAwareRouter_Score(t *testing.T) { predictor: nil, strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), - pods: []schedulingtypes.Pod{ - createTestPod("pod1", 0.5, 2, 1), + endpoints: []schedulingtypes.Endpoint{ + createTestEndpoint("pod1", 0.5, 2, 1), }, expectNil: true, }, @@ -184,10 +184,10 @@ func TestSLOAwareRouter_Score(t *testing.T) { }, strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), - pods: []schedulingtypes.Pod{ - createTestPod("pod1", 0.5, 2, 1), // 50% KV cache - createTestPod("pod2", 0.6, 3, 2), // 60% KV cache - createTestPod("pod3", 0.3, 1, 0), // 30% KV cache + endpoints: []schedulingtypes.Endpoint{ + createTestEndpoint("pod1", 0.5, 2, 1), // 50% KV cache + createTestEndpoint("pod2", 0.6, 3, 2), // 60% KV cache + createTestEndpoint("pod3", 0.3, 1, 0), // 30% KV cache }, // One pod should be selected with score 1, others 0 expectedScores: map[string]float64{ @@ -204,9 +204,9 @@ func TestSLOAwareRouter_Score(t *testing.T) { }, strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), - pods: []schedulingtypes.Pod{ - createTestPod("pod1", 0.8, 5, 3), // 80% KV cache, high load - createTestPod("pod2", 0.9, 6, 4), // 90% KV cache, very high load + endpoints: []schedulingtypes.Endpoint{ + createTestEndpoint("pod1", 0.8, 5, 3), // 80% KV cache, high load + createTestEndpoint("pod2", 0.9, 6, 4), // 90% KV cache, very high load }, // One pod should still be selected even with negative headroom expectedScores: map[string]float64{}, @@ -221,9 +221,9 @@ func TestSLOAwareRouter_Score(t *testing.T) { }, strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), - pods: []schedulingtypes.Pod{ - createTestPod("pod-positive", 0.3, 1, 0), // Low KV cache, positive headroom - createTestPod("pod-negative", 0.9, 6, 4), // High KV cache, negative headroom + endpoints: []schedulingtypes.Endpoint{ + createTestEndpoint("pod-positive", 0.3, 1, 0), // Low KV cache, positive headroom + createTestEndpoint("pod-negative", 0.9, 6, 4), // High KV cache, negative headroom }, // With 99% probability, positive headroom pod should be selected expectedScores: map[string]float64{}, @@ -235,9 +235,9 @@ func TestSLOAwareRouter_Score(t *testing.T) { }, strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), - pods: []schedulingtypes.Pod{ - createTestPod("pod1", 0.5, 2, 1), - createTestPod("pod2", 0.6, 3, 2), + endpoints: []schedulingtypes.Endpoint{ + createTestEndpoint("pod1", 0.5, 2, 1), + createTestEndpoint("pod2", 0.6, 3, 2), }, // Should fall back to composite-only scoring and select one pod expectedScores: map[string]float64{ @@ -249,7 +249,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { predictor: &mockPredictor{}, strategy: headroomStrategyLeast, request: createTestLLMRequest("test", 1.0, 0.05, true), - pods: []schedulingtypes.Pod{}, + endpoints: []schedulingtypes.Endpoint{}, // Should return empty scores map expectedScores: map[string]float64{}, }, @@ -272,7 +272,7 @@ func TestSLOAwareRouter_Score(t *testing.T) { router = NewSLOAwareRouter(cfg, predictor) - scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), tt.request, tt.pods) + scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), tt.request, tt.endpoints) if tt.expectNil { assert.Nil(t, scores, "Expected nil scores") @@ -283,17 +283,17 @@ func TestSLOAwareRouter_Score(t *testing.T) { // If we have specific expected scores, verify them if len(tt.expectedScores) > 0 { - for _, pod := range tt.pods { - podName := pod.GetPod().NamespacedName.Name - if expectedScore, ok := tt.expectedScores[podName]; ok { - assert.InDelta(t, expectedScore, scores[pod], 0.0001, "Pod %s should have score %f", podName, expectedScore) + for _, endpoint := range tt.endpoints { + endpointName := endpoint.GetMetadata().NamespacedName.Name + if expectedScore, ok := tt.expectedScores[endpointName]; ok { + assert.InDelta(t, expectedScore, scores[endpoint], 0.0001, "Pod %s should have score %f", endpointName, expectedScore) } } } // General validation: exactly one pod should have score 1 (selected), others should have score 0 // This applies even when predictions fail because we fall back to composite scoring - if !tt.expectNil && len(tt.pods) > 0 && tt.predictor != nil { + if !tt.expectNil && len(tt.endpoints) > 0 && tt.predictor != nil { selectedCount := 0 for _, score := range scores { if score == 1.0 { @@ -349,13 +349,13 @@ func TestSLOAwareRouter_Strategies(t *testing.T) { router := NewSLOAwareRouter(cfg, predictor) request := createTestLLMRequest("test", 1.0, 0.05, true) - pods := []schedulingtypes.Pod{ - createTestPod("pod1", 0.5, 2, 1), - createTestPod("pod2", 0.6, 3, 2), - createTestPod("pod3", 0.3, 1, 0), + endpoints := []schedulingtypes.Endpoint{ + createTestEndpoint("pod1", 0.5, 2, 1), + createTestEndpoint("pod2", 0.6, 3, 2), + createTestEndpoint("pod3", 0.3, 1, 0), } - scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), request, pods) + scores := router.Score(context.Background(), schedulingtypes.NewCycleState(), request, endpoints) assert.NotNil(t, scores, "Expected non-nil scores for strategy %s", tt.strategy) @@ -399,20 +399,20 @@ func TestSLOAwareRouter_WithName(t *testing.T) { func TestSLOAwareRouter_GetPodRunningRequestCount(t *testing.T) { tests := []struct { name string - setupRequests func(*SLOAwareRouter, schedulingtypes.Pod) + setupRequests func(*SLOAwareRouter, schedulingtypes.Endpoint) expectedCount int }{ { name: "No running requests", - setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) {}, + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Endpoint) {}, expectedCount: 0, }, { name: "One running request", - setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Endpoint) { podName := types.NamespacedName{ - Name: p.GetPod().NamespacedName.Name, - Namespace: p.GetPod().NamespacedName.Namespace, + Name: p.GetMetadata().NamespacedName.Name, + Namespace: p.GetMetadata().NamespacedName.Namespace, } r.runningRequestLists[podName] = newRequestPriorityQueue() r.runningRequestLists[podName].Add("req1", 0.04) @@ -421,15 +421,15 @@ func TestSLOAwareRouter_GetPodRunningRequestCount(t *testing.T) { }, { name: "Multiple running requests", - setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { - podName := types.NamespacedName{ - Name: p.GetPod().NamespacedName.Name, - Namespace: p.GetPod().NamespacedName.Namespace, + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Endpoint) { + endpointName := types.NamespacedName{ + Name: p.GetMetadata().NamespacedName.Name, + Namespace: p.GetMetadata().NamespacedName.Namespace, } - r.runningRequestLists[podName] = newRequestPriorityQueue() - r.runningRequestLists[podName].Add("req1", 0.04) - r.runningRequestLists[podName].Add("req2", 0.03) - r.runningRequestLists[podName].Add("req3", 0.05) + r.runningRequestLists[endpointName] = newRequestPriorityQueue() + r.runningRequestLists[endpointName].Add("req1", 0.04) + r.runningRequestLists[endpointName].Add("req2", 0.03) + r.runningRequestLists[endpointName].Add("req3", 0.05) }, expectedCount: 3, }, @@ -441,11 +441,11 @@ func TestSLOAwareRouter_GetPodRunningRequestCount(t *testing.T) { cfg := DefaultConfig cfg.HeadroomSelectionStrategy = string(headroomStrategyLeast) router := NewSLOAwareRouter(cfg, predictor) - pod := createTestPod("test-pod", 0.5, 2, 1) + pod := createTestEndpoint("test-pod", 0.5, 2, 1) tt.setupRequests(router, pod) - count := router.getPodRunningRequestCount(pod) + count := router.getEndpointRunningRequestCount(pod) assert.Equal(t, tt.expectedCount, count, "Running request count should match expected") }) } @@ -454,38 +454,38 @@ func TestSLOAwareRouter_GetPodRunningRequestCount(t *testing.T) { func TestSLOAwareRouter_GetPodMinTPOTSLO(t *testing.T) { tests := []struct { name string - setupRequests func(*SLOAwareRouter, schedulingtypes.Pod) + setupRequests func(*SLOAwareRouter, schedulingtypes.Endpoint) expectedSLO float64 }{ { name: "No running requests", - setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) {}, + setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Endpoint) {}, expectedSLO: 0.0, }, { name: "One running request", - setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { - podName := types.NamespacedName{ - Name: p.GetPod().NamespacedName.Name, - Namespace: p.GetPod().NamespacedName.Namespace, + setupRequests: func(r *SLOAwareRouter, e schedulingtypes.Endpoint) { + endpointName := types.NamespacedName{ + Name: e.GetMetadata().NamespacedName.Name, + Namespace: e.GetMetadata().NamespacedName.Namespace, } - r.runningRequestLists[podName] = newRequestPriorityQueue() - r.runningRequestLists[podName].Add("req1", 0.04) + r.runningRequestLists[endpointName] = newRequestPriorityQueue() + r.runningRequestLists[endpointName].Add("req1", 0.04) }, expectedSLO: 0.04, }, { name: "Multiple running requests - should return minimum", - setupRequests: func(r *SLOAwareRouter, p schedulingtypes.Pod) { - podName := types.NamespacedName{ - Name: p.GetPod().NamespacedName.Name, - Namespace: p.GetPod().NamespacedName.Namespace, + setupRequests: func(r *SLOAwareRouter, e schedulingtypes.Endpoint) { + endpointName := types.NamespacedName{ + Name: e.GetMetadata().NamespacedName.Name, + Namespace: e.GetMetadata().NamespacedName.Namespace, } - r.runningRequestLists[podName] = newRequestPriorityQueue() + r.runningRequestLists[endpointName] = newRequestPriorityQueue() // Add in any order - heap will maintain minimum at top - r.runningRequestLists[podName].Add("req1", 0.05) - r.runningRequestLists[podName].Add("req2", 0.03) // This is the minimum - r.runningRequestLists[podName].Add("req3", 0.04) + r.runningRequestLists[endpointName].Add("req1", 0.05) + r.runningRequestLists[endpointName].Add("req2", 0.03) // This is the minimum + r.runningRequestLists[endpointName].Add("req3", 0.04) }, expectedSLO: 0.03, // Minimum TPOT (heap guarantees this is at items[0]) }, @@ -497,11 +497,11 @@ func TestSLOAwareRouter_GetPodMinTPOTSLO(t *testing.T) { cfg := DefaultConfig cfg.HeadroomSelectionStrategy = string(headroomStrategyLeast) router := NewSLOAwareRouter(cfg, predictor) - pod := createTestPod("test-pod", 0.5, 2, 1) + pod := createTestEndpoint("test-pod", 0.5, 2, 1) tt.setupRequests(router, pod) - minSLO := router.getPodMinTPOTSLO(pod) + minSLO := router.getEndpointMinTPOTSLO(pod) assert.InDelta(t, tt.expectedSLO, minSLO, 0.0001, "Min TPOT SLO should match expected") }) } @@ -530,7 +530,7 @@ func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) { state := schedulingtypes.NewCycleState() tt.setupState(state) - pod := createTestPod("test-pod", 0.5, 2, 1) + pod := createTestEndpoint("test-pod", 0.5, 2, 1) score := router.getPrefixCacheScoreForPod(context.Background(), state, pod) assert.InDelta(t, tt.expectedScore, score, 0.0001, "Prefix cache score should match expected") diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go index 0e1c9eca7..b7ce83ab7 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection.go @@ -28,20 +28,20 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// selectFromPositiveHeadroomPods selects a pod from positive headroom pods using headroom strategy +// selectFromPositiveHeadroomEndpoints selects a endpoint from positive headroom endpoints using headroom strategy // Updated to incorporate TTFTHeadroom with a configurable blend vs TPOT headroom. -func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { +func (s *SLOAwareRouter) selectFromPositiveHeadroomEndpoints(ctx context.Context, posHeadroomEndpoints []endpointPredictionResult, r *rand.Rand) schedulingtypes.Endpoint { - if len(posHeadroomPods) == 1 { - return posHeadroomPods[0].Pod + if len(posHeadroomEndpoints) == 1 { + return posHeadroomEndpoints[0].Endpoint } // Apply perfect stickiness (with exploration) - candidates, sticky := s.epsilonGreedyAffinityGate(ctx, posHeadroomPods, r, "positive", s.config.AffinityGateTau) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, posHeadroomEndpoints, r, "positive", s.config.AffinityGateTau) - // If perfect stickiness collapsed us to a single pod, short-circuit + // If perfect stickiness collapsed us to a single endpoint, short-circuit if sticky && len(candidates) == 1 { - return candidates[0].Pod + return candidates[0].Endpoint } switch s.headroomStrategy { case headroomStrategyCompositeMost: @@ -50,7 +50,7 @@ func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, pos return s.selectFromCompositeScores(ctx, candidates, r, headroomStrategyCompositeLeast) } - // Find min/max for TPOT (Headroom) and TTFTHeadroom across positive pods to normalize to [0,1] + // Find min/max for TPOT (Headroom) and TTFTHeadroom across positive endpoints to normalize to [0,1] minTPOTH, maxTPOTH, minTTFTH, maxTTFTH := s.calculateHeadroomRanges(candidates) // Calculate weights for weighted random selection @@ -59,54 +59,54 @@ func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, pos return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) } -// selectFromNegativeHeadroomPods selects a pod from negative headroom pods using hierarchical TTFT/TPOT logic -// Modified to strictly prefer pods with 0 running requests -func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, negHeadroomPods []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { +// selectFromNegativeHeadroomEndpoints selects an endpoint from negative headroom endpoints using hierarchical TTFT/TPOT logic +// Modified to strictly prefer endpoints with 0 running requests +func (s *SLOAwareRouter) selectFromNegativeHeadroomEndpoints(ctx context.Context, negHeadroomEndpoints []endpointPredictionResult, r *rand.Rand) schedulingtypes.Endpoint { logger := log.FromContext(ctx) - if len(negHeadroomPods) == 1 { - return negHeadroomPods[0].Pod + if len(negHeadroomEndpoints) == 1 { + return negHeadroomEndpoints[0].Endpoint } - // First, separate pods by running request count - var zeroRunningRequestPods, nonZeroRunningRequestPods []podPredictionResult + // First, separate endpoints by running request count + var zeroRunningRequestEndpoints, nonZeroRunningRequestEndpoints []endpointPredictionResult - for _, p := range negHeadroomPods { - runningRequestCount := s.getPodRunningRequestCount(p.Pod) + for _, e := range negHeadroomEndpoints { + runningRequestCount := s.getEndpointRunningRequestCount(e.Endpoint) if runningRequestCount == 0 { - zeroRunningRequestPods = append(zeroRunningRequestPods, p) + zeroRunningRequestEndpoints = append(zeroRunningRequestEndpoints, e) } else { - nonZeroRunningRequestPods = append(nonZeroRunningRequestPods, p) + nonZeroRunningRequestEndpoints = append(nonZeroRunningRequestEndpoints, e) } } - logger.V(logutil.DEBUG).Info("Negative headroom pods by running request count", - "zeroRunningRequests", len(zeroRunningRequestPods), - "nonZeroRunningRequests", len(nonZeroRunningRequestPods)) + logger.V(logutil.DEBUG).Info("Negative headroom endpoints by running request count", + "zeroRunningRequests", len(zeroRunningRequestEndpoints), + "nonZeroRunningRequests", len(nonZeroRunningRequestEndpoints)) - // If we have pods with 0 running requests, strictly prefer them - if len(zeroRunningRequestPods) > 0 { - logger.V(logutil.DEBUG).Info("Selecting from pods with zero running requests") - return s.selectFromNegativeHeadroomPodsInternal(ctx, zeroRunningRequestPods, r) + // If we have endpoints with 0 running requests, strictly prefer them + if len(zeroRunningRequestEndpoints) > 0 { + logger.V(logutil.DEBUG).Info("Selecting from endpoints with zero running requests") + return s.selectFromNegativeHeadroomEndpointsInternal(ctx, zeroRunningRequestEndpoints, r) } - // Otherwise, fall back to pods with running requests - logger.V(logutil.DEBUG).Info("No pods with zero running requests, selecting from pods with running requests") - return s.selectFromNegativeHeadroomPodsInternal(ctx, nonZeroRunningRequestPods, r) + // Otherwise, fall back to endpoints with running requests + logger.V(logutil.DEBUG).Info("No endpoints with zero running requests, selecting from endpoints with running requests") + return s.selectFromNegativeHeadroomEndpointsInternal(ctx, nonZeroRunningRequestEndpoints, r) } -// selectFromNegativeHeadroomPodsInternal handles the actual selection logic for negative headroom pods -func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Context, negHeadroomPods []podPredictionResult, r *rand.Rand) schedulingtypes.Pod { - if len(negHeadroomPods) == 1 { - return negHeadroomPods[0].Pod +// selectFromNegativeHeadroomEndpointsInternal handles the actual selection logic for negative headroom endpoints +func (s *SLOAwareRouter) selectFromNegativeHeadroomEndpointsInternal(ctx context.Context, negHeadroomEndpoints []endpointPredictionResult, r *rand.Rand) schedulingtypes.Endpoint { + if len(negHeadroomEndpoints) == 1 { + return negHeadroomEndpoints[0].Endpoint } // Apply perfect stickiness (with exploration) - candidates, sticky := s.epsilonGreedyAffinityGate(ctx, negHeadroomPods, r, "negative", s.config.AffinityGateTau) + candidates, sticky := s.epsilonGreedyAffinityGate(ctx, negHeadroomEndpoints, r, "negative", s.config.AffinityGateTau) - // If perfect stickiness collapsed us to a single pod, short-circuit + // If perfect stickiness collapsed us to a single endpoint, short-circuit if sticky && len(candidates) == 1 { - return candidates[0].Pod + return candidates[0].Endpoint } switch s.headroomStrategy { @@ -120,17 +120,17 @@ func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Cont weightedChoices := make([]choice, 0, len(candidates)) total := 0 - s.handleNegativeHeadroomPodsHierarchical(ctx, candidates, &weightedChoices, &total, minWeight) + s.handleNegativeHeadroomEndpointsHierarchical(ctx, candidates, &weightedChoices, &total, minWeight) // Perform weighted random selection return s.performWeightedRandomSelection(weightedChoices, total, candidates, r) } -// weightPodsByBlendedDeficit applies blended weighting using TTFT and TPOT deficits. +// weightEndpointsByBlendedDeficit applies blended weighting using TTFT and TPOT deficits. // Lower blended deficit => higher weight. -func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( +func (ps *SLOAwareRouter) weightEndpointsByBlendedDeficit( ctx context.Context, - pods []podPredictionResult, + endpoints []endpointPredictionResult, choices *[]choice, total *int, minWeight int, @@ -138,7 +138,7 @@ func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( category string, ) { logger := log.FromContext(ctx) - if len(pods) == 0 { + if len(endpoints) == 0 { return } @@ -147,25 +147,25 @@ func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( // Compute raw deficits (only when headroom is negative) type deficits struct { - pod podPredictionResult - ttftDef float64 - tpotDef float64 + endpoint endpointPredictionResult + ttftDef float64 + tpotDef float64 } - defs := make([]deficits, 0, len(pods)) + defs := make([]deficits, 0, len(endpoints)) minTTFT, maxTTFT := math.MaxFloat64, -math.MaxFloat64 minTPOT, maxTPOT := math.MaxFloat64, -math.MaxFloat64 - for _, p := range pods { + for _, e := range endpoints { ttftDef := 0.0 - if p.TTFTHeadroom < 0 { - ttftDef = -p.TTFTHeadroom + if e.TTFTHeadroom < 0 { + ttftDef = -e.TTFTHeadroom } tpotDef := 0.0 - if p.Headroom < 0 { - tpotDef = -p.Headroom + if e.Headroom < 0 { + tpotDef = -e.Headroom } - defs = append(defs, deficits{pod: p, ttftDef: ttftDef, tpotDef: tpotDef}) + defs = append(defs, deficits{endpoint: e, ttftDef: ttftDef, tpotDef: tpotDef}) if ttftDef < minTTFT { minTTFT = ttftDef @@ -197,7 +197,7 @@ func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( "category", category, "minTTFTDef", minTTFT, "maxTTFTDef", maxTTFT, "minTPOTDef", minTPOT, "maxTPOTDef", maxTPOT, - "alphaTTFT", alpha, "betaTPOT", beta, "podCount", len(pods)) + "alphaTTFT", alpha, "betaTPOT", beta, "endpointCount", len(endpoints)) for _, d := range defs { // Normalize deficits to [0,1] within this bucket (0 = best / least violation) @@ -214,33 +214,33 @@ func (ps *SLOAwareRouter) weightPodsByBlendedDeficit( blended := alpha*nTTFT + beta*nTPOT // Convert to selection weight: lower badness -> higher weight - // Ensure a floor so no pod is completely excluded within the bucket. + // Ensure a floor so no endpoint is completely excluded within the bucket. w := int((1.0-blended)*float64(Wrange)) + minWeight + 1 - *choices = append(*choices, choice{podName: d.pod.Pod, weight: w}) + *choices = append(*choices, choice{endpointName: d.endpoint.Endpoint, weight: w}) *total += w logger.V(logutil.TRACE).Info("Negative bucket blended weighting", - "pod", d.pod.Pod.GetPod().String(), + "endpoint", d.endpoint.Endpoint.GetMetadata().String(), "ttftDef", d.ttftDef, "tpotDef", d.tpotDef, "normTTFT", nTTFT, "normTPOT", nTPOT, "blendedBadness", blended, "weight", w) } } -func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( +func (s *SLOAwareRouter) handleNegativeHeadroomEndpointsHierarchical( ctx context.Context, - negHeadroomPods []podPredictionResult, + negHeadroomEndpoints []endpointPredictionResult, choices *[]choice, total *int, minWeightForNegative int, ) { logger := log.FromContext(ctx) - // Categorize pods by their headroom status - var negTTFTNegTPOT, negTTFTNonNegTPOT, nonNegTTFTNegTPOT, nonNegTTFTNonNegTPOT []podPredictionResult + // Categorize endpoints by their headroom status + var negTTFTNegTPOT, negTTFTNonNegTPOT, nonNegTTFTNegTPOT, nonNegTTFTNonNegTPOT []endpointPredictionResult - for _, p := range negHeadroomPods { + for _, p := range negHeadroomEndpoints { switch { case p.TTFTHeadroom < 0 && p.Headroom < 0: negTTFTNegTPOT = append(negTTFTNegTPOT, p) @@ -253,8 +253,8 @@ func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( } } - logger.V(logutil.DEBUG).Info("Hierarchical negative headroom pod distribution", - "totalNegative", len(negHeadroomPods), + logger.V(logutil.DEBUG).Info("Hierarchical negative headroom endpoint distribution", + "totalNegative", len(negHeadroomEndpoints), "negTTFT_negTPOT", len(negTTFTNegTPOT), "negTTFT_nonNegTPOT", len(negTTFTNonNegTPOT), "nonNegTTFT_negTPOT", len(nonNegTTFTNegTPOT), @@ -262,35 +262,35 @@ func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical( // Priority 1: both TTFT and TPOT negative -> blended deficits (both active) if len(negTTFTNegTPOT) > 0 { - s.weightPodsByBlendedDeficit(ctx, negTTFTNegTPOT, choices, total, minWeightForNegative, + s.weightEndpointsByBlendedDeficit(ctx, negTTFTNegTPOT, choices, total, minWeightForNegative, s.config.NegHeadroomTTFTWeight, s.config.NegHeadroomTPOTWeight, "both_negative") } // Priority 2: TTFT negative, TPOT non-negative -> blended still works (TPOT deficit=0) if len(negTTFTNonNegTPOT) > 0 { - s.weightPodsByBlendedDeficit(ctx, negTTFTNonNegTPOT, choices, total, minWeightForNegative, + s.weightEndpointsByBlendedDeficit(ctx, negTTFTNonNegTPOT, choices, total, minWeightForNegative, s.config.NegHeadroomTTFTWeight, s.config.NegHeadroomTPOTWeight, "ttft_negative") } // Priority 3: TTFT non-negative, TPOT negative -> blended (TTFT deficit=0) if len(nonNegTTFTNegTPOT) > 0 { - s.weightPodsByBlendedDeficit(ctx, nonNegTTFTNegTPOT, choices, total, minWeightForNegative, + s.weightEndpointsByBlendedDeficit(ctx, nonNegTTFTNegTPOT, choices, total, minWeightForNegative, s.config.NegHeadroomTTFTWeight, s.config.NegHeadroomTPOTWeight, "tpot_negative") } // Priority 4: edge-case bucket -> minimal weight - for _, p := range nonNegTTFTNonNegTPOT { - *choices = append(*choices, choice{podName: p.Pod, weight: minWeightForNegative}) + for _, e := range nonNegTTFTNonNegTPOT { + *choices = append(*choices, choice{endpointName: e.Endpoint, weight: minWeightForNegative}) *total += minWeightForNegative } } -func (s *SLOAwareRouter) getPodMinTPOTSLO(pod schedulingtypes.Pod) float64 { - podName := types.NamespacedName{ - Name: pod.GetPod().NamespacedName.Name, - Namespace: pod.GetPod().NamespacedName.Namespace, +func (s *SLOAwareRouter) getEndpointMinTPOTSLO(endpoint schedulingtypes.Endpoint) float64 { + endpointName := types.NamespacedName{ + Name: endpoint.GetMetadata().NamespacedName.Name, + Namespace: endpoint.GetMetadata().NamespacedName.Namespace, } - if runningReqs, ok := s.runningRequestLists[podName]; ok && runningReqs.GetSize() > 0 { + if runningReqs, ok := s.runningRequestLists[endpointName]; ok && runningReqs.GetSize() > 0 { if topReq := runningReqs.Peek(); topReq != nil { return topReq.tpot } @@ -298,12 +298,12 @@ func (s *SLOAwareRouter) getPodMinTPOTSLO(pod schedulingtypes.Pod) float64 { return 0 // no running requests or no TPOT SLOs } -func (s *SLOAwareRouter) getPodRunningRequestCount(pod schedulingtypes.Pod) int { - podName := types.NamespacedName{ - Name: pod.GetPod().NamespacedName.Name, - Namespace: pod.GetPod().NamespacedName.Namespace, +func (s *SLOAwareRouter) getEndpointRunningRequestCount(endpoint schedulingtypes.Endpoint) int { + endpointName := types.NamespacedName{ + Name: endpoint.GetMetadata().NamespacedName.Name, + Namespace: endpoint.GetMetadata().NamespacedName.Namespace, } - if runningReqs, ok := s.runningRequestLists[podName]; ok { + if runningReqs, ok := s.runningRequestLists[endpointName]; ok { return runningReqs.GetSize() } return 0 // no running requests diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection_helpers.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection_helpers.go index 012cd70ad..ed54b9f31 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection_helpers.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/selection_helpers.go @@ -24,7 +24,7 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -func (s *SLOAwareRouter) calculateHeadroomRanges(candidates []podPredictionResult) (minTPOTH, maxTPOTH, minTTFTH, maxTTFTH float64) { +func (s *SLOAwareRouter) calculateHeadroomRanges(candidates []endpointPredictionResult) (minTPOTH, maxTPOTH, minTTFTH, maxTTFTH float64) { minTPOTH, maxTPOTH = math.MaxFloat64, -math.MaxFloat64 minTTFTH, maxTTFTH = math.MaxFloat64, -math.MaxFloat64 @@ -47,7 +47,7 @@ func (s *SLOAwareRouter) calculateHeadroomRanges(candidates []podPredictionResul func (s *SLOAwareRouter) calculateWeightedChoices( ctx context.Context, - candidates []podPredictionResult, + candidates []endpointPredictionResult, minTPOTH, maxTPOTH, minTTFTH, maxTTFTH float64, ) ([]choice, int) { logger := log.FromContext(ctx) @@ -73,15 +73,15 @@ func (s *SLOAwareRouter) calculateWeightedChoices( weightedChoices := make([]choice, 0, len(candidates)) total := 0 - for _, p := range candidates { + for _, e := range candidates { // Normalize to [0,1] within the cohort nTPOTH := 0.5 if tpotRange > eps { - nTPOTH = (p.Headroom - minTPOTH) / (tpotRange + eps) + nTPOTH = (e.Headroom - minTPOTH) / (tpotRange + eps) } nTTFTH := 0.5 if ttftRange > eps { - nTTFTH = (p.TTFTHeadroom - minTTFTH) / (ttftRange + eps) + nTTFTH = (e.TTFTHeadroom - minTTFTH) / (ttftRange + eps) } // Blend: larger combined -> "safer"; smaller -> "tighter packing" @@ -101,13 +101,13 @@ func (s *SLOAwareRouter) calculateWeightedChoices( w = int((1.0-combined)*float64(wMax-minWeight)) + minWeight + 1 } - weightedChoices = append(weightedChoices, choice{podName: p.Pod, weight: w}) + weightedChoices = append(weightedChoices, choice{endpointName: e.Endpoint, weight: w}) total += w logger.V(logutil.TRACE).Info("Positive headroom blended weight", - "pod", p.Pod.GetPod().String(), - "ttftHeadroom", p.TTFTHeadroom, "normTTFTHeadroom", nTTFTH, - "tpotHeadroom", p.Headroom, "normTPOTHeadroom", nTPOTH, + "endpoint", e.Endpoint.GetMetadata().String(), + "ttftHeadroom", e.TTFTHeadroom, "normTTFTHeadroom", nTTFTH, + "tpotHeadroom", e.Headroom, "normTPOTHeadroom", nTPOTH, "combined", combined, "weight", w) } return weightedChoices, total diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go index eec531a30..949799ad7 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go +++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/types.go @@ -22,8 +22,8 @@ import schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/sche type headroomStrategy string type choice struct { - podName schedulingtypes.Pod - weight int + endpointName schedulingtypes.Endpoint + weight int } const ( diff --git a/pkg/epp/scheduling/framework/plugins/picker/common.go b/pkg/epp/scheduling/framework/plugins/picker/common.go index c8655840f..4484f7b63 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/common.go +++ b/pkg/epp/scheduling/framework/plugins/picker/common.go @@ -32,13 +32,13 @@ type pickerParameters struct { MaxNumOfEndpoints int `json:"maxNumOfEndpoints"` } -func shuffleScoredPods(scoredPods []*types.ScoredPod) { +func shuffleScoredEndpoints(scoredEndpoints []*types.ScoredEndpoint) { // Rand package is not safe for concurrent use, so we create a new instance. // Source: https://pkg.go.dev/math/rand/v2#pkg-overview randomGenerator := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)) // Shuffle in-place - randomGenerator.Shuffle(len(scoredPods), func(i, j int) { - scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i] + randomGenerator.Shuffle(len(scoredEndpoints), func(i, j int) { + scoredEndpoints[i], scoredEndpoints[j] = scoredEndpoints[j], scoredEndpoints[i] }) } diff --git a/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go b/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go index 33e99bd06..7f18c86e3 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go +++ b/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go @@ -78,15 +78,15 @@ func (p *MaxScorePicker) TypedName() plugins.TypedName { return p.typedName } -// Pick selects the pod with the maximum score from the list of candidates. -func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { - log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates sorted by max score", "max-num-of-endpoints", p.maxNumOfEndpoints, - "num-of-candidates", len(scoredPods), "scored-pods", scoredPods) +// Pick selects the endpoint with the maximum score from the list of candidates. +func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredEndpoints []*types.ScoredEndpoint) *types.ProfileRunResult { + log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting endpoints from candidates sorted by max score", "max-num-of-endpoints", p.maxNumOfEndpoints, + "num-of-candidates", len(scoredEndpoints), "scored-endpoints", scoredEndpoints) // Shuffle in-place - needed for random tie break when scores are equal - shuffleScoredPods(scoredPods) + shuffleScoredEndpoints(scoredEndpoints) - slices.SortStableFunc(scoredPods, func(i, j *types.ScoredPod) int { // highest score first + slices.SortStableFunc(scoredEndpoints, func(i, j *types.ScoredEndpoint) int { // highest score first if i.Score > j.Score { return -1 } @@ -96,15 +96,15 @@ func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState, return 0 }) - // if we have enough pods to return keep only the "maxNumOfEndpoints" highest scored pods - if p.maxNumOfEndpoints < len(scoredPods) { - scoredPods = scoredPods[:p.maxNumOfEndpoints] + // if we have enough endpoints to return keep only the "maxNumOfEndpoints" highest scored endpoints + if p.maxNumOfEndpoints < len(scoredEndpoints) { + scoredEndpoints = scoredEndpoints[:p.maxNumOfEndpoints] } - targetPods := make([]types.Pod, len(scoredPods)) - for i, scoredPod := range scoredPods { - targetPods[i] = scoredPod + targetEndpoints := make([]types.Endpoint, len(scoredEndpoints)) + for i, scoredEndpoint := range scoredEndpoints { + targetEndpoints[i] = scoredEndpoint } - return &types.ProfileRunResult{TargetPods: targetPods} + return &types.ProfileRunResult{TargetEndpoints: targetEndpoints} } diff --git a/pkg/epp/scheduling/framework/plugins/picker/picker_test.go b/pkg/epp/scheduling/framework/plugins/picker/picker_test.go index 022328efd..07fca36f0 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/picker_test.go +++ b/pkg/epp/scheduling/framework/plugins/picker/picker_test.go @@ -25,88 +25,88 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) func TestPickMaxScorePicker(t *testing.T) { - pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} - pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} - pod3 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}} + endpoint1 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + endpoint2 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} + endpoint3 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}} tests := []struct { name string picker framework.Picker - input []*types.ScoredPod - output []types.Pod + input []*types.ScoredEndpoint + output []types.Endpoint tieBreakCandidates int // tie break is random, specify how many candidate with max score }{ { name: "Single max score", picker: NewMaxScorePicker(1), - input: []*types.ScoredPod{ - {Pod: pod1, Score: 10}, - {Pod: pod2, Score: 25}, - {Pod: pod3, Score: 15}, + input: []*types.ScoredEndpoint{ + {Endpoint: endpoint1, Score: 10}, + {Endpoint: endpoint2, Score: 25}, + {Endpoint: endpoint3, Score: 15}, }, - output: []types.Pod{ - &types.ScoredPod{Pod: pod2, Score: 25}, + output: []types.Endpoint{ + &types.ScoredEndpoint{Endpoint: endpoint2, Score: 25}, }, }, { name: "Multiple max scores, all are equally scored", picker: NewMaxScorePicker(2), - input: []*types.ScoredPod{ - {Pod: pod1, Score: 50}, - {Pod: pod2, Score: 50}, - {Pod: pod3, Score: 30}, + input: []*types.ScoredEndpoint{ + {Endpoint: endpoint1, Score: 50}, + {Endpoint: endpoint2, Score: 50}, + {Endpoint: endpoint3, Score: 30}, }, - output: []types.Pod{ - &types.ScoredPod{Pod: pod1, Score: 50}, - &types.ScoredPod{Pod: pod2, Score: 50}, + output: []types.Endpoint{ + &types.ScoredEndpoint{Endpoint: endpoint1, Score: 50}, + &types.ScoredEndpoint{Endpoint: endpoint2, Score: 50}, }, tieBreakCandidates: 2, }, { name: "Multiple results sorted by highest score, more pods than needed", picker: NewMaxScorePicker(2), - input: []*types.ScoredPod{ - {Pod: pod1, Score: 20}, - {Pod: pod2, Score: 25}, - {Pod: pod3, Score: 30}, + input: []*types.ScoredEndpoint{ + {Endpoint: endpoint1, Score: 20}, + {Endpoint: endpoint2, Score: 25}, + {Endpoint: endpoint3, Score: 30}, }, - output: []types.Pod{ - &types.ScoredPod{Pod: pod3, Score: 30}, - &types.ScoredPod{Pod: pod2, Score: 25}, + output: []types.Endpoint{ + &types.ScoredEndpoint{Endpoint: endpoint3, Score: 30}, + &types.ScoredEndpoint{Endpoint: endpoint2, Score: 25}, }, }, { name: "Multiple results sorted by highest score, less pods than needed", picker: NewMaxScorePicker(4), // picker is required to return 4 pods at most, but we have only 3. - input: []*types.ScoredPod{ - {Pod: pod1, Score: 20}, - {Pod: pod2, Score: 25}, - {Pod: pod3, Score: 30}, + input: []*types.ScoredEndpoint{ + {Endpoint: endpoint1, Score: 20}, + {Endpoint: endpoint2, Score: 25}, + {Endpoint: endpoint3, Score: 30}, }, - output: []types.Pod{ - &types.ScoredPod{Pod: pod3, Score: 30}, - &types.ScoredPod{Pod: pod2, Score: 25}, - &types.ScoredPod{Pod: pod1, Score: 20}, + output: []types.Endpoint{ + &types.ScoredEndpoint{Endpoint: endpoint3, Score: 30}, + &types.ScoredEndpoint{Endpoint: endpoint2, Score: 25}, + &types.ScoredEndpoint{Endpoint: endpoint1, Score: 20}, }, }, { name: "Multiple results sorted by highest score, num of pods exactly needed", picker: NewMaxScorePicker(3), // picker is required to return 3 pods at most, we have only 3. - input: []*types.ScoredPod{ - {Pod: pod1, Score: 30}, - {Pod: pod2, Score: 25}, - {Pod: pod3, Score: 30}, + input: []*types.ScoredEndpoint{ + {Endpoint: endpoint1, Score: 30}, + {Endpoint: endpoint2, Score: 25}, + {Endpoint: endpoint3, Score: 30}, }, - output: []types.Pod{ - &types.ScoredPod{Pod: pod1, Score: 30}, - &types.ScoredPod{Pod: pod3, Score: 30}, - &types.ScoredPod{Pod: pod2, Score: 25}, + output: []types.Endpoint{ + &types.ScoredEndpoint{Endpoint: endpoint1, Score: 30}, + &types.ScoredEndpoint{Endpoint: endpoint3, Score: 30}, + &types.ScoredEndpoint{Endpoint: endpoint2, Score: 25}, }, tieBreakCandidates: 2, }, @@ -115,13 +115,13 @@ func TestPickMaxScorePicker(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { result := test.picker.Pick(context.Background(), types.NewCycleState(), test.input) - got := result.TargetPods + got := result.TargetEndpoints if test.tieBreakCandidates > 0 { - testMaxScoredPods := test.output[:test.tieBreakCandidates] - gotMaxScoredPods := got[:test.tieBreakCandidates] - diff := cmp.Diff(testMaxScoredPods, gotMaxScoredPods, cmpopts.SortSlices(func(a, b types.Pod) bool { - return a.String() < b.String() // predictable order within the pods with equal scores + testMaxScoredEndpoints := test.output[:test.tieBreakCandidates] + gotMaxScoredEndpoints := got[:test.tieBreakCandidates] + diff := cmp.Diff(testMaxScoredEndpoints, gotMaxScoredEndpoints, cmpopts.SortSlices(func(a, b types.Endpoint) bool { + return a.String() < b.String() // predictable order within the endpoints with equal scores })) if diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) @@ -143,53 +143,53 @@ func TestPickWeightedRandomPicker(t *testing.T) { tolerance = 0.05 // Verify within tolerance ±5% ) - pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} - pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} - pod3 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}} - pod4 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod4"}}} - pod5 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod5"}}} + endpoint1 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + endpoint2 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} + endpoint3 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}} + endpoint4 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod4"}}} + endpoint5 := &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod5"}}} // A-Res algorithm uses U^(1/w) transformation which introduces statistical variance // beyond simple proportional sampling. Generous tolerance is required to prevent // flaky tests in CI environments, especially for multi-tier weights. tests := []struct { name string - input []*types.ScoredPod + input []*types.ScoredEndpoint maxPods int // maxNumOfEndpoints for this test }{ { name: "High weight dominance test", - input: []*types.ScoredPod{ - {Pod: pod1, Score: 10}, // Lower weight - {Pod: pod2, Score: 90}, // Higher weight (should dominate) + input: []*types.ScoredEndpoint{ + {Endpoint: endpoint1, Score: 10}, // Lower weight + {Endpoint: endpoint2, Score: 90}, // Higher weight (should dominate) }, maxPods: 1, }, { name: "Equal weights test - A-Res uniform distribution", - input: []*types.ScoredPod{ - {Pod: pod1, Score: 100}, // Equal weights (higher values for better numerical precision) - {Pod: pod2, Score: 100}, // Equal weights should yield uniform distribution - {Pod: pod3, Score: 100}, // Equal weights in A-Res + input: []*types.ScoredEndpoint{ + {Endpoint: endpoint1, Score: 100}, // Equal weights (higher values for better numerical precision) + {Endpoint: endpoint2, Score: 100}, // Equal weights should yield uniform distribution + {Endpoint: endpoint3, Score: 100}, // Equal weights in A-Res }, maxPods: 1, }, { name: "Zero weight exclusion test - A-Res edge case", - input: []*types.ScoredPod{ - {Pod: pod1, Score: 30}, // Normal weight, should be selected - {Pod: pod2, Score: 0}, // Zero weight, never selected in A-Res + input: []*types.ScoredEndpoint{ + {Endpoint: endpoint1, Score: 30}, // Normal weight, should be selected + {Endpoint: endpoint2, Score: 0}, // Zero weight, never selected in A-Res }, maxPods: 1, }, { name: "Multi-tier weighted test - A-Res complex distribution", - input: []*types.ScoredPod{ - {Pod: pod1, Score: 100}, // Highest weight - {Pod: pod2, Score: 90}, // High weight - {Pod: pod3, Score: 50}, // Medium weight - {Pod: pod4, Score: 30}, // Low weight - {Pod: pod5, Score: 20}, // Lowest weight + input: []*types.ScoredEndpoint{ + {Endpoint: endpoint1, Score: 100}, // Highest weight + {Endpoint: endpoint2, Score: 90}, // High weight + {Endpoint: endpoint3, Score: 50}, // Medium weight + {Endpoint: endpoint4, Score: 30}, // Low weight + {Endpoint: endpoint5, Score: 20}, // Lowest weight }, maxPods: 1, }, @@ -207,10 +207,10 @@ func TestPickWeightedRandomPicker(t *testing.T) { // Calculate expected probabilities based on scores expectedProbabilities := make(map[string]float64) - for _, pod := range test.input { - podName := pod.GetPod().NamespacedName.Name + for _, endpoint := range test.input { + podName := endpoint.GetMetadata().NamespacedName.Name if totalScore > 0 { - expectedProbabilities[podName] = pod.Score / totalScore + expectedProbabilities[podName] = endpoint.Score / totalScore } else { expectedProbabilities[podName] = 0.0 } @@ -218,9 +218,9 @@ func TestPickWeightedRandomPicker(t *testing.T) { // Initialize selection counters for each pod selectionCounts := make(map[string]int) - for _, pod := range test.input { - podName := pod.GetPod().NamespacedName.Name - selectionCounts[podName] = 0 + for _, endpoint := range test.input { + endpointName := endpoint.GetMetadata().NamespacedName.Name + selectionCounts[endpointName] = 0 } // Run multiple iterations to gather statistical data @@ -228,21 +228,21 @@ func TestPickWeightedRandomPicker(t *testing.T) { result := picker.Pick(context.Background(), types.NewCycleState(), test.input) // Count selections for probability analysis - selectedPodName := result.TargetPods[0].GetPod().NamespacedName.Name - selectionCounts[selectedPodName]++ + selectedEndpointName := result.TargetEndpoints[0].GetMetadata().NamespacedName.Name + selectionCounts[selectedEndpointName]++ } // Verify probability distribution - for podName, expectedProb := range expectedProbabilities { - actualCount := selectionCounts[podName] + for endpointName, expectedProb := range expectedProbabilities { + actualCount := selectionCounts[endpointName] actualProb := float64(actualCount) / float64(testIterations) if math.Abs(actualProb-expectedProb) > tolerance { - t.Errorf("Pod %s: expected probability %.3f ±%.1f%%, got %.3f (count: %d/%d)", - podName, expectedProb, tolerance*100, actualProb, actualCount, testIterations) + t.Errorf("Endpoint %s: expected probability %.3f ±%.1f%%, got %.3f (count: %d/%d)", + endpointName, expectedProb, tolerance*100, actualProb, actualCount, testIterations) } else { - t.Logf("Pod %s: expected %.3f, got %.3f (count: %d/%d) ✓", - podName, expectedProb, actualProb, actualCount, testIterations) + t.Logf("Endpoint %s: expected %.3f, got %.3f (count: %d/%d) ✓", + endpointName, expectedProb, actualProb, actualCount, testIterations) } } }) diff --git a/pkg/epp/scheduling/framework/plugins/picker/random_picker.go b/pkg/epp/scheduling/framework/plugins/picker/random_picker.go index 10ad68469..447e9ef61 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/random_picker.go +++ b/pkg/epp/scheduling/framework/plugins/picker/random_picker.go @@ -60,7 +60,7 @@ func NewRandomPicker(maxNumOfEndpoints int) *RandomPicker { } } -// RandomPicker picks random pod(s) from the list of candidates. +// RandomPicker picks random endpoint(s) from the list of candidates. type RandomPicker struct { typedName plugins.TypedName maxNumOfEndpoints int @@ -77,23 +77,23 @@ func (p *RandomPicker) TypedName() plugins.TypedName { return p.typedName } -// Pick selects random pod(s) from the list of candidates. -func (p *RandomPicker) Pick(ctx context.Context, _ *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { - log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates randomly", "max-num-of-endpoints", p.maxNumOfEndpoints, - "num-of-candidates", len(scoredPods), "scored-pods", scoredPods) +// Pick selects random endpoint(s) from the list of candidates. +func (p *RandomPicker) Pick(ctx context.Context, _ *types.CycleState, scoredEndpoints []*types.ScoredEndpoint) *types.ProfileRunResult { + log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting endpoints from candidates randomly", "max-num-of-endpoints", p.maxNumOfEndpoints, + "num-of-candidates", len(scoredEndpoints), "scored-endpoints", scoredEndpoints) // Shuffle in-place - shuffleScoredPods(scoredPods) + shuffleScoredEndpoints(scoredEndpoints) - // if we have enough pods to return keep only the relevant subset - if p.maxNumOfEndpoints < len(scoredPods) { - scoredPods = scoredPods[:p.maxNumOfEndpoints] + // if we have enough endpoints to return keep only the relevant subset + if p.maxNumOfEndpoints < len(scoredEndpoints) { + scoredEndpoints = scoredEndpoints[:p.maxNumOfEndpoints] } - targetPods := make([]types.Pod, len(scoredPods)) - for i, scoredPod := range scoredPods { - targetPods[i] = scoredPod + targetEndpoints := make([]types.Endpoint, len(scoredEndpoints)) + for i, scoredEndpoint := range scoredEndpoints { + targetEndpoints[i] = scoredEndpoint } - return &types.ProfileRunResult{TargetPods: targetPods} + return &types.ProfileRunResult{TargetEndpoints: targetEndpoints} } diff --git a/pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go b/pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go index 540ede43c..0ddd8c87b 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go +++ b/pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go @@ -38,9 +38,9 @@ const ( WeightedRandomPickerType = "weighted-random-picker" ) -// weightedScoredPod represents a scored pod with its A-Res sampling key -type weightedScoredPod struct { - *types.ScoredPod +// weightedScoredEndpoint represents a scored endpoint with its A-Res sampling key +type weightedScoredEndpoint struct { + *types.ScoredEndpoint key float64 } @@ -72,10 +72,10 @@ func NewWeightedRandomPicker(maxNumOfEndpoints int) *WeightedRandomPicker { } } -// WeightedRandomPicker picks pod(s) from the list of candidates based on weighted random sampling using A-Res algorithm. +// WeightedRandomPicker picks endpoint(s) from the list of candidates based on weighted random sampling using A-Res algorithm. // Reference: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf. // -// The picker at its core is picking pods randomly, where the probability of the pod to get picked is derived +// The picker at its core is picking endpoints randomly, where the probability of the endpoint to get picked is derived // from its weighted score. // Algorithm: // - Uses A-Res (Algorithm for Reservoir Sampling): keyᵢ = Uᵢ^(1/wᵢ) @@ -102,52 +102,52 @@ func (p *WeightedRandomPicker) TypedName() plugins.TypedName { return p.typedName } -// Pick selects the pod(s) randomly from the list of candidates, where the probability of the pod to get picked is derived +// Pick selects the endpoint(s) randomly from the list of candidates, where the probability of the endpoint to get picked is derived // from its weighted score. -func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { - // Check if there is at least one pod with Score > 0, if not let random picker run - if slices.IndexFunc(scoredPods, func(scoredPod *types.ScoredPod) bool { return scoredPod.Score > 0 }) == -1 { +func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredEndpoints []*types.ScoredEndpoint) *types.ProfileRunResult { + // Check if there is at least one endpoint with Score > 0, if not let random picker run + if slices.IndexFunc(scoredEndpoints, func(scoredEndpoint *types.ScoredEndpoint) bool { return scoredEndpoint.Score > 0 }) == -1 { log.FromContext(ctx).V(logutil.DEBUG).Info("All scores are zero, delegating to RandomPicker for uniform selection") - return p.randomPicker.Pick(ctx, cycleState, scoredPods) + return p.randomPicker.Pick(ctx, cycleState, scoredEndpoints) } - log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates by random weighted picker", "max-num-of-endpoints", p.maxNumOfEndpoints, - "num-of-candidates", len(scoredPods), "scored-pods", scoredPods) + log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting endpoints from candidates by random weighted picker", "max-num-of-endpoints", p.maxNumOfEndpoints, + "num-of-candidates", len(scoredEndpoints), "scored-endpoints", scoredEndpoints) randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano())) // A-Res algorithm: keyᵢ = Uᵢ^(1/wᵢ) - weightedPods := make([]weightedScoredPod, len(scoredPods)) + weightedEndpoints := make([]weightedScoredEndpoint, len(scoredEndpoints)) - for i, scoredPod := range scoredPods { + for i, scoredEndpoint := range scoredEndpoints { // Handle zero score - if scoredPod.Score <= 0 { - // Assign key=0 for zero-score pods (effectively excludes them from selection) - weightedPods[i] = weightedScoredPod{ScoredPod: scoredPod, key: 0} + if scoredEndpoint.Score <= 0 { + // Assign key=0 for zero-score endpoints (effectively excludes them from selection) + weightedEndpoints[i] = weightedScoredEndpoint{ScoredEndpoint: scoredEndpoint, key: 0} continue } - // If we're here the scoredPod.Score > 0. Generate a random number U in (0,1) + // If we're here the scoredEndpoint.Score > 0. Generate a random number U in (0,1) u := randomGenerator.Float64() if u == 0 { u = 1e-10 // Avoid log(0) } - weightedPods[i] = weightedScoredPod{ScoredPod: scoredPod, key: math.Pow(u, 1.0/scoredPod.Score)} // key = U^(1/weight) + weightedEndpoints[i] = weightedScoredEndpoint{ScoredEndpoint: scoredEndpoint, key: math.Pow(u, 1.0/scoredEndpoint.Score)} // key = U^(1/weight) } // Sort by key in descending order (largest keys first) - sort.Slice(weightedPods, func(i, j int) bool { - return weightedPods[i].key > weightedPods[j].key + sort.Slice(weightedEndpoints, func(i, j int) bool { + return weightedEndpoints[i].key > weightedEndpoints[j].key }) - // Select top k pods - selectedCount := min(p.maxNumOfEndpoints, len(weightedPods)) + // Select top k endpoints + selectedCount := min(p.maxNumOfEndpoints, len(weightedEndpoints)) - targetPods := make([]types.Pod, selectedCount) + targetEndpoints := make([]types.Endpoint, selectedCount) for i := range selectedCount { - targetPods[i] = weightedPods[i].ScoredPod + targetEndpoints[i] = weightedEndpoints[i].ScoredEndpoint } - return &types.ProfileRunResult{TargetPods: targetPods} + return &types.ProfileRunResult{TargetEndpoints: targetEndpoints} } diff --git a/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization.go b/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization.go index d78bcc9ed..ebe44420d 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization.go @@ -45,7 +45,7 @@ func NewKVCacheUtilizationScorer() *KVCacheUtilizationScorer { } } -// KVCacheUtilizationScorer scores list of candidate pods based on KV cache utilization. +// KVCacheUtilizationScorer scores list of candidate endpoints based on KV cache utilization. type KVCacheUtilizationScorer struct { typedName plugins.TypedName } @@ -68,11 +68,11 @@ func (s *KVCacheUtilizationScorer) WithName(name string) *KVCacheUtilizationScor return s } -// Score returns the scoring result for the given list of pods based on context. -func (s *KVCacheUtilizationScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { - scores := make(map[types.Pod]float64, len(pods)) - for _, pod := range pods { - scores[pod] = 1 - pod.GetMetrics().KVCacheUsagePercent +// Score returns the scoring result for the given list of endpoints based on context. +func (s *KVCacheUtilizationScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, endpoints []types.Endpoint) map[types.Endpoint]float64 { + scores := make(map[types.Endpoint]float64, len(endpoints)) + for _, endpoint := range endpoints { + scores[endpoint] = 1 - endpoint.GetMetrics().KVCacheUsagePercent } return scores } diff --git a/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization_test.go b/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization_test.go index 76aaeee31..02ad28e83 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization_test.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization_test.go @@ -22,25 +22,25 @@ import ( "github.com/stretchr/testify/assert" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) func TestKvCacheUtilizationScorer(t *testing.T) { tests := []struct { - name string - pods []types.Pod - expectedScoresPod map[int]float64 // Map of pod index to expected score + name string + endpoints []types.Endpoint + expectedScoresEndpoint map[int]float64 // Map of endpoint index to expected score }{ { name: "Different KV cache utilization", - pods: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.8}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.5}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.0}}, + endpoints: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.8}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.5}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.0}}, }, - expectedScoresPod: map[int]float64{ + expectedScoresEndpoint: map[int]float64{ 0: 0.2, // Highest KV cache usage (0.8) gets lowest score (1-0.8=0.2) 1: 0.5, // Medium KV cache usage (0.5) gets medium score (1-0.5=0.5) 2: 1.0, // No KV cache usage (0.0) gets highest score (1-0=1.0) @@ -48,33 +48,33 @@ func TestKvCacheUtilizationScorer(t *testing.T) { }, { name: "Same KV cache utilization", - pods: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.6}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.6}}, + endpoints: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.6}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.6}}, }, - expectedScoresPod: map[int]float64{ + expectedScoresEndpoint: map[int]float64{ 0: 0.4, // Both get same score (1-0.6=0.4) 1: 0.4, }, }, { name: "Zero KV cache utilization", - pods: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.0}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.0}}, + endpoints: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.0}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.0}}, }, - expectedScoresPod: map[int]float64{ + expectedScoresEndpoint: map[int]float64{ 0: 1.0, // No KV cache usage gets highest score 1: 1.0, }, }, { name: "Full KV cache utilization", - pods: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 1.0}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.5}}, + endpoints: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 1.0}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{KVCacheUsagePercent: 0.5}}, }, - expectedScoresPod: map[int]float64{ + expectedScoresEndpoint: map[int]float64{ 0: 0.0, // Full KV cache (1.0) gets lowest score (1-1=0) 1: 0.5, // Half KV cache (0.5) gets medium score (1-0.5=0.5) }, @@ -83,11 +83,11 @@ func TestKvCacheUtilizationScorer(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - scores := NewKVCacheUtilizationScorer().Score(context.Background(), types.NewCycleState(), &types.LLMRequest{}, test.pods) + scores := NewKVCacheUtilizationScorer().Score(context.Background(), types.NewCycleState(), &types.LLMRequest{}, test.endpoints) - for i, pod := range test.pods { - expectedScore := test.expectedScoresPod[i] - assert.InDelta(t, expectedScore, scores[pod], 0.0001, "Pod %d should have score %f", i, expectedScore) + for i, endpoint := range test.endpoints { + expectedScore := test.expectedScoresEndpoint[i] + assert.InDelta(t, expectedScore, scores[endpoint], 0.0001, "Endpoint %d should have score %f", i, expectedScore) } }) } diff --git a/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity.go b/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity.go index ab4fd9874..688cbb199 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity.go @@ -69,28 +69,28 @@ func (s *LoraAffinityScorer) WithName(name string) *LoraAffinityScorer { return s } -func (s *LoraAffinityScorer) Score(_ context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { - scores := make(map[types.Pod]float64, len(pods)) +func (s *LoraAffinityScorer) Score(_ context.Context, _ *types.CycleState, request *types.LLMRequest, endpoints []types.Endpoint) map[types.Endpoint]float64 { + scores := make(map[types.Endpoint]float64, len(endpoints)) - // Assign a score to each pod for loading the target adapter. - for _, pod := range pods { - _, active := pod.GetMetrics().ActiveModels[request.TargetModel] - _, waiting := pod.GetMetrics().WaitingModels[request.TargetModel] + // Assign a score to each endpoint for loading the target adapter. + for _, endpoint := range endpoints { + _, active := endpoint.GetMetrics().ActiveModels[request.TargetModel] + _, waiting := endpoint.GetMetrics().WaitingModels[request.TargetModel] // Determine the model server's suitability score based on adapter load status and capacity. switch { // Ideal: The adapter is already active on this model server. case active: - scores[pod] = 1.0 + scores[endpoint] = 1.0 // Good: The model server has capacity to load at least one more adapter. - case len(pod.GetMetrics().ActiveModels)+len(pod.GetMetrics().WaitingModels) < pod.GetMetrics().MaxActiveModels: - scores[pod] = 0.8 + case len(endpoint.GetMetrics().ActiveModels)+len(endpoint.GetMetrics().WaitingModels) < endpoint.GetMetrics().MaxActiveModels: + scores[endpoint] = 0.8 // Moderate: The adapter is already in the queue to be loaded on this model server. case waiting: - scores[pod] = 0.6 + scores[endpoint] = 0.6 // Unsuitable: The model server has reached its maximum capacity and cannot load the adapter. default: - scores[pod] = 0.0 + scores[endpoint] = 0.0 } } diff --git a/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity_test.go b/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity_test.go index 418bafde0..c65b7bbc2 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity_test.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity_test.go @@ -23,24 +23,24 @@ import ( "github.com/stretchr/testify/assert" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) func TestLoraAffinityScorer(t *testing.T) { tests := []struct { - name string - request *types.LLMRequest - pods []types.Pod - expectedScoresPod map[string]float64 // Map of pod name to expected score + name string + request *types.LLMRequest + endpoints []types.Endpoint + expectedScoresEndpoint map[string]float64 // Map of endpoint name to expected score }{ { name: "Target model is active", request: &types.LLMRequest{TargetModel: "active-model-1"}, - pods: []types.Pod{ + endpoints: []types.Endpoint{ &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{ ActiveModels: map[string]int{"active-model-1": 1}, WaitingModels: map[string]int{}, @@ -48,16 +48,16 @@ func TestLoraAffinityScorer(t *testing.T) { }, }, }, - expectedScoresPod: map[string]float64{ + expectedScoresEndpoint: map[string]float64{ "pod1": 1.0, }, }, { name: "Target model is waiting", request: &types.LLMRequest{TargetModel: "active-model-1"}, - pods: []types.Pod{ + endpoints: []types.Endpoint{ &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{ ActiveModels: map[string]int{"active-model-2": 2}, WaitingModels: map[string]int{"active-model-1": 1}, @@ -65,16 +65,16 @@ func TestLoraAffinityScorer(t *testing.T) { }, }, }, - expectedScoresPod: map[string]float64{ + expectedScoresEndpoint: map[string]float64{ "pod1": 0.6, }, }, { - name: "Pods have no space for new model", + name: "Endpoints have no space for new model", request: &types.LLMRequest{TargetModel: "active-model-1"}, - pods: []types.Pod{ + endpoints: []types.Endpoint{ &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{ ActiveModels: map[string]int{"active-model-2": 2}, WaitingModels: map[string]int{"active-model-3": 1}, @@ -82,7 +82,7 @@ func TestLoraAffinityScorer(t *testing.T) { }, }, &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: &backendmetrics.MetricsState{ ActiveModels: map[string]int{}, WaitingModels: map[string]int{}, @@ -90,17 +90,17 @@ func TestLoraAffinityScorer(t *testing.T) { }, }, }, - expectedScoresPod: map[string]float64{ + expectedScoresEndpoint: map[string]float64{ "pod1": 0.0, "pod2": 0.0, }, }, { - name: "Multiple pods with mixed active and waiting models", + name: "Multipleendpoints with mixed active and waiting models", request: &types.LLMRequest{TargetModel: "active-model-1"}, - pods: []types.Pod{ + endpoints: []types.Endpoint{ &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{ ActiveModels: map[string]int{"active-model-1": 1}, WaitingModels: map[string]int{}, @@ -108,7 +108,7 @@ func TestLoraAffinityScorer(t *testing.T) { }, }, &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: &backendmetrics.MetricsState{ ActiveModels: map[string]int{"active-model-2": 4}, WaitingModels: map[string]int{"active-model-1": 1}, @@ -116,7 +116,7 @@ func TestLoraAffinityScorer(t *testing.T) { }, }, &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, MetricsState: &backendmetrics.MetricsState{ ActiveModels: map[string]int{"active-model-2": 1}, WaitingModels: map[string]int{}, @@ -124,7 +124,7 @@ func TestLoraAffinityScorer(t *testing.T) { }, }, &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod4"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod4"}}, MetricsState: &backendmetrics.MetricsState{ ActiveModels: map[string]int{"active-model-3": 1}, WaitingModels: map[string]int{"active-model-1": 1}, @@ -132,7 +132,7 @@ func TestLoraAffinityScorer(t *testing.T) { }, }, &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod5"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod5"}}, MetricsState: &backendmetrics.MetricsState{ ActiveModels: map[string]int{"active-model-4": 1, "active-model-5": 1}, WaitingModels: map[string]int{}, @@ -140,7 +140,7 @@ func TestLoraAffinityScorer(t *testing.T) { }, }, }, - expectedScoresPod: map[string]float64{ + expectedScoresEndpoint: map[string]float64{ "pod1": 1.0, "pod2": 0.8, "pod3": 0.8, @@ -149,26 +149,26 @@ func TestLoraAffinityScorer(t *testing.T) { }, }, { - name: "Empty pods slice", - request: &types.LLMRequest{TargetModel: "modelA"}, - pods: []types.Pod{}, - expectedScoresPod: map[string]float64{}, // No pods, no scores + name: "Empty pods slice", + request: &types.LLMRequest{TargetModel: "modelA"}, + endpoints: []types.Endpoint{}, + expectedScoresEndpoint: map[string]float64{}, // No pods, no scores }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { scorer := &LoraAffinityScorer{} - scores := scorer.Score(context.Background(), types.NewCycleState(), test.request, test.pods) + scores := scorer.Score(context.Background(), types.NewCycleState(), test.request, test.endpoints) - for _, pod := range test.pods { - expectedScore, ok := test.expectedScoresPod[pod.GetPod().NamespacedName.Name] + for _, endpoint := range test.endpoints { + expectedScore, ok := test.expectedScoresEndpoint[endpoint.GetMetadata().NamespacedName.Name] if !ok { - t.Fatalf("Expected score not found for pod %s in test %s", pod.GetPod().NamespacedName, test.name) + t.Fatalf("Expected score not found for endpoint %s in test %s", endpoint.GetMetadata().NamespacedName, test.name) } - assert.InDelta(t, expectedScore, scores[pod], 0.0001, "Pod %s should have score %f", pod.GetPod().NamespacedName.Name, expectedScore) + assert.InDelta(t, expectedScore, scores[endpoint], 0.0001, "Endpoint %s should have score %f", endpoint.GetMetadata().NamespacedName.Name, expectedScore) } - assert.Len(t, scores, len(test.expectedScoresPod), "Number of scored pods should match expected") + assert.Len(t, scores, len(test.expectedScoresEndpoint), "Number of scored endpoints should match expected") }) } } diff --git a/pkg/epp/scheduling/framework/plugins/scorer/queue.go b/pkg/epp/scheduling/framework/plugins/scorer/queue.go index e2a07b0bd..4b9ca66b9 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/queue.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/queue.go @@ -70,14 +70,14 @@ func (s *QueueScorer) WithName(name string) *QueueScorer { return s } -// Score returns the scoring result for the given list of pods based on context. -func (s *QueueScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +// Score returns the scoring result for the given list of endpoints based on context. +func (s *QueueScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, endpoints []types.Endpoint) map[types.Endpoint]float64 { minQueueSize := math.MaxInt maxQueueSize := math.MinInt - // Iterate through the remaining pods to find min and max - for _, pod := range pods { - queueSize := pod.GetMetrics().WaitingQueueSize + // Iterate through the remaining endpoints to find min and max + for _, endpoint := range endpoints { + queueSize := endpoint.GetMetrics().WaitingQueueSize if queueSize < minQueueSize { minQueueSize = queueSize } @@ -86,19 +86,19 @@ func (s *QueueScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLM } } - // podScoreFunc calculates the score based on the queue size of each pod. Longer queue gets a lower score. - podScoreFunc := func(pod types.Pod) float64 { + // endpointScoreFunc calculates the score based on the queue size of each endpoint. Longer queue gets a lower score. + endpointScoreFunc := func(endpoint types.Endpoint) float64 { if maxQueueSize == minQueueSize { // If all pods have the same queue size, return a neutral score return 1.0 } - return float64(maxQueueSize-pod.GetMetrics().WaitingQueueSize) / float64(maxQueueSize-minQueueSize) + return float64(maxQueueSize-endpoint.GetMetrics().WaitingQueueSize) / float64(maxQueueSize-minQueueSize) } - // Create a map to hold the scores for each pod - scores := make(map[types.Pod]float64, len(pods)) - for _, pod := range pods { - scores[pod] = podScoreFunc(pod) + // Create a map to hold the scores for each endpoint + scores := make(map[types.Endpoint]float64, len(endpoints)) + for _, endpoint := range endpoints { + scores[endpoint] = endpointScoreFunc(endpoint) } return scores } diff --git a/pkg/epp/scheduling/framework/plugins/scorer/queue_test.go b/pkg/epp/scheduling/framework/plugins/scorer/queue_test.go index ce8193679..3959d3ba6 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/queue_test.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/queue_test.go @@ -22,25 +22,25 @@ import ( "github.com/stretchr/testify/assert" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) func TestQueueScorer(t *testing.T) { tests := []struct { - name string - pods []types.Pod - expectedScoresPod map[int]float64 // Map of pod index to expected score + name string + endpoints []types.Endpoint + expectedScoresEndpoint map[int]float64 // Map of endpoint index to expected score }{ { name: "Different queue sizes", - pods: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 10}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 5}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0}}, + endpoints: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 10}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 5}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0}}, }, - expectedScoresPod: map[int]float64{ + expectedScoresEndpoint: map[int]float64{ 0: 0.0, // Longest queue (10) gets lowest score 1: 0.5, // Medium queue (5) gets medium score 2: 1.0, // Shortest queue (0) gets highest score @@ -48,22 +48,22 @@ func TestQueueScorer(t *testing.T) { }, { name: "Same queue sizes", - pods: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 5}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 5}}, + endpoints: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 5}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 5}}, }, - expectedScoresPod: map[int]float64{ + expectedScoresEndpoint: map[int]float64{ 0: 1.0, // When all pods have the same queue size, they get the same neutral score 1: 1.0, }, }, { name: "Zero queue sizes", - pods: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0}}, + endpoints: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0}}, }, - expectedScoresPod: map[int]float64{ + expectedScoresEndpoint: map[int]float64{ 0: 1.0, 1: 1.0, }, @@ -74,11 +74,11 @@ func TestQueueScorer(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - scores := scorer.Score(context.Background(), types.NewCycleState(), &types.LLMRequest{}, test.pods) + scores := scorer.Score(context.Background(), types.NewCycleState(), &types.LLMRequest{}, test.endpoints) - for i, pod := range test.pods { - expectedScore := test.expectedScoresPod[i] - assert.InDelta(t, expectedScore, scores[pod], 0.0001, "Pod %d should have score %f", i, expectedScore) + for i, endpoint := range test.endpoints { + expectedScore := test.expectedScoresEndpoint[i] + assert.InDelta(t, expectedScore, scores[endpoint], 0.0001, "Pod %d should have score %f", i, expectedScore) } }) } diff --git a/pkg/epp/scheduling/framework/plugins/scorer/running.go b/pkg/epp/scheduling/framework/plugins/scorer/running.go index c446ac0ca..f816d4768 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/running.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/running.go @@ -71,13 +71,13 @@ func (s *RunningRequestsSizeScorer) WithName(name string) *RunningRequestsSizeSc } // Score returns the scoring result for the given list of pods based on context. -func (s *RunningRequestsSizeScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (s *RunningRequestsSizeScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, endpoints []types.Endpoint) map[types.Endpoint]float64 { minQueueSize := math.MaxInt maxQueueSize := math.MinInt // Iterate through the remaining pods to find min and max - for _, pod := range pods { - queueSize := pod.GetMetrics().RunningRequestsSize + for _, endpoint := range endpoints { + queueSize := endpoint.GetMetrics().RunningRequestsSize if queueSize < minQueueSize { minQueueSize = queueSize } @@ -86,19 +86,19 @@ func (s *RunningRequestsSizeScorer) Score(_ context.Context, _ *types.CycleState } } - // podScoreFunc calculates the score based on the queue size of each pod. Longer queue gets a lower score. - podScoreFunc := func(pod types.Pod) float64 { + // endpointScoreFunc calculates the score based on the queue size of each pod. Longer queue gets a lower score. + endpointScoreFunc := func(endpoint types.Endpoint) float64 { if maxQueueSize == minQueueSize { // If all pods have the same queue size, return a neutral score return 1.0 } - return float64(maxQueueSize-pod.GetMetrics().RunningRequestsSize) / float64(maxQueueSize-minQueueSize) + return float64(maxQueueSize-endpoint.GetMetrics().RunningRequestsSize) / float64(maxQueueSize-minQueueSize) } // Create a map to hold the scores for each pod - scores := make(map[types.Pod]float64, len(pods)) - for _, pod := range pods { - scores[pod] = podScoreFunc(pod) + scores := make(map[types.Endpoint]float64, len(endpoints)) + for _, endpoint := range endpoints { + scores[endpoint] = endpointScoreFunc(endpoint) } return scores } diff --git a/pkg/epp/scheduling/framework/plugins/scorer/running_test.go b/pkg/epp/scheduling/framework/plugins/scorer/running_test.go index 76864d480..0f04ca1ee 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/running_test.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/running_test.go @@ -22,23 +22,23 @@ import ( "github.com/stretchr/testify/assert" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) func TestRunningRequestsSizeScorer(t *testing.T) { tests := []struct { name string - pods []types.Pod + endpoints []types.Endpoint expectedScoresPod map[int]float64 // Map of pod index to expected score }{ { name: "Different running queue sizes", - pods: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 10}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 5}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 0}}, + endpoints: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 10}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 5}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 0}}, }, expectedScoresPod: map[int]float64{ 0: 0.0, // Longest queue (10) gets lowest score @@ -48,9 +48,9 @@ func TestRunningRequestsSizeScorer(t *testing.T) { }, { name: "Same running queue sizes", - pods: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 5}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 5}}, + endpoints: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 5}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 5}}, }, expectedScoresPod: map[int]float64{ 0: 1.0, // When all pods have the same queue size, they get the same neutral score @@ -59,9 +59,9 @@ func TestRunningRequestsSizeScorer(t *testing.T) { }, { name: "Zero running queue sizes", - pods: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 0}}, - &types.PodMetrics{Pod: &backend.Pod{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 0}}, + endpoints: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 0}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{}, MetricsState: &backendmetrics.MetricsState{RunningRequestsSize: 0}}, }, expectedScoresPod: map[int]float64{ 0: 1.0, @@ -74,9 +74,9 @@ func TestRunningRequestsSizeScorer(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - scores := scorer.Score(context.Background(), types.NewCycleState(), &types.LLMRequest{}, test.pods) + scores := scorer.Score(context.Background(), types.NewCycleState(), &types.LLMRequest{}, test.endpoints) - for i, pod := range test.pods { + for i, pod := range test.endpoints { expectedScore := test.expectedScoresPod[i] assert.InDelta(t, expectedScore, scores[pod], 0.0001, "Pod %d should have score %f", i, expectedScore) } diff --git a/pkg/epp/scheduling/framework/plugins/test/filter/filter_test.go b/pkg/epp/scheduling/framework/plugins/test/filter/filter_test.go index 3cd8a6197..72be97f6f 100644 --- a/pkg/epp/scheduling/framework/plugins/test/filter/filter_test.go +++ b/pkg/epp/scheduling/framework/plugins/test/filter/filter_test.go @@ -22,7 +22,7 @@ import ( "github.com/google/go-cmp/cmp" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/test" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -31,88 +31,88 @@ func TestFilter(t *testing.T) { tests := []struct { name string req *types.LLMRequest - input []types.Pod - output []types.Pod + input []types.Endpoint + output []types.Endpoint }{ { name: "header unset in request", req: &types.LLMRequest{}, // Deliberately unset - input: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.1", Port: "3000"}}, + input: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.1", Port: "3000"}}, }, - output: []types.Pod{}, + output: []types.Endpoint{}, }, { name: "header set but no IP match", req: &types.LLMRequest{Headers: map[string]string{test.HeaderTestEppEndPointSelectionKey: "10.0.0.99"}}, - input: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.1", Port: "3000"}}, + input: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.1", Port: "3000"}}, }, - output: []types.Pod{}, + output: []types.Endpoint{}, }, { name: "IP-only header matches pod (port-agnostic)", req: &types.LLMRequest{Headers: map[string]string{test.HeaderTestEppEndPointSelectionKey: "10.0.0.1"}}, - input: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.1", Port: "3002"}}, + input: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.1", Port: "3002"}}, }, - output: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.1", Port: "3002"}}, + output: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.1", Port: "3002"}}, }, }, { name: "IP:port header matches exact port", req: &types.LLMRequest{Headers: map[string]string{test.HeaderTestEppEndPointSelectionKey: "10.0.0.1:3002"}}, - input: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.1", Port: "3000"}}, - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.1", Port: "3002"}}, - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.2", Port: "3002"}}, + input: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.1", Port: "3000"}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.1", Port: "3002"}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.2", Port: "3002"}}, }, - output: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.1", Port: "3002"}}, + output: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.1", Port: "3002"}}, }, }, { name: "IP:port header with non-matching port produces no match", req: &types.LLMRequest{Headers: map[string]string{test.HeaderTestEppEndPointSelectionKey: "10.0.0.1:9999"}}, - input: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.1", Port: "3002"}}, + input: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.1", Port: "3002"}}, }, - output: []types.Pod{}, + output: []types.Endpoint{}, }, { name: "multiple header values (IP and IP:port) produce multiple matches in order and deduped", req: &types.LLMRequest{Headers: map[string]string{test.HeaderTestEppEndPointSelectionKey: "10.0.0.3:3004, 10.0.0.2, 10.0.0.3"}}, - input: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.1", Port: "3000"}}, - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.2", Port: "3002"}}, - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.3", Port: "3004"}}, + input: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.1", Port: "3000"}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.2", Port: "3002"}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.3", Port: "3004"}}, }, - output: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.3", Port: "3004"}}, - &types.PodMetrics{Pod: &backend.Pod{Address: "10.0.0.2", Port: "3002"}}, + output: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.3", Port: "3004"}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "10.0.0.2", Port: "3002"}}, }, }, { name: "IPv6 with brackets and port", req: &types.LLMRequest{Headers: map[string]string{test.HeaderTestEppEndPointSelectionKey: "[fd00::1]:3002"}}, - input: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "fd00::1", Port: "3002"}}, - &types.PodMetrics{Pod: &backend.Pod{Address: "fd00::2", Port: "3002"}}, + input: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "fd00::1", Port: "3002"}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "fd00::2", Port: "3002"}}, }, - output: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "fd00::1", Port: "3002"}}, + output: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "fd00::1", Port: "3002"}}, }, }, { name: "IPv6 bare address (no port)", req: &types.LLMRequest{Headers: map[string]string{test.HeaderTestEppEndPointSelectionKey: "fd00::2"}}, - input: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "fd00::1", Port: "3002"}}, - &types.PodMetrics{Pod: &backend.Pod{Address: "fd00::2", Port: "3004"}}, + input: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "fd00::1", Port: "3002"}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "fd00::2", Port: "3004"}}, }, - output: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{Address: "fd00::2", Port: "3004"}}, + output: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{Address: "fd00::2", Port: "3004"}}, }, }, } diff --git a/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go b/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go index bce3be11a..75145d160 100644 --- a/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go +++ b/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go @@ -49,7 +49,7 @@ func NewHeaderBasedTestingFilter() *HeaderBasedTestingFilter { } } -// HeaderBasedTestingFilter filters Pods based on an address specified in the "test-epp-endpoint-selection" request header. +// HeaderBasedTestingFilter filters Endpoints based on an address specified in the "test-epp-endpoint-selection" request header. type HeaderBasedTestingFilter struct { typedName plugins.TypedName } @@ -65,39 +65,39 @@ func (f *HeaderBasedTestingFilter) WithName(name string) *HeaderBasedTestingFilt return f } -// Filter selects pods whose IP or IP:port matches any value in the +// Filter selects endpoints whose IP or IP:port matches any value in the // "test-epp-endpoint-selection" header. Values may be "IP" or "IP:port". // If a port is provided, only an exact IP:port match is accepted. -func (f *HeaderBasedTestingFilter) Filter(_ context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod { +func (f *HeaderBasedTestingFilter) Filter(_ context.Context, _ *types.CycleState, request *types.LLMRequest, endpoints []types.Endpoint) []types.Endpoint { hv, ok := request.Headers[test.HeaderTestEppEndPointSelectionKey] if !ok || strings.TrimSpace(hv) == "" { - return []types.Pod{} + return []types.Endpoint{} } normalizeIP := func(s string) string { return strings.Trim(s, "[]") } // Build lookup maps: - // ip -> pod - // ip:port -> pod (only when pod GetPort() is non-empty) - ipToPod := make(map[string]types.Pod, len(pods)) - hpToPod := make(map[string]types.Pod, len(pods)) - for _, p := range pods { - if p == nil || p.GetPod() == nil { + // ip -> endpoint + // ip:port -> endpoint (only when endpoint GetPort() is non-empty) + ipToEndpoint := make(map[string]types.Endpoint, len(endpoints)) + hpToPod := make(map[string]types.Endpoint, len(endpoints)) + for _, e := range endpoints { + if e == nil || e.GetMetadata() == nil { continue } - ip := normalizeIP(strings.TrimSpace(p.GetPod().GetIPAddress())) + ip := normalizeIP(strings.TrimSpace(e.GetMetadata().GetIPAddress())) if ip == "" { continue } - ipToPod[ip] = p - if port := strings.TrimSpace(p.GetPod().GetPort()); port != "" { - hpToPod[ip+":"+port] = p + ipToEndpoint[ip] = e + if port := strings.TrimSpace(e.GetMetadata().GetPort()); port != "" { + hpToPod[ip+":"+port] = e } } headerVals := strings.Split(hv, ",") - filteredPods := make([]types.Pod, 0, len(headerVals)) - seen := make(map[string]struct{}, len(headerVals)) // de-dupe by pod IP + filteredEndpoints := make([]types.Endpoint, 0, len(headerVals)) + seen := make(map[string]struct{}, len(headerVals)) // de-dupe by endpoint IP for _, raw := range headerVals { item := strings.TrimSpace(raw) @@ -114,26 +114,26 @@ func (f *HeaderBasedTestingFilter) Filter(_ context.Context, _ *types.CycleState } host = normalizeIP(host) - var pod types.Pod + var endpoint types.Endpoint if port != "" { // Require an exact ip:port match if p, ok := hpToPod[host+":"+port]; ok { - pod = p + endpoint = p } } else { // IP-only selection - if p, ok := ipToPod[host]; ok { - pod = p + if p, ok := ipToEndpoint[host]; ok { + endpoint = p } } - if pod != nil { - ip := normalizeIP(pod.GetPod().GetIPAddress()) + if endpoint != nil { + ip := normalizeIP(endpoint.GetMetadata().GetIPAddress()) if _, dup := seen[ip]; !dup { seen[ip] = struct{}{} - filteredPods = append(filteredPods, pod) + filteredEndpoints = append(filteredEndpoints, endpoint) } } } - return filteredPods + return filteredEndpoints } diff --git a/pkg/epp/scheduling/framework/scheduler_profile.go b/pkg/epp/scheduling/framework/scheduler_profile.go index db547088a..35ef9af6a 100644 --- a/pkg/epp/scheduling/framework/scheduler_profile.go +++ b/pkg/epp/scheduling/framework/scheduler_profile.go @@ -114,76 +114,76 @@ func (p *SchedulerProfile) String() string { // Run runs a SchedulerProfile. It invokes all the SchedulerProfile plugins for the given request in this // order - Filters, Scorers, Picker. After completing all, it returns the result. -func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, candidatePods []types.Pod) (*types.ProfileRunResult, error) { - pods := p.runFilterPlugins(ctx, request, cycleState, candidatePods) - if len(pods) == 0 { - return nil, errutil.Error{Code: errutil.Internal, Msg: "no pods available for the given request"} +func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, candidateEndpoints []types.Endpoint) (*types.ProfileRunResult, error) { + endpoints := p.runFilterPlugins(ctx, request, cycleState, candidateEndpoints) + if len(endpoints) == 0 { + return nil, errutil.Error{Code: errutil.Internal, Msg: "no endpoints available for the given request"} } - // if we got here, there is at least one pod to score - weightedScorePerPod := p.runScorerPlugins(ctx, request, cycleState, pods) + // if we got here, there is at least one endpoint to score + weightedScorePerEndpoint := p.runScorerPlugins(ctx, request, cycleState, endpoints) - result := p.runPickerPlugin(ctx, cycleState, weightedScorePerPod) + result := p.runPickerPlugin(ctx, cycleState, weightedScorePerEndpoint) return result, nil } -func (p *SchedulerProfile) runFilterPlugins(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) []types.Pod { +func (p *SchedulerProfile) runFilterPlugins(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, endpoints []types.Endpoint) []types.Endpoint { logger := log.FromContext(ctx) - filteredPods := pods - logger.V(logutil.DEBUG).Info("Before running filter plugins", "pods", filteredPods) + filteredEndpoints := endpoints + logger.V(logutil.DEBUG).Info("Before running filter plugins", "endpoints", filteredEndpoints) for _, filter := range p.filters { logger.V(logutil.VERBOSE).Info("Running filter plugin", "plugin", filter.TypedName()) before := time.Now() - filteredPods = filter.Filter(ctx, cycleState, request, filteredPods) + filteredEndpoints = filter.Filter(ctx, cycleState, request, filteredEndpoints) metrics.RecordPluginProcessingLatency(FilterExtensionPoint, filter.TypedName().Type, filter.TypedName().Name, time.Since(before)) - logger.V(logutil.DEBUG).Info("Completed running filter plugin successfully", "plugin", filter.TypedName(), "pods", filteredPods) - if len(filteredPods) == 0 { + logger.V(logutil.DEBUG).Info("Completed running filter plugin successfully", "plugin", filter.TypedName(), "endpoints", filteredEndpoints) + if len(filteredEndpoints) == 0 { break } } logger.V(logutil.VERBOSE).Info("Completed running filter plugins successfully") - return filteredPods + return filteredEndpoints } -func (p *SchedulerProfile) runScorerPlugins(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) map[types.Pod]float64 { +func (p *SchedulerProfile) runScorerPlugins(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, endpoints []types.Endpoint) map[types.Endpoint]float64 { logger := log.FromContext(ctx) - logger.V(logutil.DEBUG).Info("Before running scorer plugins", "pods", pods) + logger.V(logutil.DEBUG).Info("Before running scorer plugins", "endpoints", endpoints) - weightedScorePerPod := make(map[types.Pod]float64, len(pods)) - for _, pod := range pods { - weightedScorePerPod[pod] = float64(0) // initialize weighted score per pod with 0 value + weightedScorePerEndpoint := make(map[types.Endpoint]float64, len(endpoints)) + for _, endpoint := range endpoints { + weightedScorePerEndpoint[endpoint] = float64(0) // initialize weighted score per endpoint with 0 value } // Iterate through each scorer in the chain and accumulate the weighted scores. for _, scorer := range p.scorers { logger.V(logutil.VERBOSE).Info("Running scorer plugin", "plugin", scorer.TypedName()) before := time.Now() - scores := scorer.Score(ctx, cycleState, request, pods) + scores := scorer.Score(ctx, cycleState, request, endpoints) metrics.RecordPluginProcessingLatency(ScorerExtensionPoint, scorer.TypedName().Type, scorer.TypedName().Name, time.Since(before)) - for pod, score := range scores { // weight is relative to the sum of weights - logger.V(logutil.DEBUG).Info("Calculated score", "plugin", scorer.TypedName(), "endpoint", pod.GetPod().NamespacedName, "score", score) - weightedScorePerPod[pod] += enforceScoreRange(score) * float64(scorer.Weight()) + for endpoint, score := range scores { // weight is relative to the sum of weights + logger.V(logutil.DEBUG).Info("Calculated score", "plugin", scorer.TypedName(), "endpoint", endpoint.GetMetadata().NamespacedName, "score", score) + weightedScorePerEndpoint[endpoint] += enforceScoreRange(score) * float64(scorer.Weight()) } logger.V(logutil.DEBUG).Info("Completed running scorer plugin successfully", "plugin", scorer.TypedName()) } logger.V(logutil.VERBOSE).Info("Completed running scorer plugins successfully") - return weightedScorePerPod + return weightedScorePerEndpoint } -func (p *SchedulerProfile) runPickerPlugin(ctx context.Context, cycleState *types.CycleState, weightedScorePerPod map[types.Pod]float64) *types.ProfileRunResult { +func (p *SchedulerProfile) runPickerPlugin(ctx context.Context, cycleState *types.CycleState, weightedScorePerEndpoint map[types.Endpoint]float64) *types.ProfileRunResult { logger := log.FromContext(ctx) - scoredPods := make([]*types.ScoredPod, len(weightedScorePerPod)) + scoredEndpoints := make([]*types.ScoredEndpoint, len(weightedScorePerEndpoint)) i := 0 - for pod, score := range weightedScorePerPod { - scoredPods[i] = &types.ScoredPod{Pod: pod, Score: score} + for endpoint, score := range weightedScorePerEndpoint { + scoredEndpoints[i] = &types.ScoredEndpoint{Endpoint: endpoint, Score: score} i++ } logger.V(logutil.VERBOSE).Info("Running picker plugin", "plugin", p.picker.TypedName()) - logger.V(logutil.DEBUG).Info("Candidate pods for picking", "pods-weighted-score", scoredPods) + logger.V(logutil.DEBUG).Info("Candidate pods for picking", "endpoints-weighted-score", scoredEndpoints) before := time.Now() - result := p.picker.Pick(ctx, cycleState, scoredPods) + result := p.picker.Pick(ctx, cycleState, scoredEndpoints) metrics.RecordPluginProcessingLatency(PickerExtensionPoint, p.picker.TypedName().Type, p.picker.TypedName().Name, time.Since(before)) logger.V(logutil.DEBUG).Info("Completed running picker plugin successfully", "plugin", p.picker.TypedName(), "result", result) diff --git a/pkg/epp/scheduling/framework/scheduler_profile_test.go b/pkg/epp/scheduling/framework/scheduler_profile_test.go index f79b48de7..0a18ab0b8 100644 --- a/pkg/epp/scheduling/framework/scheduler_profile_test.go +++ b/pkg/epp/scheduling/framework/scheduler_profile_test.go @@ -24,7 +24,7 @@ import ( "github.com/google/uuid" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -50,14 +50,14 @@ func TestSchedulePlugins(t *testing.T) { } tests := []struct { - name string - profile *SchedulerProfile - input []types.Pod - wantTargetPod k8stypes.NamespacedName - targetPodScore float64 - // Number of expected pods to score (after filter) - numPodsToScore int - err bool + name string + profile *SchedulerProfile + input []types.Endpoint + wantTargetEndpoint k8stypes.NamespacedName + targetEndpointScore float64 + // Number of expected endpoints to score (after filter) + numEndpointsToScore int + err bool }{ { name: "all plugins executed successfully, all scorers with same weight", @@ -65,15 +65,15 @@ func TestSchedulePlugins(t *testing.T) { WithFilters(tp1, tp2). WithScorers(NewWeightedScorer(tp1, 1), NewWeightedScorer(tp2, 1)). WithPicker(pickerPlugin), - input: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, - &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, - &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + input: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, - targetPodScore: 1.1, - numPodsToScore: 2, - err: false, + wantTargetEndpoint: k8stypes.NamespacedName{Name: "pod1"}, + targetEndpointScore: 1.1, + numEndpointsToScore: 2, + err: false, }, { name: "all plugins executed successfully, different scorers weights", @@ -81,15 +81,15 @@ func TestSchedulePlugins(t *testing.T) { WithFilters(tp1, tp2). WithScorers(NewWeightedScorer(tp1, 60), NewWeightedScorer(tp2, 40)). WithPicker(pickerPlugin), - input: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, - &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, - &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + input: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, - targetPodScore: 50, - numPodsToScore: 2, - err: false, + wantTargetEndpoint: k8stypes.NamespacedName{Name: "pod1"}, + targetEndpointScore: 50, + numEndpointsToScore: 2, + err: false, }, { name: "filter all", @@ -97,13 +97,13 @@ func TestSchedulePlugins(t *testing.T) { WithFilters(tp1, tp_filterAll). WithScorers(NewWeightedScorer(tp1, 1), NewWeightedScorer(tp2, 1)). WithPicker(pickerPlugin), - input: []types.Pod{ - &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, - &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, - &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + input: []types.Endpoint{ + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + &types.PodMetrics{EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - numPodsToScore: 0, - err: true, // no available pods to server after filter all + numEndpointsToScore: 0, + err: true, // no available endpoints to server after filter all }, } @@ -137,9 +137,9 @@ func TestSchedulePlugins(t *testing.T) { // Validate output wantRes := &types.ProfileRunResult{ - TargetPods: []types.Pod{ + TargetEndpoints: []types.Endpoint{ &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: test.wantTargetPod}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: test.wantTargetEndpoint}, }, }, } @@ -159,19 +159,19 @@ func TestSchedulePlugins(t *testing.T) { if tp.ScoreCallCount != 1 { t.Errorf("Plugin '%s' Score() called %d times, expected 1", plugin.TypedName(), tp.ScoreCallCount) } - if test.numPodsToScore != tp.NumOfScoredPods { - t.Errorf("Plugin '%s' Score() called with %d pods, expected %d", plugin.TypedName(), tp.NumOfScoredPods, test.numPodsToScore) + if test.numEndpointsToScore != tp.NumOfScoredEndpoints { + t.Errorf("Plugin '%s' Score() called with %d pods, expected %d", plugin.TypedName(), tp.NumOfScoredEndpoints, test.numEndpointsToScore) } } tp, _ := test.profile.picker.(*testPlugin) - if tp.NumOfPickerCandidates != test.numPodsToScore { - t.Errorf("Picker plugin '%s' Pick() called with %d candidates, expected %d", tp.TypedName(), tp.NumOfPickerCandidates, tp.NumOfScoredPods) + if tp.NumOfPickerCandidates != test.numEndpointsToScore { + t.Errorf("Picker plugin '%s' Pick() called with %d candidates, expected %d", tp.TypedName(), tp.NumOfPickerCandidates, tp.NumOfScoredEndpoints) } if tp.PickCallCount != 1 { t.Errorf("Picker plugin '%s' Pick() called %d times, expected 1", tp.TypedName(), tp.PickCallCount) } - if tp.WinnerPodScore != test.targetPodScore { - t.Errorf("winner pod score %v, expected %v", tp.WinnerPodScore, test.targetPodScore) + if tp.WinnerEndpointScore != test.targetEndpointScore { + t.Errorf("winner pod score %v, expected %v", tp.WinnerEndpointScore, test.targetEndpointScore) } }) } @@ -187,65 +187,65 @@ type testPlugin struct { typedName plugins.TypedName TypeRes string ScoreCallCount int - NumOfScoredPods int + NumOfScoredEndpoints int ScoreRes float64 FilterCallCount int FilterRes []k8stypes.NamespacedName PickCallCount int NumOfPickerCandidates int PickRes k8stypes.NamespacedName - WinnerPodScore float64 + WinnerEndpointScore float64 } func (tp *testPlugin) TypedName() plugins.TypedName { return tp.typedName } -func (tp *testPlugin) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { +func (tp *testPlugin) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, endpoints []types.Endpoint) []types.Endpoint { tp.FilterCallCount++ - return findPods(pods, tp.FilterRes...) + return findEndpoints(endpoints, tp.FilterRes...) } -func (tp *testPlugin) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (tp *testPlugin) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, endpoints []types.Endpoint) map[types.Endpoint]float64 { tp.ScoreCallCount++ - scoredPods := make(map[types.Pod]float64, len(pods)) - for _, pod := range pods { - scoredPods[pod] += tp.ScoreRes + scoredEndpoints := make(map[types.Endpoint]float64, len(endpoints)) + for _, endpoint := range endpoints { + scoredEndpoints[endpoint] += tp.ScoreRes } - tp.NumOfScoredPods = len(scoredPods) - return scoredPods + tp.NumOfScoredEndpoints = len(scoredEndpoints) + return scoredEndpoints } -func (tp *testPlugin) Pick(_ context.Context, _ *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { +func (tp *testPlugin) Pick(_ context.Context, _ *types.CycleState, scoredEndpoints []*types.ScoredEndpoint) *types.ProfileRunResult { tp.PickCallCount++ - tp.NumOfPickerCandidates = len(scoredPods) + tp.NumOfPickerCandidates = len(scoredEndpoints) - winnerPods := []types.Pod{} - for _, scoredPod := range scoredPods { - if scoredPod.GetPod().NamespacedName.String() == tp.PickRes.String() { - winnerPods = append(winnerPods, scoredPod.Pod) - tp.WinnerPodScore = scoredPod.Score + winnerEndpoints := []types.Endpoint{} + for _, scoredEndpoint := range scoredEndpoints { + if scoredEndpoint.GetMetadata().NamespacedName.String() == tp.PickRes.String() { + winnerEndpoints = append(winnerEndpoints, scoredEndpoint.Endpoint) + tp.WinnerEndpointScore = scoredEndpoint.Score } } - return &types.ProfileRunResult{TargetPods: winnerPods} + return &types.ProfileRunResult{TargetEndpoints: winnerEndpoints} } func (tp *testPlugin) reset() { tp.FilterCallCount = 0 tp.ScoreCallCount = 0 - tp.NumOfScoredPods = 0 + tp.NumOfScoredEndpoints = 0 tp.PickCallCount = 0 tp.NumOfPickerCandidates = 0 } -func findPods(pods []types.Pod, names ...k8stypes.NamespacedName) []types.Pod { - res := []types.Pod{} - for _, pod := range pods { +func findEndpoints(endpoints []types.Endpoint, names ...k8stypes.NamespacedName) []types.Endpoint { + res := []types.Endpoint{} + for _, endpoint := range endpoints { for _, name := range names { - if pod.GetPod().NamespacedName.String() == name.String() { - res = append(res, pod) + if endpoint.GetMetadata().NamespacedName.String() == name.String() { + res = append(res, endpoint) } } } diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 5b2e64d23..f05554130 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -44,7 +44,7 @@ type Scheduler struct { } // Schedule finds the target pod based on metrics and the requested lora adapter. -func (s *Scheduler) Schedule(ctx context.Context, request *types.LLMRequest, candidatePods []types.Pod) (result *types.SchedulingResult, err error) { +func (s *Scheduler) Schedule(ctx context.Context, request *types.LLMRequest, candidateEndpoints []types.Endpoint) (result *types.SchedulingResult, err error) { loggerVerbose := log.FromContext(ctx).V(logutil.VERBOSE) scheduleStart := time.Now() @@ -69,7 +69,7 @@ func (s *Scheduler) Schedule(ctx context.Context, request *types.LLMRequest, can for name, profile := range profiles { loggerVerbose.Info("Running scheduler profile", "profile", name) // run the selected profiles and collect results (current code runs all profiles) - profileRunResult, err := profile.Run(ctx, request, cycleState, candidatePods) + profileRunResult, err := profile.Run(ctx, request, cycleState, candidateEndpoints) if err != nil { loggerVerbose.Info("failed to run scheduler profile", "profile", name, "error", err.Error()) } else { diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index c197096ba..533660414 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -24,8 +24,8 @@ import ( "github.com/google/uuid" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker" @@ -56,31 +56,31 @@ func TestSchedule(t *testing.T) { tests := []struct { name string req *types.LLMRequest - input []types.Pod + input []types.Endpoint wantRes *types.SchedulingResult err bool }{ { - name: "no candidate pods", + name: "no candidate endpoints", req: &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "any-model", }, - input: []types.Pod{}, + input: []types.Endpoint{}, wantRes: nil, err: true, }, { - name: "finds optimal pod", + name: "finds optimal endpoint", req: &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", }, // pod2 will be picked because it has relatively low queue size, with the requested // model being active, and has low KV cache. - input: []types.Pod{ + input: []types.Endpoint{ &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -92,7 +92,7 @@ func TestSchedule(t *testing.T) { }, }, &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: &backendmetrics.MetricsState{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -104,7 +104,7 @@ func TestSchedule(t *testing.T) { }, }, &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, MetricsState: &backendmetrics.MetricsState{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.8, @@ -118,10 +118,10 @@ func TestSchedule(t *testing.T) { wantRes: &types.SchedulingResult{ ProfileResults: map[string]*types.ProfileRunResult{ "default": { - TargetPods: []types.Pod{ - &types.ScoredPod{ - Pod: &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + TargetEndpoints: []types.Endpoint{ + &types.ScoredEndpoint{ + Endpoint: &types.PodMetrics{ + EndpointMetadata: &datalayer.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: &backendmetrics.MetricsState{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 8e0553fae..4401bd6ca 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -22,7 +22,6 @@ import ( "fmt" "strings" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" ) @@ -188,8 +187,8 @@ func (mc Content) PlainText() string { return sb.String() } -type Pod interface { - GetPod() *backend.Pod +type Endpoint interface { + GetMetadata() *datalayer.EndpointMetadata GetMetrics() *backendmetrics.MetricsState String() string Get(string) (datalayer.Cloneable, bool) @@ -197,8 +196,8 @@ type Pod interface { Keys() []string } -type ScoredPod struct { - Pod +type ScoredEndpoint struct { + Endpoint Score float64 } @@ -210,8 +209,8 @@ func (pm *PodMetrics) String() string { return fmt.Sprintf("%+v", *pm) } -func (pm *PodMetrics) GetPod() *backend.Pod { - return pm.Pod +func (pm *PodMetrics) GetMetadata() *datalayer.EndpointMetadata { + return pm.EndpointMetadata } func (pm *PodMetrics) GetMetrics() *backendmetrics.MetricsState { @@ -219,14 +218,14 @@ func (pm *PodMetrics) GetMetrics() *backendmetrics.MetricsState { } type PodMetrics struct { - *backend.Pod + *datalayer.EndpointMetadata *backendmetrics.MetricsState datalayer.AttributeMap } // ProfileRunResult captures the profile run result. type ProfileRunResult struct { - TargetPods []Pod + TargetEndpoints []Endpoint } // SchedulingResult captures the result of the scheduling cycle. diff --git a/pkg/epp/server/server_test.go b/pkg/epp/server/server_test.go index 9032e2c9f..3b59e4edb 100644 --- a/pkg/epp/server/server_test.go +++ b/pkg/epp/server/server_test.go @@ -26,7 +26,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" @@ -192,6 +192,6 @@ func (ts *testDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx * return reqCtx, nil } -func (ts *testDirector) GetRandomPod() *backend.Pod { +func (ts *testDirector) GetRandomEndpoint() *datalayer.EndpointMetadata { return nil } diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 72dedf252..062c1da9a 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -60,8 +60,8 @@ import ( v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/common" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" @@ -153,7 +153,7 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { tests := []struct { name string requests []*extProcPb.ProcessingRequest - pods map[*backend.Pod]*backendmetrics.MetricsState + endpoints map[*datalayer.EndpointMetadata]*backendmetrics.MetricsState wantResponses []*extProcPb.ProcessingResponse wantMetrics map[string]string wantErr bool @@ -163,11 +163,11 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { { name: "select lower queue and kv cache, no active lora", requests: integrationutils.GenerateStreamedRequestSet(logger, "test1", modelMyModel, modelMyModelTarget, nil), - // Pod 1 will be picked because it has relatively low queue size and low KV cache. - pods: newPodStates( - podState{index: 0, queueSize: 3, kvCacheUsage: 0.2}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2}, + // Endpoint 1 will be picked because it has relatively low queue size and low KV cache. + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 3, kvCacheUsage: 0.2}, + endpointState{index: 1, queueSize: 0, kvCacheUsage: 0.1}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.2}, ), wantMetrics: map[string]string{ "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ @@ -219,11 +219,11 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { }, }, }, - // Pod 1 will be picked because it has relatively low queue size, the requested model active, and low KV cache. - pods: newPodStates( - podState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, + // Endpoint 1 will be picked because it has relatively low queue size, the requested model active, and low KV cache. + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, + endpointState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, ), wantErr: false, wantResponses: integrationutils.NewImmediateErrorResponse( @@ -234,11 +234,11 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { { name: "select active lora, low queue", requests: integrationutils.GenerateStreamedRequestSet(logger, "test2", modelSQLLora, modelSQLLoraTarget, nil), - // Pod 1 will be picked because it has relatively low queue size, the requested model active, and low KV cache. - pods: newPodStates( - podState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, + // Endpoint 1 will be picked because it has relatively low queue size, the requested model active, and low KV cache. + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, + endpointState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, ), wantMetrics: map[string]string{ @@ -268,13 +268,13 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { { name: "select lora despite higher kv cache usage", requests: integrationutils.GenerateStreamedRequestSet(logger, "test3", modelSQLLora, modelSQLLoraTarget, nil), - // Pod 2 will be picked despite NOT having the requested model active as it is above the affinity for queue size. + // Endpoint 2 will be picked despite NOT having the requested model active as it is above the affinity for queue size. // Also it is critical, so we should still admit the request despite all queue sizes being greater than the queue // size threshold. - pods: newPodStates( - podState{index: 0, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 10, kvCacheUsage: 0.4, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.3, activeModels: []string{"foo"}}, + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, + endpointState{index: 1, queueSize: 10, kvCacheUsage: 0.4, activeModels: []string{"foo", modelSQLLoraTarget}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.3, activeModels: []string{"foo"}}, ), wantMetrics: map[string]string{ "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ @@ -306,10 +306,10 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { // pod 0: excluded; above queue size threshold // pod 1: excluded; above KV cache threshold // pod 2: excluded; above queue size threshold - pods: newPodStates( - podState{index: 0, queueSize: 6, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSQLLoraTarget}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo"}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo"}}, + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 6, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSQLLoraTarget}}, + endpointState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo"}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo"}}, ), wantMetrics: map[string]string{ "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ @@ -374,11 +374,11 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { }, }, }, - // Pod 1 will be picked because it has relatively low queue size and low KV cache. - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, - podState{index: 1, queueSize: 4, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, + // Endpoint 1 will be picked because it has relatively low queue size and low KV cache. + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, + endpointState{index: 1, queueSize: 4, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, ), wantMetrics: map[string]string{ "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ @@ -449,10 +449,10 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { }, }, // pod 0: selected due to low queue size and kv cache usage - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, + endpointState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, ), wantMetrics: map[string]string{ "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ @@ -510,10 +510,10 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { // pod 0: selected // pod 1: excluded; above KV cache threshold // pod 2: excluded; above queue size threshold - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, + endpointState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, ), wantErr: false, wantResponses: integrationutils.NewResponseBufferedResponse( @@ -558,10 +558,10 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { // pod 0: selected // pod 1: excluded; above KV cache threshold // pod 2: excluded; above queue size threshold - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, + endpointState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, ), wantErr: false, wantResponses: integrationutils.NewResponseBufferedResponse( @@ -611,10 +611,10 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { // pod 0: selected // pod 1: excluded; above KV cache threshold // pod 2: excluded; above queue size threshold - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, + endpointState{index: 1, queueSize: 0, kvCacheUsage: 0.85, activeModels: []string{"foo", modelSheddableTarget}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.9, activeModels: []string{"foo", modelSheddableTarget}}, ), wantErr: false, wantResponses: integrationutils.NewResponseBufferedResponse( @@ -805,8 +805,8 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { }, }, wantResponses: nil, - pods: newPodStates( - podState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 4, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar", modelSheddableTarget}}, ), wantMetrics: map[string]string{}, }, @@ -818,11 +818,11 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { modelSQLLora, modelSQLLoraTarget, []string{"192.168.1.1:8000", "192.168.1.2:8000", "192.168.1.3:8000"}), - // Pod 1 will be picked because it has relatively low queue size, the requested model active, low KV cache, and within subset. - pods: newPodStates( - podState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, + // Endpoint 1 will be picked because it has relatively low queue size, the requested model active, low KV cache, and within subset. + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, + endpointState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, ), wantMetrics: map[string]string{ @@ -857,11 +857,11 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { modelSQLLora, modelSQLLoraTarget, []string{"192.168.1.3:8000"}), - // Pod 3 has high queue and kv cache utilization, but it will still be picked because it is the only one matching subsetting target. - pods: newPodStates( - podState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, + // Endpoint 3 has high queue and kv cache utilization, but it will still be picked because it is the only one matching subsetting target. + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, + endpointState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, ), wantMetrics: map[string]string{ @@ -897,10 +897,10 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { modelSQLLoraTarget, []string{"192.168.1.4:8000", "192.168.1.5:8000", "192.168.1.6:8000"}), // No pods will be picked as none are within the subset. - pods: newPodStates( - podState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, - podState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, - podState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 0, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, + endpointState{index: 1, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", modelSQLLoraTarget}}, + endpointState{index: 2, queueSize: 10, kvCacheUsage: 0.2, activeModels: []string{"foo", "bar"}}, ), wantMetrics: map[string]string{}, @@ -941,7 +941,7 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { }, }, }, - pods: nil, + endpoints: nil, wantMetrics: map[string]string{}, wantErr: true, wantResponses: []*extProcPb.ProcessingResponse{ @@ -986,13 +986,13 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { { name: "rewrite request model", requests: integrationutils.GenerateStreamedRequestSet(logger, "test-rewrite", modelToBeWritten, modelToBeWritten, nil), - // Pod 0 will be picked. + // Endpoint 0 will be picked. // Expected flow: // 1. Request asks for "model-to-be-rewritten" // 2. Rewrite rule transforms "model-to-be-rewritten" -> "rewritten-model" // 3. EPP sends request to backend with model "rewritten-model" - pods: newPodStates( - podState{index: 0, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", "rewritten-model"}}, + endpoints: newEndpointStates( + endpointState{index: 0, queueSize: 0, kvCacheUsage: 0.1, activeModels: []string{"foo", "rewritten-model"}}, ), wantMetrics: map[string]string{ "inference_objective_request_total": inferenceObjectiveRequestTotal([]label{ @@ -1023,7 +1023,7 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - client, cleanup := setUpHermeticServer(t, test.pods) + client, cleanup := setUpHermeticServer(t, test.endpoints) t.Cleanup(cleanup) responses, err := integrationutils.StreamedRequest(t, client, test.requests, len(test.wantResponses)) @@ -1051,7 +1051,7 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { } } -func setUpHermeticServer(t *testing.T, podAndMetrics map[*backend.Pod]*backendmetrics.MetricsState) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { +func setUpHermeticServer(t *testing.T, podAndMetrics map[*datalayer.EndpointMetadata]*backendmetrics.MetricsState) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { // Reconfigure the TestPodMetricsClient. res := map[types.NamespacedName]*backendmetrics.MetricsState{} for pod, metrics := range podAndMetrics { @@ -1121,14 +1121,14 @@ func setUpHermeticServer(t *testing.T, podAndMetrics map[*backend.Pod]*backendme Namespace(pod.NamespacedName.Namespace).Complete().ObjRef() if err := k8sClient.Delete(context.Background(), pod); err != nil { - logutil.Fatal(logger, err, "Failed to delete pod", "pod", fakePod) + logutil.Fatal(logger, err, "Failed to delete pod", "pod", fakeMetadata) } } } } -func fakePod(index int) *backend.Pod { - return &backend.Pod{ +func fakeMetadata(index int) *datalayer.EndpointMetadata { + return &datalayer.EndpointMetadata{ NamespacedName: types.NamespacedName{Name: fmt.Sprintf("pod-%v-rank-0", index), Namespace: testNamespace}, Address: fmt.Sprintf("192.168.1.%d", index+1), PodName: fmt.Sprintf("pod-%v", index), @@ -1136,19 +1136,19 @@ func fakePod(index int) *backend.Pod { } } -// podState is a descriptor for a pod's simulated metrics. -type podState struct { +// endpointState is a descriptor for a pod's simulated metrics. +type endpointState struct { index int queueSize int kvCacheUsage float64 activeModels []string } -// newPodStates generates the backend metrics map required by the test setup. -func newPodStates(states ...podState) map[*backend.Pod]*backendmetrics.MetricsState { - res := make(map[*backend.Pod]*backendmetrics.MetricsState) +// newEndpointStates generates the backend metrics map required by the test setup. +func newEndpointStates(states ...endpointState) map[*datalayer.EndpointMetadata]*backendmetrics.MetricsState { + res := make(map[*datalayer.EndpointMetadata]*backendmetrics.MetricsState) for _, s := range states { - pod := fakePod(s.index) + pod := fakeMetadata(s.index) activeModelsMap := make(map[string]int) for _, model := range s.activeModels { activeModelsMap[model] = 1