diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index cfcd82ecc..bebacd8f0 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -70,6 +70,7 @@ func (s *StreamingServer) HandleRequestBody( Model: model, ResolvedTargetModel: modelName, Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, + Headers: reqCtx.RequestHeaders, Prompt: prompt, } logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) @@ -109,7 +110,7 @@ func (s *StreamingServer) HandleRequestBody( reqCtx.TargetPod = targetPod.NamespacedName.String() reqCtx.TargetEndpoint = endpoint - s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes)) + s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes), res.MutatedHeaders) reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{ // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header @@ -151,7 +152,12 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ return err } endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) - s.populateRequestHeaderResponse(reqCtx, endpoint, 0) + s.populateRequestHeaderResponse(reqCtx, endpoint, 0, nil) } + + for _, header := range req.RequestHeaders.Headers.Headers { + reqCtx.RequestHeaders[header.Key] = header.Value + } + return nil } diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 630baef31..0dd113a38 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -86,6 +86,8 @@ type RequestContext struct { RequestState StreamRequestState modelServerStreaming bool + RequestHeaders map[string]string + reqHeaderResp *extProcPb.ProcessingResponse reqBodyResp *extProcPb.ProcessingResponse reqTrailerResp *extProcPb.ProcessingResponse @@ -117,7 +119,8 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) // Create request context to share states during life time of an HTTP request. // See https://github.com/envoyproxy/envoy/issues/17540. reqCtx := &RequestContext{ - RequestState: RequestReceived, + RequestState: RequestReceived, + RequestHeaders: make(map[string]string), } var body []byte @@ -358,7 +361,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces return nil } -func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int) { +func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int, mutatedHeaders map[string]string) { headers := []*configPb.HeaderValueOption{ { Header: &configPb.HeaderValue{ @@ -377,6 +380,15 @@ func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, }, }) } + // Add headers added by filters/scorers + for key, value := range mutatedHeaders { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: key, + RawValue: []byte(value), + }, + }) + } targetEndpointValue := &structpb.Struct{ Fields: map[string]*structpb.Value{ diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 1a1d67b5c..db73cc7b1 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -119,6 +119,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types s.runPostSchedulePlugins(sCtx, result) + result.MutatedHeaders = sCtx.MutatedHeaders return result, nil } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index b44c7ac2e..fab17f08e 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -109,6 +109,7 @@ func TestSchedule(t *testing.T) { }, }, }, + MutatedHeaders: make(map[string]string), }, }, { @@ -172,6 +173,7 @@ func TestSchedule(t *testing.T) { }, }, }, + MutatedHeaders: make(map[string]string), }, }, { @@ -242,18 +244,27 @@ func TestSchedule(t *testing.T) { func TestSchedulePlugins(t *testing.T) { tp1 := &TestPlugin{ - NameRes: "test1", - ScoreRes: 0.3, - FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}}, + NameRes: "test1", + ScoreRes: 0.3, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}}, + ReceivedRequestHeaders: make(map[string]string), } tp2 := &TestPlugin{ - NameRes: "test2", - ScoreRes: 0.8, - FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + NameRes: "test2", + ScoreRes: 0.8, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + ReceivedRequestHeaders: make(map[string]string), } tp_filterAll := &TestPlugin{ - NameRes: "filter all", - FilterRes: []k8stypes.NamespacedName{}, + NameRes: "filter all", + FilterRes: []k8stypes.NamespacedName{}, + ReceivedRequestHeaders: make(map[string]string), + } + tp_headers := &TestPlugin{ + NameRes: "headers", + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + ExtraHeaders: map[string]string{"x-unit-test": "test 1 2 3"}, + ReceivedRequestHeaders: make(map[string]string), } pickerPlugin := &TestPlugin{ NameRes: "picker", @@ -261,11 +272,13 @@ func TestSchedulePlugins(t *testing.T) { } tests := []struct { - name string - config SchedulerConfig - input []*backendmetrics.FakePodMetrics - wantTargetPod k8stypes.NamespacedName - targetPodScore float64 + name string + config SchedulerConfig + input []*backendmetrics.FakePodMetrics + requestHeaders map[string]string + wantTargetPod k8stypes.NamespacedName + wantMutatedHeaders map[string]string + targetPodScore float64 // Number of expected pods to score (after filter) numPodsToScore int err bool @@ -287,10 +300,12 @@ func TestSchedulePlugins(t *testing.T) { {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, - targetPodScore: 1.1, - numPodsToScore: 2, - err: false, + requestHeaders: make(map[string]string), + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + wantMutatedHeaders: make(map[string]string), + targetPodScore: 1.1, + numPodsToScore: 2, + err: false, }, { name: "all plugins executed successfully, different scorers weights", @@ -309,10 +324,12 @@ func TestSchedulePlugins(t *testing.T) { {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, - targetPodScore: 50, - numPodsToScore: 2, - err: false, + requestHeaders: make(map[string]string), + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + wantMutatedHeaders: make(map[string]string), + targetPodScore: 50, + numPodsToScore: 2, + err: false, }, { name: "filter all", @@ -331,9 +348,37 @@ func TestSchedulePlugins(t *testing.T) { {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, + requestHeaders: make(map[string]string), numPodsToScore: 0, err: true, // no available pods to server after filter all }, + { + name: "Mutate a header", + config: SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp_headers}, + scorers: map[plugins.Scorer]int{ + tp1: 1, + tp2: 1, + }, + picker: pickerPlugin, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, + }, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + requestHeaders: map[string]string{ + "Content-type": "application/json", + "x-session-id": "qazw-edcr-tgby-nhyu", + }, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + wantMutatedHeaders: map[string]string{"x-unit-test": "test 1 2 3"}, + targetPodScore: 1.1, + numPodsToScore: 2, + err: false, // no available pods to server after filter all + }, } for _, test := range tests { @@ -356,7 +401,10 @@ func TestSchedulePlugins(t *testing.T) { // Initialize the scheduler scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config) - req := &types.LLMRequest{Model: "test-model"} + req := &types.LLMRequest{ + Model: "test-model", + Headers: test.requestHeaders, + } got, err := scheduler.Schedule(context.Background(), req) // Validate error state @@ -372,7 +420,10 @@ func TestSchedulePlugins(t *testing.T) { wantPod := &types.PodMetrics{ Pod: &backend.Pod{NamespacedName: test.wantTargetPod}, } - wantRes := &types.Result{TargetPod: wantPod} + wantRes := &types.Result{ + TargetPod: wantPod, + MutatedHeaders: test.wantMutatedHeaders, + } if diff := cmp.Diff(wantRes, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } @@ -390,6 +441,9 @@ func TestSchedulePlugins(t *testing.T) { if tp.FilterCallCount != 1 { t.Errorf("Plugin %s Filter() called %d times, expected 1", plugin.Name(), tp.FilterCallCount) } + if len(test.requestHeaders) != len(tp.ReceivedRequestHeaders) { + t.Errorf("Count of received request headers is %d, expected %d", len(tp.ReceivedRequestHeaders), len(test.requestHeaders)) + } } for plugin := range test.config.scorers { @@ -419,6 +473,10 @@ func TestSchedulePlugins(t *testing.T) { t.Errorf("Plugin %s PostSchedule() called %d times, expected 1", plugin.Name(), tp.PostScheduleCallCount) } } + + if len(test.wantMutatedHeaders) != len(got.MutatedHeaders) { + t.Errorf("Count of mutated headers is %d, expected %d", len(got.MutatedHeaders), len(test.wantMutatedHeaders)) + } }) } } @@ -437,18 +495,20 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics { // TestPlugin is an implementation useful in unit tests. type TestPlugin struct { - NameRes string - ScoreCallCount int - NumOfScoredPods int - ScoreRes float64 - FilterCallCount int - FilterRes []k8stypes.NamespacedName - PreScheduleCallCount int - PostScheduleCallCount int - PickCallCount int - NumOfPickerCandidates int - PickRes k8stypes.NamespacedName - WinnderPodScore float64 + NameRes string + ScoreCallCount int + NumOfScoredPods int + ScoreRes float64 + FilterCallCount int + FilterRes []k8stypes.NamespacedName + PreScheduleCallCount int + PostScheduleCallCount int + PickCallCount int + NumOfPickerCandidates int + PickRes k8stypes.NamespacedName + WinnderPodScore float64 + ExtraHeaders map[string]string + ReceivedRequestHeaders map[string]string } func (tp *TestPlugin) Name() string { return tp.NameRes } @@ -459,6 +519,12 @@ func (tp *TestPlugin) PreSchedule(ctx *types.SchedulingContext) { func (tp *TestPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { tp.FilterCallCount++ + for key, value := range tp.ExtraHeaders { + ctx.MutatedHeaders[key] = value + } + for key, value := range ctx.Req.Headers { + tp.ReceivedRequestHeaders[key] = value + } return findPods(ctx, tp.FilterRes...) } diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 4f69fae0a..f0e49452d 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -32,6 +32,9 @@ type LLMRequest struct { // Target models is a map of target model name to weight. TargetModels map[string]int Prompt string + // Headers during request processing contains all of the request headers. + // During response processing it contains all of the response headers. + Headers map[string]string // Resolved target model is the final target model after traffic split. ResolvedTargetModel string Critical bool @@ -58,6 +61,8 @@ type SchedulingContext struct { Logger logr.Logger Req *LLMRequest PodsSnapshot []Pod + // MutatedHeaders is used by the plugins to add/modify headers + MutatedHeaders map[string]string } func (pm *PodMetrics) String() string { @@ -83,10 +88,11 @@ type PodMetrics struct { func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext { logger := log.FromContext(ctx).WithValues("request", req) return &SchedulingContext{ - Context: ctx, - Logger: logger, - Req: req, - PodsSnapshot: pods, + Context: ctx, + Logger: logger, + Req: req, + PodsSnapshot: pods, + MutatedHeaders: make(map[string]string), } } @@ -100,5 +106,6 @@ func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod { // Result captures the scheduler result. type Result struct { - TargetPod Pod + TargetPod Pod + MutatedHeaders map[string]string }