Skip to content

Commit 824af68

Browse files
committed
make ResponseComplete to accept LLMResponse and update the encoding method of Messages in ChatCompletions.
1 parent 028974c commit 824af68

File tree

10 files changed

+228
-90
lines changed

10 files changed

+228
-90
lines changed

pkg/epp/handlers/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
304304
break
305305
}
306306

307+
reqCtx.Response.Body = body
307308
reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody)
308309
if responseErr != nil {
309310
if logger.V(logutil.DEBUG).Enabled() {

pkg/epp/requestcontrol/director.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,7 @@ func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handl
288288
logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.")
289289
return reqCtx, err
290290
}
291-
response := &Response{
292-
RequestId: requestID,
293-
Headers: reqCtx.Response.Headers,
294-
// Currently use the first choice as the response body to process.
295-
Body: llmResponse.GetFirstChoiceContent(),
296-
}
297-
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
291+
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, llmResponse, reqCtx.TargetPod)
298292

299293
logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
300294
return reqCtx, nil
@@ -344,7 +338,7 @@ func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *sch
344338
}
345339
}
346340

347-
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
341+
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) {
348342
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
349343
for _, plugin := range d.requestControlPlugins.responseCompletePlugins {
350344
loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName())

pkg/epp/requestcontrol/director_test.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,10 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
712712
"total_tokens": 3
713713
}
714714
}`
715+
wantLLMResponse, err := schedulingtypes.NewLLMResponseFromBytes([]byte(chatCompletionJSON))
716+
if err != nil {
717+
t.Fatalf("NewLLMResponseFromBytes failed with error: %v", err)
718+
}
715719

716720
reqCtx := &handlers.RequestContext{
717721
Request: &handlers.Request{
@@ -726,21 +730,15 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
726730
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
727731
}
728732

729-
_, err := director.HandleResponseBodyComplete(ctx, reqCtx)
733+
_, err = director.HandleResponseBodyComplete(ctx, reqCtx)
730734
if err != nil {
731735
t.Fatalf("HandleResponseBodyComplete() returned unexpected error: %v", err)
732736
}
733737

734-
if diff := cmp.Diff("test-req-id-for-complete", pc1.lastRespOnComplete.RequestId); diff != "" {
735-
t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff)
736-
}
737-
if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" {
738-
t.Errorf("Scheduler.OnComplete response headers mismatch (-want +got):\n%s", diff)
739-
}
740738
if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" {
741739
t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff)
742740
}
743-
if diff := cmp.Diff("Hello!", pc1.lastRespOnComplete.Body); diff != "" {
741+
if diff := cmp.Diff(wantLLMResponse, pc1.lastRespOnComplete); diff != "" {
744742
t.Errorf("Scheduler.OnComplete response body mismatch (-want +got):\n%s", diff)
745743
}
746744
}
@@ -765,7 +763,7 @@ type testResponseStreaming struct {
765763

766764
type testResponseComplete struct {
767765
tn plugins.TypedName
768-
lastRespOnComplete *Response
766+
lastRespOnComplete *schedulingtypes.LLMResponse
769767
lastTargetPodOnComplete string
770768
}
771769

@@ -809,7 +807,7 @@ func (p *testResponseStreaming) ResponseStreaming(_ context.Context, _ *scheduli
809807
p.lastTargetPodOnStreaming = targetPod.NamespacedName.String()
810808
}
811809

812-
func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
810+
func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) {
813811
p.lastRespOnComplete = response
814812
p.lastTargetPodOnComplete = targetPod.NamespacedName.String()
815813
}

pkg/epp/requestcontrol/plugins.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,5 @@ type ResponseStreaming interface {
5555
// ResponseComplete is called by the director after the complete response is sent.
5656
type ResponseComplete interface {
5757
plugins.Plugin
58-
ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod)
58+
ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod)
5959
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ type SchedulingContextState struct {
123123
// If not empty, this will be used as the starting block for the following response that will
124124
// be added to the response as well. This happens especially at the multi-turn scenario.
125125
RestBytes []byte
126+
// BlockSize is the block size used to caculate the hash of the request/response.
127+
BlockSize int
126128
// A map of server to its longest prefix cache match length.
127129
PrefixCacheServers map[ServerID]int
128130
}
@@ -198,11 +200,13 @@ func (p *Plugin) WithName(name string) *Plugin {
198200

199201
// Score returns the scoring result for the given list of pods based on context.
200202
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
203+
blockSize := getBlockSize(pods, p.config.DefaultBlockSize)
201204
// pre score step, hashing prompt and find longest prefix match.
202-
hashes, restBytes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
205+
hashes, restBytes := hashPrompt(ctx, request, blockSize, p.config.MaxPrefixBlocksToMatch)
203206
state := &SchedulingContextState{
204207
PrefixHashes: hashes,
205208
RestBytes: restBytes,
209+
BlockSize: blockSize,
206210
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
207211
}
208212

@@ -233,7 +237,6 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
233237
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile
234238

235239
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
236-
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
237240
if err != nil {
238241
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
239242
return
@@ -251,9 +254,7 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
251254

252255
total := len(state.PrefixHashes)
253256
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
254-
255-
blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize)
256-
metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize)
257+
metrics.RecordPrefixCacheMatch(matchLen*state.BlockSize, total*state.BlockSize)
257258
}
258259

259260
// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
@@ -375,19 +376,25 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
375376
}
376377

377378
// must be chat-completions request at this point, return bytes of entire messages
378-
return json.Marshal(request.Body.ChatCompletions.Messages)
379+
return types.MarshalMessagesToJSON(request.Body.ChatCompletions.Messages...)
379380
}
380381

381-
func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) {
382+
func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) {
382383
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
383384
if err != nil {
384385
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
385386
return
386387
}
387388
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it.
389+
390+
reponseForKVCache, err := response.FirstChoiceContent()
391+
if err != nil {
392+
log.FromContext(ctx).Error(err, "failed to get first choice content", "requestID", request.RequestId)
393+
return
394+
}
388395
var input bytes.Buffer
389396
input.Write(state.RestBytes)
390-
input.Write([]byte(response.Body))
397+
input.Write(reponseForKVCache)
391398

392399
server := ServerID(targetPod.NamespacedName)
393400
prevBlockHash := defaultPrevBlock(request)
@@ -396,8 +403,7 @@ func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest
396403
prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1]
397404
prevBlockHashLength = len(state.PrefixHashes)
398405
}
399-
inputBytes := input.Bytes()
400-
hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, inputBytes, p.config.DefaultBlockSize, p.config.MaxPrefixBlocksToMatch)
406+
hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, input.Bytes(), state.BlockSize, p.config.MaxPrefixBlocksToMatch)
401407
p.wg.Add(1)
402408
go func() {
403409
p.indexer.Add(hashBlocks, server)

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import (
3030
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3131
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3232
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
33-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
3433
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3534
)
3635

@@ -201,8 +200,9 @@ func TestPrefixPluginCompletion(t *testing.T) {
201200
}
202201

203202
func TestPrefixPluginCompletionWithResponse(t *testing.T) {
203+
const defaultBlockSize = 4
204204
config := Config{
205-
DefaultBlockSize: 4,
205+
DefaultBlockSize: defaultBlockSize,
206206
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
207207
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
208208
}
@@ -231,6 +231,9 @@ func TestPrefixPluginCompletionWithResponse(t *testing.T) {
231231
// Total hashes = 1 (for the "aaaa" block) + 1 (for the model prefix).
232232
assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect")
233233
assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers yet")
234+
// The last 2 characters are recorded in restBytes of the state.
235+
assert.Equal(t, 2, len(state.RestBytes), "number of restBytes is incorrect")
236+
assert.Equal(t, defaultBlockSize, state.BlockSize, "blockSize is incorrect")
234237
assert.Equal(t, float64(0), scores[pod1], "score for pod1 should be 0 on first request")
235238
assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0 on first request")
236239

@@ -251,7 +254,16 @@ func TestPrefixPluginCompletionWithResponse(t *testing.T) {
251254
// - Response Body: "bb"
252255
// - Cached Sequence: "aaaaaabb" (length 8)
253256
// This sequence creates two 4-character blocks to be cached: "aaaa" and "aabb".
254-
plugin.ResponseComplete(context.Background(), req1, &requestcontrol.Response{Body: "bb"}, pod1.GetPod())
257+
resp1 := &types.LLMResponse{
258+
Completion: &types.CompletionResponse{
259+
Choices: []types.CompletionChoice{
260+
{
261+
Text: "bb",
262+
},
263+
},
264+
},
265+
}
266+
plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod())
255267
plugin.wg.Wait()
256268

257269
// -- Second Request: Multi-turn Follow-up --
@@ -278,6 +290,9 @@ func TestPrefixPluginCompletionWithResponse(t *testing.T) {
278290
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
279291
// It should find a server (pod1) that has cached the prefixes.
280292
assert.Equal(t, 1, len(state.PrefixCacheServers), "a cached server should have been found")
293+
// The last 2 characters ("cc") are recorded in restBytes of the state.
294+
assert.Equal(t, 2, len(state.RestBytes), "number of restBytes is incorrect")
295+
assert.Equal(t, defaultBlockSize, state.BlockSize, "blockSize is incorrect")
281296
// The score for pod1 should be 1.0 because both prompt blocks ("aaaa" and "aabb") were found in its cache.
282297
assert.Equal(t, float64(1), scores[pod1], "score for pod1 should be a perfect match")
283298
assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0")
@@ -362,6 +377,19 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
362377
plugin.PreRequest(context.Background(), req1, schedulingResult)
363378
plugin.wg.Wait()
364379

380+
resp1 := &types.LLMResponse{
381+
ChatCompletion: &types.ChatCompletionResponse{
382+
Choices: []types.ChatChoice{
383+
{
384+
Message: types.Message{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"},
385+
},
386+
},
387+
},
388+
}
389+
// Trigger to simulate the resp1 is added to the kvCache recording.
390+
plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod())
391+
plugin.wg.Wait()
392+
365393
// Second request adds assistant response and new user message (conversation grows)
366394
req2 := &types.LLMRequest{
367395
RequestId: uuid.NewString(),
@@ -389,13 +417,27 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
389417
cachedBlocks := state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)]
390418
expectedScore := float64(cachedBlocks) / float64(extendedHashCount)
391419
assert.Equal(t, expectedScore, scores[pod1], "pod1 should have prefix cache hit")
420+
assert.Greater(t, scores[pod1], float64(0.5), "given the response is also prefix cached the cache hit should be well above 0.5")
392421
assert.Equal(t, float64(0), scores[pod2], "pod2 should have no cache hit")
393422

394423
// Simulate pod1 was picked again
395424
plugin.PreRequest(context.Background(), req2, schedulingResult)
396425
plugin.wg.Wait()
397426

398-
// Third request continues the conversation even further
427+
resp2 := &types.LLMResponse{
428+
ChatCompletion: &types.ChatCompletionResponse{
429+
Choices: []types.ChatChoice{
430+
{
431+
Message: types.Message{Role: "assistant", Content: "Prefix caching is a technique where..."},
432+
},
433+
},
434+
},
435+
}
436+
// Trigger to simulate the resp1 is added to the kvCache recording.
437+
plugin.ResponseComplete(context.Background(), req2, resp2, pod1.GetPod())
438+
plugin.wg.Wait()
439+
440+
// Third request is the whole above conversation to make the cache hit to 1.0.
399441
req3 := &types.LLMRequest{
400442
RequestId: uuid.NewString(),
401443
TargetModel: "test-model1",
@@ -424,7 +466,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
424466
cachedBlocks = state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)]
425467
expectedScore = float64(cachedBlocks) / float64(longHashCount)
426468
assert.Equal(t, expectedScore, scores[pod1], "pod1 should have higher prefix cache hit")
427-
assert.Greater(t, scores[pod1], float64(0.5), "cache hit rate should be substantial for growing conversation")
469+
assert.Equal(t, scores[pod1], float64(1), "cache hit rate should be substantial for growing conversation")
428470
assert.Equal(t, float64(0), scores[pod2], "pod2 should still have no cache hit")
429471
}
430472

pkg/epp/scheduling/types/llmresponse.go

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package types
1818

1919
import (
2020
"encoding/json"
21+
"errors"
2122
"fmt"
2223
)
2324

@@ -26,19 +27,19 @@ import (
2627
type LLMResponse struct {
2728
// ChatCompletion is the representation of the OpenAI /v1/chat/completions response body.
2829
ChatCompletion *ChatCompletionResponse `json:"chat_completion,omitempty"`
29-
// LegacyCompletion is the representation of the OpenAI /v1/completions response body.
30-
LegacyCompletion *LegacyCompletionResponse `json:"legacy_completion,omitempty"`
30+
// Completion is the representation of the OpenAI /v1/completions response body.
31+
Completion *CompletionResponse `json:"legacy_completion,omitempty"`
3132
}
3233

33-
// GetFirstChoiceContent extracts the primary text content from the first choice
34-
// in either a ChatCompletion or a LegacyCompletion response.
35-
func (res *LLMResponse) GetFirstChoiceContent() string {
34+
// FirstChoiceContent extracts the first choice of the response.
35+
func (res *LLMResponse) FirstChoiceContent() ([]byte, error) {
3636
if res.ChatCompletion != nil && len(res.ChatCompletion.Choices) > 0 {
37-
return res.ChatCompletion.Choices[0].Message.Content
38-
} else if res.LegacyCompletion != nil && len(res.LegacyCompletion.Choices) > 0 {
39-
return res.LegacyCompletion.Choices[0].Text
37+
return MarshalMessagesToJSON(res.ChatCompletion.Choices[0].Message)
4038
}
41-
return ""
39+
if res.Completion != nil && len(res.Completion.Choices) > 0 {
40+
return []byte(res.Completion.Choices[0].Text), nil
41+
}
42+
return nil, errors.New("no choices found in the LLM response")
4243
}
4344

4445
// ChatCompletionResponse represents the full response body for the chat completions API.
@@ -60,8 +61,8 @@ func (r *ChatCompletionResponse) String() string {
6061

6162
// ChatChoice represents a single choice in the chat completion response.
6263
type ChatChoice struct {
63-
Message ChatMessage `json:"message"`
64-
FinishReason string `json:"finish_reason"`
64+
Message Message `json:"message"`
65+
FinishReason string `json:"finish_reason"`
6566
}
6667

6768
// ChatMessage represents the message object within a choice.
@@ -70,13 +71,13 @@ type ChatMessage struct {
7071
Content string `json:"content"`
7172
}
7273

73-
// LegacyCompletionResponse represents the full response body for the legacy completions API.
74-
type LegacyCompletionResponse struct {
75-
Choices []LegacyChoice `json:"choices"`
76-
Usage *Usage `json:"usage,omitempty"`
74+
// CompletionResponse represents the full response body for the legacy completions API.
75+
type CompletionResponse struct {
76+
Choices []CompletionChoice `json:"choices"`
77+
Usage *Usage `json:"usage,omitempty"`
7778
}
7879

79-
func (r *LegacyCompletionResponse) String() string {
80+
func (r *CompletionResponse) String() string {
8081
if r == nil {
8182
return nilString
8283
}
@@ -87,8 +88,8 @@ func (r *LegacyCompletionResponse) String() string {
8788
return fmt.Sprintf("{TextLength: %d, Usage: %v}", textLen, r.Usage)
8889
}
8990

90-
// LegacyChoice represents a single choice in the legacy completion response.
91-
type LegacyChoice struct {
91+
// CompletionChoice represents a single choice in the legacy completion response.
92+
type CompletionChoice struct {
9293
Text string `json:"text"`
9394
FinishReason string `json:"finish_reason"`
9495
}
@@ -111,7 +112,7 @@ func (u *Usage) String() string {
111112
// as a chat completion and then as a legacy completion response.
112113
func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) {
113114
if len(body) == 0 {
114-
return nil, fmt.Errorf("input bytes are empty")
115+
return nil, errors.New("input bytes are empty")
115116
}
116117

117118
// Attempt to unmarshal as a ChatCompletionResponse first.
@@ -124,12 +125,12 @@ func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) {
124125
}
125126

126127
// Try to unmarshal as a LegacyCompletionResponse.
127-
var legacyResp LegacyCompletionResponse
128+
var legacyResp CompletionResponse
128129
if err := json.Unmarshal(body, &legacyResp); err == nil {
129130
if len(legacyResp.Choices) > 0 {
130-
return &LLMResponse{LegacyCompletion: &legacyResp}, nil
131+
return &LLMResponse{Completion: &legacyResp}, nil
131132
}
132133
}
133134

134-
return nil, fmt.Errorf("failed to unmarshal body into any known LLM response format")
135+
return nil, errors.New("failed to unmarshal body into any known LLM response format")
135136
}

0 commit comments

Comments
 (0)