Skip to content

Commit 028974c

Browse files
committed
Add reponse to prefix cache in nonStreaming mode.
1 parent 1a7793a commit 028974c

File tree

7 files changed

+717
-24
lines changed

7 files changed

+717
-24
lines changed

pkg/epp/handlers/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ type Request struct {
115115
}
116116
type Response struct {
117117
Headers map[string]string
118+
Body []byte
118119
}
119120
type StreamRequestState int
120121

pkg/epp/requestcontrol/director.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,13 +280,20 @@ func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *hand
280280

281281
// HandleResponseBodyComplete is called when the response body is fully received.
282282
func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
283-
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
283+
requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey]
284+
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID)
284285
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
286+
llmResponse, err := schedulingtypes.NewLLMResponseFromBytes(reqCtx.Response.Body)
287+
if err != nil {
288+
logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.")
289+
return reqCtx, err
290+
}
285291
response := &Response{
286-
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
292+
RequestId: requestID,
287293
Headers: reqCtx.Response.Headers,
294+
// Currently use the first choice as the response body to process.
295+
Body: llmResponse.GetFirstChoiceContent(),
288296
}
289-
290297
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
291298

292299
logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")

pkg/epp/requestcontrol/director_test.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,23 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
696696
mockSched := &mockScheduler{}
697697
director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseCompletePlugins(pc1))
698698

699+
chatCompletionJSON := `{
700+
"choices": [
701+
{
702+
"message": {
703+
"role": "assistant",
704+
"content": "Hello!"
705+
},
706+
"finish_reason": "stop"
707+
}
708+
],
709+
"usage": {
710+
"prompt_tokens": 1,
711+
"completion_tokens": 2,
712+
"total_tokens": 3
713+
}
714+
}`
715+
699716
reqCtx := &handlers.RequestContext{
700717
Request: &handlers.Request{
701718
Headers: map[string]string{
@@ -704,6 +721,7 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
704721
},
705722
Response: &handlers.Response{
706723
Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"},
724+
Body: []byte(chatCompletionJSON),
707725
},
708726
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
709727
}
@@ -717,11 +735,14 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
717735
t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff)
718736
}
719737
if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" {
720-
t.Errorf("Scheduler.OnComplete Headers mismatch (-want +got):\n%s", diff)
738+
t.Errorf("Scheduler.OnComplete response headers mismatch (-want +got):\n%s", diff)
721739
}
722740
if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" {
723741
t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff)
724742
}
743+
if diff := cmp.Diff("Hello!", pc1.lastRespOnComplete.Body); diff != "" {
744+
t.Errorf("Scheduler.OnComplete response body mismatch (-want +got):\n%s", diff)
745+
}
725746
}
726747

727748
const (

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package prefix
1818

1919
import (
20+
"bytes"
2021
"context"
2122
"encoding/binary"
2223
"encoding/json"
@@ -28,6 +29,7 @@ import (
2829
k8stypes "k8s.io/apimachinery/pkg/types"
2930
"sigs.k8s.io/controller-runtime/pkg/log"
3031

32+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3133
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3234
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3335
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
@@ -117,6 +119,10 @@ var _ plugins.StateData = &SchedulingContextState{}
117119
type SchedulingContextState struct {
118120
// PrefixHashes is a list of prefix hashes of the request prompt broken into blocks.
119121
PrefixHashes []BlockHash
122+
// RestBytes is the trailing bytes that not able to fill in a full block and left over.
123+
// If not empty, this will be used as the starting block for the following response that will
124+
// be added to the response as well. This happens especially at the multi-turn scenario.
125+
RestBytes []byte
120126
// A map of server to its longest prefix cache match length.
121127
PrefixCacheServers map[ServerID]int
122128
}
@@ -193,9 +199,10 @@ func (p *Plugin) WithName(name string) *Plugin {
193199
// Score returns the scoring result for the given list of pods based on context.
194200
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
195201
// pre score step, hashing prompt and find longest prefix match.
196-
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
202+
hashes, restBytes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
197203
state := &SchedulingContextState{
198204
PrefixHashes: hashes,
205+
RestBytes: restBytes,
199206
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
200207
}
201208

@@ -301,47 +308,59 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle)
301308
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
302309
// hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache.
303310
// For block i, hash(i) = hash(block i content, hash(i-1)).
304-
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
311+
// Also return the extra string.
312+
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) {
305313
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
306314
if request == nil || request.Body == nil {
307315
loggerDebug.Info("Request or request data is nil, skipping hashing")
308-
return nil
316+
return nil, nil
309317
}
310318

311319
userInput, err := getUserInputBytes(request)
312320
if err != nil {
313321
loggerDebug.Error(err, "Failed to get user input bytes")
314-
return nil
322+
return nil, nil
315323
}
324+
prevBlockHash := defaultPrevBlock(request)
325+
return hashInputWithPrevBlockHash(ctx, prevBlockHash, 0, userInput, cacheBlockSize, maxPrefixBlocks)
326+
}
316327

317-
if len(userInput) < cacheBlockSize {
318-
loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize)
319-
return nil
320-
}
321-
if len(userInput) > cacheBlockSize*maxPrefixBlocks {
322-
loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
323-
userInput = userInput[:maxPrefixBlocks*cacheBlockSize]
324-
}
325-
// Split the body into blocks of size cacheBlockSize.
326-
// If the last block is smaller than cacheBlockSize, it will be ignored.
327-
res := make([]BlockHash, 0, len(userInput)/cacheBlockSize)
328-
// Add the model to the first block hash so that different models have different hashes even with the same body.
328+
func defaultPrevBlock(request *types.LLMRequest) BlockHash {
329329
h := xxhash.New()
330+
// Add the model to the first block hash so that different models have different hashes even with the same body.
330331
_, _ = h.Write([]byte(request.TargetModel))
331332
if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" {
332333
_, _ = h.Write([]byte(cacheSalt))
333334
}
334335

335-
prevBlockHash := BlockHash(h.Sum64())
336-
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
336+
return BlockHash(h.Sum64())
337+
}
338+
339+
func hashInputWithPrevBlockHash(ctx context.Context, prevBlockHash BlockHash, prevBlockLength int, input []byte, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) {
340+
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
341+
if len(input)+prevBlockLength < cacheBlockSize {
342+
loggerDebug.Info("Request body too small for prefix cache", "size", len(input), "block size", cacheBlockSize)
343+
return nil, input
344+
}
345+
if len(input)+prevBlockLength > cacheBlockSize*maxPrefixBlocks {
346+
loggerDebug.Info("Truncating input", "size", len(input), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
347+
input = input[:(maxPrefixBlocks*cacheBlockSize - prevBlockLength)]
348+
}
349+
// Split the body into blocks of size cacheBlockSize.
350+
// If the last block is smaller than cacheBlockSize, it will be ignored.
351+
res := make([]BlockHash, 0, len(input)/cacheBlockSize)
352+
lastOffSet := 0
353+
h := xxhash.New()
354+
for i := 0; i+cacheBlockSize <= len(input); i += cacheBlockSize {
337355
h.Reset()
338-
_, _ = h.Write(userInput[i : i+cacheBlockSize])
356+
_, _ = h.Write(input[i : i+cacheBlockSize])
339357
_, _ = h.Write(toBytes(prevBlockHash))
340358
res = append(res, BlockHash(h.Sum64()))
341359

342360
prevBlockHash = res[len(res)-1]
361+
lastOffSet = i + cacheBlockSize
343362
}
344-
return res
363+
return res, input[lastOffSet:]
345364
}
346365

347366
func toBytes(i BlockHash) []byte {
@@ -359,6 +378,33 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
359378
return json.Marshal(request.Body.ChatCompletions.Messages)
360379
}
361380

381+
func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) {
382+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
383+
if err != nil {
384+
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
385+
return
386+
}
387+
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it.
388+
var input bytes.Buffer
389+
input.Write(state.RestBytes)
390+
input.Write([]byte(response.Body))
391+
392+
server := ServerID(targetPod.NamespacedName)
393+
prevBlockHash := defaultPrevBlock(request)
394+
prevBlockHashLength := 0
395+
if len(state.PrefixHashes) > 0 {
396+
prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1]
397+
prevBlockHashLength = len(state.PrefixHashes)
398+
}
399+
inputBytes := input.Bytes()
400+
hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, inputBytes, p.config.DefaultBlockSize, p.config.MaxPrefixBlocksToMatch)
401+
p.wg.Add(1)
402+
go func() {
403+
p.indexer.Add(hashBlocks, server)
404+
p.wg.Done()
405+
}()
406+
}
407+
362408
func getBlockSize(pods []types.Pod, defaultBlockSize int) int {
363409
if len(pods) == 0 {
364410
return defaultBlockSize

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3131
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3232
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
33+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
3334
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3435
)
3536

@@ -199,6 +200,89 @@ func TestPrefixPluginCompletion(t *testing.T) {
199200
plugin.wg.Wait()
200201
}
201202

203+
func TestPrefixPluginCompletionWithResponse(t *testing.T) {
204+
config := Config{
205+
DefaultBlockSize: 4,
206+
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
207+
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
208+
}
209+
plugin := New(context.Background(), config)
210+
211+
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}
212+
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}
213+
pods := []types.Pod{pod1, pod2}
214+
215+
// -- First Request --
216+
// This initial request will populate the cache.
217+
req1 := &types.LLMRequest{
218+
RequestId: uuid.NewString(),
219+
TargetModel: "test-model1",
220+
Body: &types.LLMRequestBody{
221+
Completions: &types.CompletionsRequest{
222+
Prompt: "aaaaaa",
223+
},
224+
},
225+
}
226+
scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
227+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
228+
assert.NoError(t, err)
229+
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
230+
// Input size is 6, hash block size is 4, so the last 2 characters are ignored.
231+
// Total hashes = 1 (for the "aaaa" block) + 1 (for the model prefix).
232+
assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect")
233+
assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers yet")
234+
assert.Equal(t, float64(0), scores[pod1], "score for pod1 should be 0 on first request")
235+
assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0 on first request")
236+
237+
// Simulate that the scheduler picked pod1 for the first request.
238+
schedulingResult := &types.SchedulingResult{
239+
PrimaryProfileName: "default",
240+
ProfileResults: map[string]*types.ProfileRunResult{
241+
"default": {TargetPods: []types.Pod{pod1}},
242+
},
243+
}
244+
plugin.PreRequest(context.Background(), req1, schedulingResult, 0)
245+
plugin.wg.Wait()
246+
247+
// -- Simulate Response Completion --
248+
// The ResponseComplete hook is called. The plugin should update pod1's KV cache
249+
// with the full context of the completed interaction (prompt + response).
250+
// - Initial Prompt: "aaaaaa"
251+
// - Response Body: "bb"
252+
// - Cached Sequence: "aaaaaabb" (length 8)
253+
// This sequence creates two 4-character blocks to be cached: "aaaa" and "aabb".
254+
plugin.ResponseComplete(context.Background(), req1, &requestcontrol.Response{Body: "bb"}, pod1.GetPod())
255+
plugin.wg.Wait()
256+
257+
// -- Second Request: Multi-turn Follow-up --
258+
// This request simulates a follow-up message in a chat. The prompt contains the
259+
// entire conversation history ("aaaaaabb") plus new text ("cc").
260+
// The plugin should find that the first two blocks ("aaaa", "aabb") of this new
261+
// prompt are already cached on pod1, giving it a perfect match score of 1.0.
262+
// Pod2 has no matching cache entries and should score 0.
263+
req2 := &types.LLMRequest{
264+
RequestId: uuid.NewString(),
265+
TargetModel: "test-model1",
266+
Body: &types.LLMRequestBody{
267+
Completions: &types.CompletionsRequest{
268+
Prompt: "aaaaaabbcc",
269+
},
270+
},
271+
}
272+
scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods)
273+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String()))
274+
assert.NoError(t, err)
275+
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
276+
// Input size is 10, hash block size is 4. The prompt "aaaaaabb" generates 2 hashes.
277+
// The last 2 characters ("cc") are ignored.
278+
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
279+
// It should find a server (pod1) that has cached the prefixes.
280+
assert.Equal(t, 1, len(state.PrefixCacheServers), "a cached server should have been found")
281+
// The score for pod1 should be 1.0 because both prompt blocks ("aaaa" and "aabb") were found in its cache.
282+
assert.Equal(t, float64(1), scores[pod1], "score for pod1 should be a perfect match")
283+
assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0")
284+
}
285+
202286
func TestPrefixPluginChatCompletions(t *testing.T) {
203287
config := Config{
204288
DefaultBlockSize: 4,

0 commit comments

Comments
 (0)