Skip to content

Commit 7476ad3

Browse files
committed
make ResponseComplete to accept LLMResponse and update the encoding method of Messages in ChatCompletions.
1 parent bb4fef9 commit 7476ad3

File tree

10 files changed

+213
-83
lines changed

10 files changed

+213
-83
lines changed

pkg/epp/handlers/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
303303
break
304304
}
305305

306+
reqCtx.Response.Body = body
306307
reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody)
307308
if responseErr != nil {
308309
if logger.V(logutil.DEBUG).Enabled() {

pkg/epp/requestcontrol/director.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,13 +297,7 @@ func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handl
297297
logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.")
298298
return reqCtx, err
299299
}
300-
response := &Response{
301-
RequestId: requestID,
302-
Headers: reqCtx.Response.Headers,
303-
// Currently use the first choice as the response body to process.
304-
Body: llmResponse.GetFirstChoiceContent(),
305-
}
306-
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
300+
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, llmResponse, reqCtx.TargetPod)
307301

308302
logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
309303
return reqCtx, nil
@@ -353,7 +347,7 @@ func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *sch
353347
}
354348
}
355349

356-
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
350+
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) {
357351
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
358352
for _, plugin := range d.requestControlPlugins.responseCompletePlugins {
359353
loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName())

pkg/epp/requestcontrol/director_test.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,10 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
677677
"total_tokens": 3
678678
}
679679
}`
680+
wantLLMResponse, err := schedulingtypes.NewLLMResponseFromBytes([]byte(chatCompletionJSON))
681+
if err != nil {
682+
t.Fatalf("NewLLMResponseFromBytes failed with error: %v", err)
683+
}
680684

681685
reqCtx := &handlers.RequestContext{
682686
Request: &handlers.Request{
@@ -691,21 +695,15 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
691695
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
692696
}
693697

694-
_, err := director.HandleResponseBodyComplete(ctx, reqCtx)
698+
_, err = director.HandleResponseBodyComplete(ctx, reqCtx)
695699
if err != nil {
696700
t.Fatalf("HandleResponseBodyComplete() returned unexpected error: %v", err)
697701
}
698702

699-
if diff := cmp.Diff("test-req-id-for-complete", pc1.lastRespOnComplete.RequestId); diff != "" {
700-
t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff)
701-
}
702-
if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" {
703-
t.Errorf("Scheduler.OnComplete response headers mismatch (-want +got):\n%s", diff)
704-
}
705703
if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" {
706704
t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff)
707705
}
708-
if diff := cmp.Diff("Hello!", pc1.lastRespOnComplete.Body); diff != "" {
706+
if diff := cmp.Diff(wantLLMResponse, pc1.lastRespOnComplete); diff != "" {
709707
t.Errorf("Scheduler.OnComplete response body mismatch (-want +got):\n%s", diff)
710708
}
711709
}
@@ -730,7 +728,7 @@ type testResponseStreaming struct {
730728

731729
type testResponseComplete struct {
732730
tn plugins.TypedName
733-
lastRespOnComplete *Response
731+
lastRespOnComplete *schedulingtypes.LLMResponse
734732
lastTargetPodOnComplete string
735733
}
736734

@@ -774,7 +772,7 @@ func (p *testResponseStreaming) ResponseStreaming(_ context.Context, _ *scheduli
774772
p.lastTargetPodOnStreaming = targetPod.NamespacedName.String()
775773
}
776774

777-
func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
775+
func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) {
778776
p.lastRespOnComplete = response
779777
p.lastTargetPodOnComplete = targetPod.NamespacedName.String()
780778
}

pkg/epp/requestcontrol/plugins.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,5 @@ type ResponseStreaming interface {
5555
// ResponseComplete is called by the director after the complete response is sent.
5656
type ResponseComplete interface {
5757
plugins.Plugin
58-
ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod)
58+
ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod)
5959
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,19 +375,25 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
375375
}
376376

377377
// must be chat-completions request at this point, return bytes of entire messages
378-
return json.Marshal(request.Body.ChatCompletions.Messages)
378+
return types.MarshalMessagesToJSON(request.Body.ChatCompletions.Messages...)
379379
}
380380

381-
func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) {
381+
func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) {
382382
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
383383
if err != nil {
384384
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
385385
return
386386
}
387387
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it.
388+
389+
reponseForKVCache, err := response.FirstChoiceContent()
390+
if err != nil {
391+
log.FromContext(ctx).Error(err, "failed to get first choice content", "requestID", request.RequestId)
392+
return
393+
}
388394
var input bytes.Buffer
389395
input.Write(state.RestBytes)
390-
input.Write([]byte(response.Body))
396+
input.Write(reponseForKVCache)
391397

392398
server := ServerID(targetPod.NamespacedName)
393399
prevBlockHash := defaultPrevBlock(request)
@@ -396,8 +402,7 @@ func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest
396402
prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1]
397403
prevBlockHashLength = len(state.PrefixHashes)
398404
}
399-
inputBytes := input.Bytes()
400-
hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, inputBytes, p.config.DefaultBlockSize, p.config.MaxPrefixBlocksToMatch)
405+
hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, input.Bytes(), p.config.DefaultBlockSize, p.config.MaxPrefixBlocksToMatch)
401406
p.wg.Add(1)
402407
go func() {
403408
p.indexer.Add(hashBlocks, server)

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import (
3030
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3131
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3232
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
33-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
3433
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3534
)
3635

@@ -251,7 +250,16 @@ func TestPrefixPluginCompletionWithResponse(t *testing.T) {
251250
// - Response Body: "bb"
252251
// - Cached Sequence: "aaaaaabb" (length 8)
253252
// This sequence creates two 4-character blocks to be cached: "aaaa" and "aabb".
254-
plugin.ResponseComplete(context.Background(), req1, &requestcontrol.Response{Body: "bb"}, pod1.GetPod())
253+
resp1 := &types.LLMResponse{
254+
Completion: &types.CompletionResponse{
255+
Choices: []types.CompletionChoice{
256+
{
257+
Text: "bb",
258+
},
259+
},
260+
},
261+
}
262+
plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod())
255263
plugin.wg.Wait()
256264

257265
// -- Second Request: Multi-turn Follow-up --
@@ -362,6 +370,19 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
362370
plugin.PreRequest(context.Background(), req1, schedulingResult, 0)
363371
plugin.wg.Wait()
364372

373+
resp1 := &types.LLMResponse{
374+
ChatCompletion: &types.ChatCompletionResponse{
375+
Choices: []types.ChatChoice{
376+
{
377+
Message: types.Message{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"},
378+
},
379+
},
380+
},
381+
}
382+
// Trigger to simulate the resp1 is added to the kvCache recording.
383+
plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod())
384+
plugin.wg.Wait()
385+
365386
// Second request adds assistant response and new user message (conversation grows)
366387
req2 := &types.LLMRequest{
367388
RequestId: uuid.NewString(),
@@ -389,13 +410,27 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
389410
cachedBlocks := state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)]
390411
expectedScore := float64(cachedBlocks) / float64(extendedHashCount)
391412
assert.Equal(t, expectedScore, scores[pod1], "pod1 should have prefix cache hit")
413+
assert.Greater(t, scores[pod1], float64(0.5), "given the response is also prefix cached the cache hit should be well above 0.5")
392414
assert.Equal(t, float64(0), scores[pod2], "pod2 should have no cache hit")
393415

394416
// Simulate pod1 was picked again
395417
plugin.PreRequest(context.Background(), req2, schedulingResult, 0)
396418
plugin.wg.Wait()
397419

398-
// Third request continues the conversation even further
420+
resp2 := &types.LLMResponse{
421+
ChatCompletion: &types.ChatCompletionResponse{
422+
Choices: []types.ChatChoice{
423+
{
424+
Message: types.Message{Role: "assistant", Content: "Prefix caching is a technique where..."},
425+
},
426+
},
427+
},
428+
}
429+
// Trigger to simulate the resp1 is added to the kvCache recording.
430+
plugin.ResponseComplete(context.Background(), req2, resp2, pod1.GetPod())
431+
plugin.wg.Wait()
432+
433+
// Third request is the whole above conversation to make the cache hit to 1.0.
399434
req3 := &types.LLMRequest{
400435
RequestId: uuid.NewString(),
401436
TargetModel: "test-model1",
@@ -407,7 +442,6 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
407442
{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"},
408443
{Role: "user", Content: "Can you explain how prefix caching works?"},
409444
{Role: "assistant", Content: "Prefix caching is a technique where..."},
410-
{Role: "user", Content: "That's very helpful, thank you!"},
411445
},
412446
},
413447
},
@@ -424,7 +458,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
424458
cachedBlocks = state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)]
425459
expectedScore = float64(cachedBlocks) / float64(longHashCount)
426460
assert.Equal(t, expectedScore, scores[pod1], "pod1 should have higher prefix cache hit")
427-
assert.Greater(t, scores[pod1], float64(0.5), "cache hit rate should be substantial for growing conversation")
461+
assert.Equal(t, scores[pod1], float64(1), "cache hit rate should be substantial for growing conversation")
428462
assert.Equal(t, float64(0), scores[pod2], "pod2 should still have no cache hit")
429463
}
430464

pkg/epp/scheduling/types/llmresponse.go

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,19 @@ import (
2626
type LLMResponse struct {
2727
// ChatCompletion is the representation of the OpenAI /v1/chat/completions response body.
2828
ChatCompletion *ChatCompletionResponse `json:"chat_completion,omitempty"`
29-
// LegacyCompletion is the representation of the OpenAI /v1/completions response body.
30-
LegacyCompletion *LegacyCompletionResponse `json:"legacy_completion,omitempty"`
29+
// Completion is the representation of the OpenAI /v1/completions response body.
30+
Completion *CompletionResponse `json:"legacy_completion,omitempty"`
3131
}
3232

33-
// GetFirstChoiceContent extracts the primary text content from the first choice
34-
// in either a ChatCompletion or a LegacyCompletion response.
35-
func (res *LLMResponse) GetFirstChoiceContent() string {
33+
// FirstChoiceContent extracts the first choice of the response.
34+
func (res *LLMResponse) FirstChoiceContent() ([]byte, error) {
3635
if res.ChatCompletion != nil && len(res.ChatCompletion.Choices) > 0 {
37-
return res.ChatCompletion.Choices[0].Message.Content
38-
} else if res.LegacyCompletion != nil && len(res.LegacyCompletion.Choices) > 0 {
39-
return res.LegacyCompletion.Choices[0].Text
36+
return MarshalMessagesToJSON(res.ChatCompletion.Choices[0].Message)
4037
}
41-
return ""
38+
if res.Completion != nil && len(res.Completion.Choices) > 0 {
39+
return []byte(res.Completion.Choices[0].Text), nil
40+
}
41+
return nil, fmt.Errorf("no choices found in the LLM response")
4242
}
4343

4444
// ChatCompletionResponse represents the full response body for the chat completions API.
@@ -60,8 +60,8 @@ func (r *ChatCompletionResponse) String() string {
6060

6161
// ChatChoice represents a single choice in the chat completion response.
6262
type ChatChoice struct {
63-
Message ChatMessage `json:"message"`
64-
FinishReason string `json:"finish_reason"`
63+
Message Message `json:"message"`
64+
FinishReason string `json:"finish_reason"`
6565
}
6666

6767
// ChatMessage represents the message object within a choice.
@@ -70,13 +70,13 @@ type ChatMessage struct {
7070
Content string `json:"content"`
7171
}
7272

73-
// LegacyCompletionResponse represents the full response body for the legacy completions API.
74-
type LegacyCompletionResponse struct {
75-
Choices []LegacyChoice `json:"choices"`
76-
Usage *Usage `json:"usage,omitempty"`
73+
// CompletionResponse represents the full response body for the legacy completions API.
74+
type CompletionResponse struct {
75+
Choices []CompletionChoice `json:"choices"`
76+
Usage *Usage `json:"usage,omitempty"`
7777
}
7878

79-
func (r *LegacyCompletionResponse) String() string {
79+
func (r *CompletionResponse) String() string {
8080
if r == nil {
8181
return nilString
8282
}
@@ -87,8 +87,8 @@ func (r *LegacyCompletionResponse) String() string {
8787
return fmt.Sprintf("{TextLength: %d, Usage: %v}", textLen, r.Usage)
8888
}
8989

90-
// LegacyChoice represents a single choice in the legacy completion response.
91-
type LegacyChoice struct {
90+
// CompletionChoice represents a single choice in the legacy completion response.
91+
type CompletionChoice struct {
9292
Text string `json:"text"`
9393
FinishReason string `json:"finish_reason"`
9494
}
@@ -124,10 +124,10 @@ func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) {
124124
}
125125

126126
// Try to unmarshal as a LegacyCompletionResponse.
127-
var legacyResp LegacyCompletionResponse
127+
var legacyResp CompletionResponse
128128
if err := json.Unmarshal(body, &legacyResp); err == nil {
129129
if len(legacyResp.Choices) > 0 {
130-
return &LLMResponse{LegacyCompletion: &legacyResp}, nil
130+
return &LLMResponse{Completion: &legacyResp}, nil
131131
}
132132
}
133133

0 commit comments

Comments
 (0)