@@ -17,6 +17,7 @@ limitations under the License.
1717package prefix
1818
1919import (
20+ "bytes"
2021 "context"
2122 "encoding/binary"
2223 "encoding/json"
@@ -28,6 +29,7 @@ import (
2829 k8stypes "k8s.io/apimachinery/pkg/types"
2930 "sigs.k8s.io/controller-runtime/pkg/log"
3031
32+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3133 backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3234 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3335 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
@@ -117,6 +119,10 @@ var _ plugins.StateData = &SchedulingContextState{}
117119type SchedulingContextState struct {
118120 // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks.
119121 PrefixHashes []BlockHash
122+ // RestBytes is the trailing bytes that not able to fill in a full block and left over.
123+ // If not empty, this will be used as the starting block for the following response that will
124+ // be added to the response as well. This happens especially at the multi-turn scenario.
125+ RestBytes []byte
120126 // A map of server to its longest prefix cache match length.
121127 PrefixCacheServers map [ServerID ]int
122128}
@@ -193,9 +199,10 @@ func (p *Plugin) WithName(name string) *Plugin {
193199// Score returns the scoring result for the given list of pods based on context.
194200func (p * Plugin ) Score (ctx context.Context , cycleState * types.CycleState , request * types.LLMRequest , pods []types.Pod ) map [types.Pod ]float64 {
195201 // pre score step, hashing prompt and find longest prefix match.
196- hashes := hashPrompt (ctx , request , getBlockSize (pods , p .config .DefaultBlockSize ), p .config .MaxPrefixBlocksToMatch )
202+ hashes , restBytes := hashPrompt (ctx , request , getBlockSize (pods , p .config .DefaultBlockSize ), p .config .MaxPrefixBlocksToMatch )
197203 state := & SchedulingContextState {
198204 PrefixHashes : hashes ,
205+ RestBytes : restBytes ,
199206 PrefixCacheServers : p .matchLongestPrefix (ctx , hashes ),
200207 }
201208
@@ -301,47 +308,59 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle)
301308// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
302309// hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache.
303310// For block i, hash(i) = hash(block i content, hash(i-1)).
304- func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
311+ // Also return the extra string.
312+ func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) ([]BlockHash , []byte ) {
305313 loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
306314 if request == nil || request .Body == nil {
307315 loggerDebug .Info ("Request or request data is nil, skipping hashing" )
308- return nil
316+ return nil , nil
309317 }
310318
311319 userInput , err := getUserInputBytes (request )
312320 if err != nil {
313321 loggerDebug .Error (err , "Failed to get user input bytes" )
314- return nil
322+ return nil , nil
315323 }
324+ prevBlockHash := defaultPrevBlock (request )
325+ return hashInputWithPrevBlockHash (ctx , prevBlockHash , 0 , userInput , cacheBlockSize , maxPrefixBlocks )
326+ }
316327
317- if len (userInput ) < cacheBlockSize {
318- loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (userInput ), "block size" , cacheBlockSize )
319- return nil
320- }
321- if len (userInput ) > cacheBlockSize * maxPrefixBlocks {
322- loggerDebug .Info ("Truncating input" , "size" , len (userInput ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
323- userInput = userInput [:maxPrefixBlocks * cacheBlockSize ]
324- }
325- // Split the body into blocks of size cacheBlockSize.
326- // If the last block is smaller than cacheBlockSize, it will be ignored.
327- res := make ([]BlockHash , 0 , len (userInput )/ cacheBlockSize )
328- // Add the model to the first block hash so that different models have different hashes even with the same body.
328+ func defaultPrevBlock (request * types.LLMRequest ) BlockHash {
329329 h := xxhash .New ()
330+ // Add the model to the first block hash so that different models have different hashes even with the same body.
330331 _ , _ = h .Write ([]byte (request .TargetModel ))
331332 if cacheSalt := request .Body .CacheSalt (); cacheSalt != "" {
332333 _ , _ = h .Write ([]byte (cacheSalt ))
333334 }
334335
335- prevBlockHash := BlockHash (h .Sum64 ())
336- for i := 0 ; i + cacheBlockSize <= len (userInput ); i += cacheBlockSize {
336+ return BlockHash (h .Sum64 ())
337+ }
338+
339+ func hashInputWithPrevBlockHash (ctx context.Context , prevBlockHash BlockHash , prevBlockLength int , input []byte , cacheBlockSize int , maxPrefixBlocks int ) ([]BlockHash , []byte ) {
340+ loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
341+ if len (input )+ prevBlockLength < cacheBlockSize {
342+ loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (input ), "block size" , cacheBlockSize )
343+ return nil , input
344+ }
345+ if len (input )+ prevBlockLength > cacheBlockSize * maxPrefixBlocks {
346+ loggerDebug .Info ("Truncating input" , "size" , len (input ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
347+ input = input [:(maxPrefixBlocks * cacheBlockSize - prevBlockLength )]
348+ }
349+ // Split the body into blocks of size cacheBlockSize.
350+ // If the last block is smaller than cacheBlockSize, it will be ignored.
351+ res := make ([]BlockHash , 0 , len (input )/ cacheBlockSize )
352+ lastOffSet := 0
353+ h := xxhash .New ()
354+ for i := 0 ; i + cacheBlockSize <= len (input ); i += cacheBlockSize {
337355 h .Reset ()
338- _ , _ = h .Write (userInput [i : i + cacheBlockSize ])
356+ _ , _ = h .Write (input [i : i + cacheBlockSize ])
339357 _ , _ = h .Write (toBytes (prevBlockHash ))
340358 res = append (res , BlockHash (h .Sum64 ()))
341359
342360 prevBlockHash = res [len (res )- 1 ]
361+ lastOffSet = i + cacheBlockSize
343362 }
344- return res
363+ return res , input [ lastOffSet :]
345364}
346365
347366func toBytes (i BlockHash ) []byte {
@@ -359,6 +378,33 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
359378 return json .Marshal (request .Body .ChatCompletions .Messages )
360379}
361380
381+ func (p * Plugin ) ResponseComplete (ctx context.Context , request * types.LLMRequest , response * requestcontrol.Response , targetPod * backend.Pod ) {
382+ state , err := plugins .ReadPluginStateKey [* SchedulingContextState ](p .pluginState , request .RequestId , plugins .StateKey (p .TypedName ().String ()))
383+ if err != nil {
384+ log .FromContext (ctx ).Error (err , "failed to read prefix plugin state" , "requestID" , request .RequestId )
385+ return
386+ }
387+ p .pluginState .Delete (request .RequestId ) // delete the state explicitly after completing using it.
388+ var input bytes.Buffer
389+ input .Write (state .RestBytes )
390+ input .Write ([]byte (response .Body ))
391+
392+ server := ServerID (targetPod .NamespacedName )
393+ prevBlockHash := defaultPrevBlock (request )
394+ prevBlockHashLength := 0
395+ if len (state .PrefixHashes ) > 0 {
396+ prevBlockHash = state .PrefixHashes [len (state .PrefixHashes )- 1 ]
397+ prevBlockHashLength = len (state .PrefixHashes )
398+ }
399+ inputBytes := input .Bytes ()
400+ hashBlocks , _ := hashInputWithPrevBlockHash (ctx , prevBlockHash , prevBlockHashLength , inputBytes , p .config .DefaultBlockSize , p .config .MaxPrefixBlocksToMatch )
401+ p .wg .Add (1 )
402+ go func () {
403+ p .indexer .Add (hashBlocks , server )
404+ p .wg .Done ()
405+ }()
406+ }
407+
362408func getBlockSize (pods []types.Pod , defaultBlockSize int ) int {
363409 if len (pods ) == 0 {
364410 return defaultBlockSize
0 commit comments