From 415a62454cffd17190b447a638f5b5d9d5a23090 Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Tue, 29 Apr 2025 20:00:41 +0000 Subject: [PATCH 1/2] Add prefix cache aware scheduling --- cmd/epp/main.go | 23 +++ pkg/epp/metrics/metrics.go | 75 ++++++++ pkg/epp/metrics/metrics_test.go | 103 ++++++++++ .../testdata/prefix_indexer_hit_bytes_metric | 19 ++ .../testdata/prefix_indexer_hit_ratio_metric | 16 ++ .../testdata/prefix_indexer_size_metric | 3 + pkg/epp/scheduling/plugins/filter/filter.go | 8 + pkg/epp/scheduling/plugins/prefix/indexer.go | 163 ++++++++++++++++ .../scheduling/plugins/prefix/indexer_test.go | 46 +++++ .../scheduling/plugins/prefix/linked_list.go | 85 +++++++++ pkg/epp/scheduling/plugins/prefix/plugin.go | 178 ++++++++++++++++++ .../scheduling/plugins/prefix/plugin_test.go | 132 +++++++++++++ pkg/epp/scheduling/scheduler_v2.go | 62 ++++++ pkg/epp/scheduling/types/types.go | 14 ++ 14 files changed, 927 insertions(+) create mode 100644 pkg/epp/metrics/testdata/prefix_indexer_hit_bytes_metric create mode 100644 pkg/epp/metrics/testdata/prefix_indexer_hit_ratio_metric create mode 100644 pkg/epp/metrics/testdata/prefix_indexer_size_metric create mode 100644 pkg/epp/scheduling/plugins/prefix/indexer.go create mode 100644 pkg/epp/scheduling/plugins/prefix/indexer_test.go create mode 100644 pkg/epp/scheduling/plugins/prefix/linked_list.go create mode 100644 pkg/epp/scheduling/plugins/prefix/plugin.go create mode 100644 pkg/epp/scheduling/plugins/prefix/plugin_test.go create mode 100644 pkg/epp/scheduling/scheduler_v2.go diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 2bd779c55..e10572944 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -34,6 +34,7 @@ import ( "k8s.io/client-go/rest" "k8s.io/component-base/metrics/legacyregistry" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/metrics/filters" @@ -42,7 +43,9 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" + envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -106,8 +109,24 @@ var ( "Prometheus metric for the LoRA info metrics (must be in vLLM label format).") setupLog = ctrl.Log.WithName("setup") + + // Environment variables + schedulerV2 = envutil.GetEnvString("EXPERIMENTAL_USE_SCHEDULE_V2", "false", setupLog) + prefixCacheConfig = loadPrefixCacheConfig() ) +func loadPrefixCacheConfig() prefix.Config { + // logger := zap.New(zap.RawZapOpts(uberzap.AddCaller())) + // log.SetLogger(logger) + baseLogger := log.Log.WithName("env-config") + + return prefix.Config{ + HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultCacheBlockSize, baseLogger), + MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger), + LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_MAX_CACHE_SIZE_MB", prefix.DefaultLRUIndexerCapacity, baseLogger), + } +} + func main() { if err := run(); err != nil { os.Exit(1) @@ -171,6 +190,10 @@ func run() error { datastore := datastore.NewDatastore(ctx, pmf) scheduler := scheduling.NewScheduler(datastore) + if schedulerV2 == "true" { + setupLog.Info("Creating scheduler with prefixCache plugin", "prefix cache config", prefixCacheConfig) + scheduler = scheduling.NewSchedulerV2(datastore, prefixCacheConfig) + } serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, DestinationEndpointHintMetadataNamespace: *destinationEndpointHintMetadataNamespace, diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index 6cc0cdb83..1baa3099f 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -18,6 +18,7 @@ package metrics import ( "context" + "runtime/debug" "sync" "time" @@ -219,6 +220,40 @@ var ( }, []string{"commit"}, ) + + // Prefix indexer Metrics + PrefixCacheSize = compbasemetrics.NewGaugeVec( + &compbasemetrics.GaugeOpts{ + Subsystem: InferenceExtension, + Name: "prefix_indexer_size", + Help: "Size of the prefix indexer.", + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{}, + ) + + PrefixCacheHitRatio = compbasemetrics.NewHistogramVec( + &compbasemetrics.HistogramOpts{ + Subsystem: InferenceExtension, + Name: "prefix_indexer_hit_ratio", + Help: "Ratio of prefix length matched to total prefix length in the cache lookup.", + // Buckets from 0.0 to 1.0 in increments + Buckets: []float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{}, + ) + + PrefixCacheHitLength = compbasemetrics.NewHistogramVec( + &compbasemetrics.HistogramOpts{ + Subsystem: InferenceExtension, + Name: "prefix_indexer_hit_bytes", + Help: "Length of the prefix match in number of bytes in the cache lookup.", + Buckets: []float64{0, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}, + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{}, + ) ) var registerMetrics sync.Once @@ -244,6 +279,10 @@ func Register() { legacyregistry.MustRegister(SchedulerE2ELatency) legacyregistry.MustRegister(InferenceExtensionInfo) + + legacyregistry.MustRegister(PrefixCacheSize) + legacyregistry.MustRegister(PrefixCacheHitRatio) + legacyregistry.MustRegister(PrefixCacheHitLength) }) } @@ -352,8 +391,44 @@ func RecordSchedulerE2ELatency(duration time.Duration) { SchedulerE2ELatency.WithLabelValues().Observe(duration.Seconds()) } +// RecordPrefixCacheSize records the size of the prefix indexer in megabytes. +func RecordPrefixCacheSize(size int64) { + PrefixCacheSize.WithLabelValues().Set(float64(size)) +} + +// RecordPrefixCacheMatch records both the hit ratio and hit length for a prefix indexer match. +// matchedLength is the number of characters that matched, and totalLength is the total prefix length. +func RecordPrefixCacheMatch(matchedLength, totalLength int) { + // Record the hit length metric + PrefixCacheHitLength.WithLabelValues().Observe(float64(matchedLength)) + + // Record the hit ratio metric if totalLength is positive + if totalLength > 0 { + ratio := float64(matchedLength) / float64(totalLength) + PrefixCacheHitRatio.WithLabelValues().Observe(ratio) + } +} + func RecordInferenceExtensionInfo() { if CommitSHA != "" { InferenceExtensionInfo.WithLabelValues(CommitSHA).Set(1) } } + +func init() { + info, ok := debug.ReadBuildInfo() + if !ok { + return + } + + var Commit = func(i *debug.BuildInfo) string { + for _, setting := range i.Settings { + if setting.Key == "vcs.revision" { + return setting.Value + } + } + return "" + }(info) + + CommitSHA = Commit +} diff --git a/pkg/epp/metrics/metrics_test.go b/pkg/epp/metrics/metrics_test.go index a2311517d..a29c5d4be 100644 --- a/pkg/epp/metrics/metrics_test.go +++ b/pkg/epp/metrics/metrics_test.go @@ -663,3 +663,106 @@ func TestSchedulerE2ELatency(t *testing.T) { }) } } + +func TestPrefixCacheMetrics(t *testing.T) { + const ( + PrefixCacheSizeMetric = InferenceExtension + "_prefix_indexer_size" + PrefixCacheHitRatioMetric = InferenceExtension + "_prefix_indexer_hit_ratio" + PrefixCacheHitLengthMetric = InferenceExtension + "_prefix_indexer_hit_bytes" + ) + + type cacheMatchRecord struct { + matchedLength int + totalLength int + } + + scenario := struct { + name string + cacheSizes []int64 + cacheMatches []cacheMatchRecord + }{ + name: "multiple cache metrics", + cacheSizes: []int64{1024, 2048, 4096}, + cacheMatches: []cacheMatchRecord{ + { + matchedLength: 5, + totalLength: 10, + }, + { + matchedLength: 0, + totalLength: 10, + }, + { + matchedLength: 10, + totalLength: 10, + }, + { + matchedLength: 7, + totalLength: 10, + }, + { + matchedLength: 64, + totalLength: 128, + }, + { + matchedLength: 0, + totalLength: 128, + }, + }, + } + + Register() + t.Run(scenario.name, func(t *testing.T) { + // Record cache size metrics + for _, size := range scenario.cacheSizes { + RecordPrefixCacheSize(size) + } + + // Record cache match metrics (both hit ratio and hit length) + for _, match := range scenario.cacheMatches { + RecordPrefixCacheMatch(match.matchedLength, match.totalLength) + } + + // Verify cache size metrics + wantCacheSizeMetrics, err := os.Open("testdata/prefix_indexer_size_metric") + defer func() { + if err := wantCacheSizeMetrics.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantCacheSizeMetrics, PrefixCacheSizeMetric); err != nil { + t.Error(err) + } + + // Verify hit ratio metrics + wantHitRatioMetrics, err := os.Open("testdata/prefix_indexer_hit_ratio_metric") + defer func() { + if err := wantHitRatioMetrics.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantHitRatioMetrics, PrefixCacheHitRatioMetric); err != nil { + t.Error(err) + } + + // Verify hit length metrics + wantHitLengthMetrics, err := os.Open("testdata/prefix_indexer_hit_bytes_metric") + defer func() { + if err := wantHitLengthMetrics.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantHitLengthMetrics, PrefixCacheHitLengthMetric); err != nil { + t.Error(err) + } + }) +} diff --git a/pkg/epp/metrics/testdata/prefix_indexer_hit_bytes_metric b/pkg/epp/metrics/testdata/prefix_indexer_hit_bytes_metric new file mode 100644 index 000000000..86b48724e --- /dev/null +++ b/pkg/epp/metrics/testdata/prefix_indexer_hit_bytes_metric @@ -0,0 +1,19 @@ +# HELP inference_extension_prefix_indexer_hit_bytes [ALPHA] Length of the prefix match in number of bytes in the cache lookup. +# TYPE inference_extension_prefix_indexer_hit_bytes histogram +inference_extension_prefix_indexer_hit_bytes_bucket{le="0"} 2 +inference_extension_prefix_indexer_hit_bytes_bucket{le="16"} 5 +inference_extension_prefix_indexer_hit_bytes_bucket{le="32"} 5 +inference_extension_prefix_indexer_hit_bytes_bucket{le="64"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="128"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="256"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="512"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="1024"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="2048"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="4096"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="8192"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="16384"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="32768"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="65536"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="+Inf"} 6 +inference_extension_prefix_indexer_hit_bytes_sum 86 +inference_extension_prefix_indexer_hit_bytes_count 6 diff --git a/pkg/epp/metrics/testdata/prefix_indexer_hit_ratio_metric b/pkg/epp/metrics/testdata/prefix_indexer_hit_ratio_metric new file mode 100644 index 000000000..e94827cb6 --- /dev/null +++ b/pkg/epp/metrics/testdata/prefix_indexer_hit_ratio_metric @@ -0,0 +1,16 @@ +# HELP inference_extension_prefix_indexer_hit_ratio [ALPHA] Ratio of prefix length matched to total prefix length in the cache lookup. +# TYPE inference_extension_prefix_indexer_hit_ratio histogram +inference_extension_prefix_indexer_hit_ratio_bucket{le="0"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.1"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.2"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.3"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.4"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.5"} 4 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.6"} 4 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.7"} 5 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.8"} 5 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.9"} 5 +inference_extension_prefix_indexer_hit_ratio_bucket{le="1"} 6 +inference_extension_prefix_indexer_hit_ratio_bucket{le="+Inf"} 6 +inference_extension_prefix_indexer_hit_ratio_sum 2.7 +inference_extension_prefix_indexer_hit_ratio_count 6 diff --git a/pkg/epp/metrics/testdata/prefix_indexer_size_metric b/pkg/epp/metrics/testdata/prefix_indexer_size_metric new file mode 100644 index 000000000..9799b1729 --- /dev/null +++ b/pkg/epp/metrics/testdata/prefix_indexer_size_metric @@ -0,0 +1,3 @@ +# HELP inference_extension_prefix_indexer_size [ALPHA] Size of the prefix indexer. +# TYPE inference_extension_prefix_indexer_size gauge +inference_extension_prefix_indexer_size{} 4096 diff --git a/pkg/epp/scheduling/plugins/filter/filter.go b/pkg/epp/scheduling/plugins/filter/filter.go index 86620aa9f..67ce764dd 100644 --- a/pkg/epp/scheduling/plugins/filter/filter.go +++ b/pkg/epp/scheduling/plugins/filter/filter.go @@ -256,6 +256,14 @@ var HasCapacityFilter = &baseFilter{ filter: toFilterFunc(queueThresholdPredicate(config.Conf.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.Conf.KVCacheThreshold))), } +// NoopFilter is a filter that does not filter out any pods. +var NoopFilter = &baseFilter{ + name: "noop", + filter: toFilterFunc(func(req *types.LLMRequest, pod types.Pod) bool { + return true + }), +} + // podPredicate is a filter function to check whether a pod is desired. type podPredicate func(req *types.LLMRequest, pod types.Pod) bool diff --git a/pkg/epp/scheduling/plugins/prefix/indexer.go b/pkg/epp/scheduling/plugins/prefix/indexer.go new file mode 100644 index 000000000..cae7739bd --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/indexer.go @@ -0,0 +1,163 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package prefix + +import ( + "context" + "sync" + "time" + "unsafe" + + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func newIndexer(maxCacheSize int) *indexer { + t := &indexer{ + maxCacheSize: maxCacheSize, + table: make(map[types.BlockHash]map[types.ServerID]*node), + list: newLinkedList(), + } + go t.ReportCacheSize(time.Second) + return t +} + +// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that +// prefix cached . +type indexer struct { + mu sync.RWMutex + maxCacheSize int + table map[types.BlockHash]map[types.ServerID]*node // from any prefix cache to the cache entry to find the server + list *linkedList // LRU list to keep track of the order of entries +} + +// Get returns the set of servers that have the given prefix hash cached. +func (i *indexer) Get(hash types.BlockHash) map[types.ServerID]bool { + i.mu.RLock() + defer i.mu.RUnlock() + res := map[types.ServerID]bool{} + for server := range i.table[hash] { + res[server] = true + } + return res +} + +// Add adds a list of prefix hashes of a single request to the server the request was sent to. +// The intuition is that this server is likely to have the prefix cached, so next time a request +// sharing the longest prefix should be sent to the same server to take advantage of the cache hit. +func (i *indexer) Add(hashes []types.BlockHash, server types.ServerID) { + i.mu.Lock() + defer i.mu.Unlock() + for _, hash := range hashes { + i.add(hash, server) + } +} + +func (i *indexer) check(hash types.BlockHash, server types.ServerID) (*node, bool) { + servers, ok := i.table[hash] + if !ok { + return nil, false + } + n, ok := servers[server] + return n, ok +} + +func (i *indexer) add(hash types.BlockHash, server types.ServerID) { + node, exists := i.check(hash, server) + if exists { + i.list.moveToTail(node) + } else { + i.create(hash, server) + } +} + +func (i *indexer) create(hash types.BlockHash, server types.ServerID) { + n := &node{ + hash: hash, + server: server, + } + + for i.list.size >= i.maxCacheSize { + // Evict the least recently used entry if we've exceeded the max cache size + i.evict() + } + + if _, ok := i.table[hash]; !ok { + i.table[hash] = make(map[types.ServerID]*node) + } + i.table[hash][server] = n + i.list.add(n) +} + +// evict removes the least recently used entry from the cache +func (i *indexer) evict() { + oldestNode := i.list.dummyHead.next + i.list.delete(oldestNode) + + hash := oldestNode.hash + server := oldestNode.server + // Remove from the hash map + serverMap := i.table[hash] + delete(serverMap, server) + + // If this was the last server for this hash, remove the hash entry entirely + if len(serverMap) == 0 { + delete(i.table, hash) + } + + log.FromContext(context.TODO()).V(logutil.TRACE).Info("Evicted LRU entry", "hash", hash, "server", server) +} + +// ReportCacheSize starts a goroutine that periodically reports the cache size metric +func (i *indexer) ReportCacheSize(interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for range ticker.C { + i.mu.RLock() + metrics.RecordPrefixCacheSize(int64(i.list.size)) + log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.list.size, "estimated size MB", i.list.size*i.estimateEntrySize()/1000000) + i.mu.RUnlock() + } +} + +// estimateEntrySize estimates the memory size of a cache entry in bytes. +func (i *indexer) estimateEntrySize() int { + size := 0 + + // Estimate the size of a node in the linked list. + // First get the size of the node struct via unsafe.Sizeof. + // The prev and next pointers are 8 bytes each on a 64-bit system. + // The BlockHash is a uint64, which is 8 bytes. + // The ServerID is a NamespacedName, which contains two strings (Name and Namespace). + // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length). + // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes. + size += int(unsafe.Sizeof(node{})) + // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName). + size += 2 * 63 + + // Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored. + size += 8 // Size of the BlockHash (uint64). + size += 2 * 16 // Size of the ServerID string headers (NamespacedName). + size += 2 * 63 // Size of the Name and Namespace strings in ServerID. + size += 8 // Size of the pointer to the node in the hash map. + + // Based on the above estimates, the estimated size of an entry is: + // (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes. + return size +} diff --git a/pkg/epp/scheduling/plugins/prefix/indexer_test.go b/pkg/epp/scheduling/plugins/prefix/indexer_test.go new file mode 100644 index 000000000..592b7c3e3 --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/indexer_test.go @@ -0,0 +1,46 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package prefix + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +func TestIndexer_AddAndGet(t *testing.T) { + cache := newIndexer(2) + + hash1 := types.BlockHash(1) + server := types.ServerID{Namespace: "default", Name: "server1"} + + // Add an entry to the cache + cache.Add([]types.BlockHash{hash1}, server) + + // Retrieve the entry + assert.Equal(t, 1, cache.list.size, "Cache size should be 1 after adding an entry") + servers := cache.Get(hash1) + assert.Contains(t, servers, server, "Cache should contain the added server") + + // Add another entry to the cache, the cache size should be incremented to 2. + cache.Add([]types.BlockHash{types.BlockHash(2)}, server) + assert.Equal(t, 2, cache.list.size, "Cache size should be 2 after adding an entry") + + // Add another entry to the cache, which should evict the first one due to max size. + cache.Add([]types.BlockHash{types.BlockHash(3)}, server) + assert.Equal(t, 2, cache.list.size, "Cache size should still be 2 after adding an entry") +} diff --git a/pkg/epp/scheduling/plugins/prefix/linked_list.go b/pkg/epp/scheduling/plugins/prefix/linked_list.go new file mode 100644 index 000000000..9c9b82103 --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/linked_list.go @@ -0,0 +1,85 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package prefix + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +type linkedList struct { + dummyHead *node // The head of the linked list (dummy node). + tail *node // The tail of the linked list. + size int // The size of the linked list (excluding dummy head). +} + +// newLinkedList initializes a new linked list with a dummy head node. +// Using a dummy head simplifies the implementation by eliminating nil checks. +func newLinkedList() *linkedList { + dummy := &node{} // Create dummy head node + return &linkedList{ + dummyHead: dummy, + tail: dummy, + size: 0, + } +} + +type node struct { + prev *node + next *node + server types.ServerID + hash types.BlockHash +} + +// add adds a node to the end of the linked list. +func (ll *linkedList) add(n *node) { + ll.size++ + + n.prev = ll.tail + ll.tail.next = n + ll.tail = n +} + +// delete removes a node from the linked list. +// Note the method assumes the input node exists in the list. +func (ll *linkedList) delete(n *node) { + ll.size-- + n.prev.next = n.next + + // If it's the tail node + if n.next == nil { + ll.tail = n.prev + } else { + n.next.prev = n.prev + } +} + +// moveToTail moves an existing node to the end of the linked list (most recent). +func (ll *linkedList) moveToTail(n *node) { + if n.next == nil { + // Already the tail, no need to move. + return + } + + n.prev.next = n.next + n.next.prev = n.prev + + // Move it to the tail position + n.prev = ll.tail + n.next = nil + ll.tail.next = n + ll.tail = n +} diff --git a/pkg/epp/scheduling/plugins/prefix/plugin.go b/pkg/epp/scheduling/plugins/prefix/plugin.go new file mode 100644 index 000000000..2e748af82 --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/plugin.go @@ -0,0 +1,178 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package prefix + +import ( + "encoding/binary" + "fmt" + + "github.com/cespare/xxhash/v2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + // Attempt to return DefaultNumServersToMatch servers with their longest prefix match length. + // Why not just return the server with longest prefix match? + // It may not be the optimal choice, e.g., it may have a high queue depth. + // We optimistically search more than one to give more candidates for the scheduler to choose. + DefaultNumServersToMatch = 2 + // vLLM default token block size is 16, and a good guess of average characters per token is 4. + DefaultCacheBlockSize = 64 + DefaultMaxPrefixBlocks = 128 + // Assume each request reaches DefaultMaxPrefixBlocks = 128, and each BlockHash is cached onto 2 + // servers due to load balancing, then it requires 256 entries per request. + // According to the estimates in indexer.estimateEntrySize(), the size of each entry is 348 bytes. + // So each request will cost 89,088 bytes ~ 90KB. + // Therefore, to cache 50k requests, we need 50K * 90KB = 4.5GB. Assuming 500 requests per + // second, a 4.5 GB cache can hold at least last 100 seconds of requests. + // Note in practice, the size of each entry will be much smaller (shorter NamespacedNames, + // shorter prompt). And due to the prefix cache hit, the number of unique cache entries will be + // much smaller per request. Therefore the actual cache size will be much smaller. + // TODO: Add some guidance for choosing the right size. + DefaultLRUIndexerCapacity = 50000 +) + +type Config struct { + // The input prompt is broken into sizes of HashBlockSize to calculate block hashes . Requests + // with length shorter than the block size will be ignored. + HashBlockSize int + // MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will + // be ignored. + MaxPrefixBlocksToMatch int + // Max (approximate) size of the LRU indexer in number of entries. + LRUIndexerCapacity int +} + +var DefaultConfig = Config{ + HashBlockSize: DefaultCacheBlockSize, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUIndexerCapacity: DefaultLRUIndexerCapacity, +} + +type plugin struct { + Config + indexer Indexer +} + +type Indexer interface { + Get(hash types.BlockHash) map[types.ServerID]bool + Add(hashes []types.BlockHash, server types.ServerID) +} + +func New(config Config) *plugin { + m := &plugin{ + Config: config, + indexer: newIndexer(config.LRUIndexerCapacity), + } + return m +} + +func (m *plugin) Name() string { + return "prefixCache" +} + +func (m *plugin) PreSchedule(ctx *types.SchedulingContext) { + ctx.PrefixHashes = hashPrompt(ctx, m.HashBlockSize, m.MaxPrefixBlocksToMatch) + ctx.PrefixCacheServers = m.matchLongestPrefix(ctx, DefaultNumServersToMatch) + ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("PreSchedule, cached servers: %+v", ctx.PrefixCacheServers), "hashes", ctx.PrefixHashes) +} + +// If a request was routed to a server, record it in the cache: +func (m *plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { + targetPod := res.TargetPod.GetPod() + m.indexer.Add(ctx.PrefixHashes, types.ServerID(targetPod.NamespacedName)) + total := len(ctx.PrefixHashes) + matchLen := ctx.PrefixCacheServers[types.ServerID(targetPod.NamespacedName)] + metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize) +} + +func (m *plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { + total := len(ctx.PrefixHashes) + podScoreFunc := func(ctx *types.SchedulingContext, pod types.Pod) float64 { + if total == 0 { + return 0 + } + matchLen := ctx.PrefixCacheServers[types.ServerID(pod.GetPod().NamespacedName)] + return float64(matchLen) / float64(total) + } + + scores := make(map[types.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = podScoreFunc(ctx, pod) + } + return scores +} + +// matchLongestPrefix returns a map of servers and length of prefix that each server caches. +func (m *plugin) matchLongestPrefix(ctx *types.SchedulingContext, numServers int) map[types.ServerID]int { + if numServers > len(ctx.PodsSnapshot) { + numServers = len(ctx.PodsSnapshot) + } + res := make(map[types.ServerID]int) + // Use a greedy strategy to search from the longest prefix. + // NOTE: It's possible to further optimize this with a binary search. + for i := len(ctx.PrefixHashes) - 1; i >= 0 && len(res) < numServers; i-- { + hash := ctx.PrefixHashes[i] + cachedServers := m.indexer.Get(hash) + if len(cachedServers) > 0 { + ctx.Logger.V(logutil.VERBOSE).Info("Found cached servers", "cachedServers", cachedServers, "total # blocks", len(ctx.PrefixHashes), "longest prefix", i) + for server := range cachedServers { + // Update servers with their longest prefix match. + // If we already found this server with longer prefix match, don't update it. + if _, ok := res[server]; !ok { + res[server] = i + 1 + } + } + } + } + return res +} + +// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block. +// hash(0) is the hash of the model name, since different models generally don't share prefix cache. +// For block i, hash(i) = hash(block i content, hash(i-1)). +func hashPrompt(ctx *types.SchedulingContext, cacheBlockSize int, maxPrefixBlocks int) []types.BlockHash { + prompt := []byte(ctx.Req.Prompt) + if len(prompt) < cacheBlockSize { + ctx.Logger.V(logutil.DEBUG).Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize) + return nil + } + if len(prompt) > cacheBlockSize*maxPrefixBlocks { + ctx.Logger.V(logutil.DEBUG).Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) + prompt = prompt[:maxPrefixBlocks*cacheBlockSize] + } + // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model. + // If the last block is smaller than cacheBlockSize, it will be ignored. + res := make([]types.BlockHash, 0, 1+len(prompt)/cacheBlockSize) + // Add the model to the first block hash so that different models have different hashes even with the same body. + res = append(res, types.BlockHash(xxhash.Sum64String(ctx.Req.ResolvedTargetModel))) + for i := 0; i+cacheBlockSize <= len(prompt); i += cacheBlockSize { + block := prompt[i : i+cacheBlockSize] + prevBlockHash := res[len(res)-1] + toHash := append(block, toBytes(prevBlockHash)...) + res = append(res, types.BlockHash(xxhash.Sum64(toHash))) + } + return res +} + +func toBytes(i types.BlockHash) []byte { + bytes := make([]byte, 8) + binary.LittleEndian.PutUint64(bytes, uint64(i)) + return bytes +} diff --git a/pkg/epp/scheduling/plugins/prefix/plugin_test.go b/pkg/epp/scheduling/plugins/prefix/plugin_test.go new file mode 100644 index 000000000..47c6c7f18 --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/plugin_test.go @@ -0,0 +1,132 @@ +package prefix + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +func TestPrefixPlugin(t *testing.T) { + config := Config{ + HashBlockSize: 4, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUIndexerCapacity: DefaultLRUIndexerCapacity, + } + plugin := New(config) + + pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} + pods := []types.Pod{pod1, pod2} + + // First request. + req1 := &types.LLMRequest{ + Model: "test-model1", + ResolvedTargetModel: "test-model1", + Prompt: "aaaaaa", + } + ctx := types.NewSchedulingContext(context.Background(), req1, pods) + plugin.PreSchedule(ctx) + t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + // Input size is 6, hash block size is 4, the last 2 characters are ignored. + // Total hashes = 2 (the first one is for the model) + assert.Equal(t, 2, len(ctx.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(ctx.PrefixCacheServers), "there shouldn't be any cached servers") + + // Updated to use the new Score method signature + scores := plugin.Score(ctx, pods) + assert.Equal(t, float64(0), scores[pod1], "score for pod1") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + // Simulate pod1 was picked. + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) + + // Second request doesn't share any prefix with first one. It should be added to the cache but + // the pod score should be 0. + req2 := &types.LLMRequest{ + Model: "test-model2", + ResolvedTargetModel: "test-model2", + Prompt: "bbbbbb", + } + ctx = types.NewSchedulingContext(context.Background(), req2, pods) + plugin.PreSchedule(ctx) + t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + // Input size is 6, hash block size is 4, the last 2 characters are ignored. + // Total hashes = 2 (the first one is for the model) + assert.Equal(t, 2, len(ctx.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(ctx.PrefixCacheServers), "there shouldn't be any cached servers") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, float64(0), scores[pod1], "score for pod1") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + // Simulate pod2 was picked. + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod2}) + + // Third request shares partial prefix with first one. + req3 := &types.LLMRequest{ + Model: "test-model1", + ResolvedTargetModel: "test-model1", + Prompt: "aaaabbbb", + } + ctx = types.NewSchedulingContext(context.Background(), req3, pods) + plugin.PreSchedule(ctx) + t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + // Input size is 8, hash block size is 4, so 2 hashes will be calculated. + // Total hashes = 3 (the first one is for the model) + assert.Equal(t, 3, len(ctx.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 1, len(ctx.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) + + // 4th request is same as req3 except the model is different, still no match. + req4 := &types.LLMRequest{ + Model: "test-model-new", + ResolvedTargetModel: "test-model-new", + Prompt: "aaaabbbb", + } + ctx = types.NewSchedulingContext(context.Background(), req4, pods) + plugin.PreSchedule(ctx) + t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + // Input size is 8, hash block size is 4, so 2 hashes will be calculated. + // Total hashes = 3 (the first one is for the model) + assert.Equal(t, 3, len(ctx.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(ctx.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, float64(0), scores[pod1], "score for pod1") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) + + // 5th request shares partial prefix with 3rd one. + req5 := &types.LLMRequest{ + Model: "test-model1", + ResolvedTargetModel: "test-model1", + Prompt: "aaaabbbbcccc", + } + ctx = types.NewSchedulingContext(context.Background(), req5, pods) + plugin.PreSchedule(ctx) + t.Logf("Hashes %+v, cached servers: %+v", ctx.PrefixHashes, ctx.PrefixCacheServers) + // Input size is 12, hash block size is 4, so 3 hashes will be calculated. + // Total hashes = 4 (the first one is for the model) + assert.Equal(t, 4, len(ctx.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 1, len(ctx.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) +} diff --git a/pkg/epp/scheduling/scheduler_v2.go b/pkg/epp/scheduling/scheduler_v2.go new file mode 100644 index 000000000..7a3da3b3a --- /dev/null +++ b/pkg/epp/scheduling/scheduler_v2.go @@ -0,0 +1,62 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package scheduling implements request scheduling algorithms. +package scheduling + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +func NewSchedulerV2(datastore Datastore, prefixConfig prefix.Config) *Scheduler { + prefixPlugin := prefix.New(prefixConfig) + queuePlugin := &scorer.QueueScorer{} + kvCachePlugin := &scorer.KVCacheScorer{} + configV2 := &SchedulerConfig{ + PreSchedulePlugins: []plugins.PreSchedule{prefixPlugin}, + PostSchedulePlugins: []plugins.PostSchedule{prefixPlugin}, + Scorers: map[plugins.Scorer]int{ + prefixPlugin: 3, + queuePlugin: 1, + kvCachePlugin: 1, + }, + Filters: []plugins.Filter{&sheddableRequestFilterV2{}}, + Picker: &picker.MaxScorePicker{}, + } + return NewSchedulerWithConfig(datastore, configV2) +} + +type sheddableRequestFilterV2 struct { +} + +func (p *sheddableRequestFilterV2) Name() string { + return "sheddableRequestFilterV2" +} + +func (p *sheddableRequestFilterV2) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + if ctx.Req.Critical { + // Allow all pods to pass through if the request is critical, even if all pods reach their capacity. + return pods + } + + // Only allow pods that have enough capacity to handle the request. + return filter.HasCapacityFilter.Filter(ctx, pods) +} diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 4f69fae0a..f470f288f 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -58,6 +59,10 @@ type SchedulingContext struct { Logger logr.Logger Req *LLMRequest PodsSnapshot []Pod + // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks. + PrefixHashes []BlockHash + // A map of server to its longest prefix cache match length. + PrefixCacheServers map[ServerID]int } func (pm *PodMetrics) String() string { @@ -102,3 +107,12 @@ func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod { type Result struct { TargetPod Pod } + +// BlockHash is a hash of the block of request body. +type BlockHash uint64 + +type ServerID types.NamespacedName + +func (s ServerID) String() string { + return types.NamespacedName(s).String() +} From d8ba48dd9aec03f385a12a753a49df8b18ac1cbe Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Fri, 2 May 2025 21:57:50 +0000 Subject: [PATCH 2/2] Replace scheduler v2 with config v2 --- cmd/epp/main.go | 24 ++++++++++++------- .../{scheduler_v2.go => config_v2.go} | 19 +++++++++------ pkg/epp/scheduling/plugins/prefix/plugin.go | 4 ++-- 3 files changed, 30 insertions(+), 17 deletions(-) rename pkg/epp/scheduling/{scheduler_v2.go => config_v2.go} (85%) diff --git a/cmd/epp/main.go b/cmd/epp/main.go index e10572944..a4256054d 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -111,19 +111,26 @@ var ( setupLog = ctrl.Log.WithName("setup") // Environment variables - schedulerV2 = envutil.GetEnvString("EXPERIMENTAL_USE_SCHEDULE_V2", "false", setupLog) - prefixCacheConfig = loadPrefixCacheConfig() + schedulerV2 = envutil.GetEnvString("EXPERIMENTAL_USE_SCHEDULER_V2", "false", setupLog) ) func loadPrefixCacheConfig() prefix.Config { - // logger := zap.New(zap.RawZapOpts(uberzap.AddCaller())) - // log.SetLogger(logger) baseLogger := log.Log.WithName("env-config") return prefix.Config{ - HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultCacheBlockSize, baseLogger), + HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger), MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger), - LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_MAX_CACHE_SIZE_MB", prefix.DefaultLRUIndexerCapacity, baseLogger), + LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger), + } +} + +func loadSchedulingScorerWeights() scheduling.ScorerWeights { + baseLogger := log.Log.WithName("env-config") + + return scheduling.ScorerWeights{ + Prefix: envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", 3, baseLogger), + Queue: envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", 2, baseLogger), + KVCache: envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", 1, baseLogger), } } @@ -191,8 +198,9 @@ func run() error { scheduler := scheduling.NewScheduler(datastore) if schedulerV2 == "true" { - setupLog.Info("Creating scheduler with prefixCache plugin", "prefix cache config", prefixCacheConfig) - scheduler = scheduling.NewSchedulerV2(datastore, prefixCacheConfig) + schedConfig := scheduling.CreateConfig(loadSchedulingScorerWeights(), loadPrefixCacheConfig()) + setupLog.Info("Creating scheduler", "config", *schedConfig) + scheduler = scheduling.NewSchedulerWithConfig(datastore, schedConfig) } serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, diff --git a/pkg/epp/scheduling/scheduler_v2.go b/pkg/epp/scheduling/config_v2.go similarity index 85% rename from pkg/epp/scheduling/scheduler_v2.go rename to pkg/epp/scheduling/config_v2.go index 7a3da3b3a..4992de637 100644 --- a/pkg/epp/scheduling/scheduler_v2.go +++ b/pkg/epp/scheduling/config_v2.go @@ -26,27 +26,32 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) -func NewSchedulerV2(datastore Datastore, prefixConfig prefix.Config) *Scheduler { +func CreateConfig(weights ScorerWeights, prefixConfig prefix.Config) *SchedulerConfig { prefixPlugin := prefix.New(prefixConfig) queuePlugin := &scorer.QueueScorer{} kvCachePlugin := &scorer.KVCacheScorer{} - configV2 := &SchedulerConfig{ + config := &SchedulerConfig{ PreSchedulePlugins: []plugins.PreSchedule{prefixPlugin}, PostSchedulePlugins: []plugins.PostSchedule{prefixPlugin}, Scorers: map[plugins.Scorer]int{ - prefixPlugin: 3, - queuePlugin: 1, - kvCachePlugin: 1, + prefixPlugin: weights.Prefix, + queuePlugin: weights.Queue, + kvCachePlugin: weights.KVCache, }, Filters: []plugins.Filter{&sheddableRequestFilterV2{}}, Picker: &picker.MaxScorePicker{}, } - return NewSchedulerWithConfig(datastore, configV2) + return config } -type sheddableRequestFilterV2 struct { +type ScorerWeights struct { + Prefix int + Queue int + KVCache int } +type sheddableRequestFilterV2 struct{} + func (p *sheddableRequestFilterV2) Name() string { return "sheddableRequestFilterV2" } diff --git a/pkg/epp/scheduling/plugins/prefix/plugin.go b/pkg/epp/scheduling/plugins/prefix/plugin.go index 2e748af82..39ccf886e 100644 --- a/pkg/epp/scheduling/plugins/prefix/plugin.go +++ b/pkg/epp/scheduling/plugins/prefix/plugin.go @@ -33,7 +33,7 @@ const ( // We optimistically search more than one to give more candidates for the scheduler to choose. DefaultNumServersToMatch = 2 // vLLM default token block size is 16, and a good guess of average characters per token is 4. - DefaultCacheBlockSize = 64 + DefaultHashBlockSize = 64 DefaultMaxPrefixBlocks = 128 // Assume each request reaches DefaultMaxPrefixBlocks = 128, and each BlockHash is cached onto 2 // servers due to load balancing, then it requires 256 entries per request. @@ -60,7 +60,7 @@ type Config struct { } var DefaultConfig = Config{ - HashBlockSize: DefaultCacheBlockSize, + HashBlockSize: DefaultHashBlockSize, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, LRUIndexerCapacity: DefaultLRUIndexerCapacity, }