Skip to content

Commit 05081ee

Browse files
WiP fixes
1 parent 467ac7c commit 05081ee

File tree

6 files changed

+63
-118
lines changed

6 files changed

+63
-118
lines changed

pkg/bbr/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Body-Based Routing
22
This package provides an extension that can be deployed to write the `model`
3-
HTTP body parameter as a header (X-Gateway-Model-Name) so as to enable routing capabilities on the
3+
HTTP body parameter as a header (`X-Gateway-Model-Name`) so as to enable routing capabilities on the
44
model name.
55

66
As per OpenAI spec, it is standard for the model name to be included in the

pkg/bbr/framework/interfaces.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ limitations under the License.
1717
package framework
1818

1919
import (
20-
"github.com/openai/openai-go/v3"
2120
bbrplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/plugins"
2221
)
2322

@@ -48,12 +47,8 @@ type PluginsChain interface {
4847
AddPluginAtInd(typeKey string, i int, r PluginRegistry) error //only affects the instance of the plugin chain
4948
GetPlugin(index int, registry PluginRegistry) (bbrplugins.BBRPlugin, error) //retrieves i-th plugin as defined in the chain from the registry
5049
Length() int
51-
ParseChatCompletion(data []byte) (openai.ChatCompletionNewParams, error)
52-
ParseCompletion(data []byte) (openai.CompletionNewParams, error)
53-
GetSharedChatCompletion() openai.ChatCompletionNewParams
54-
GetSharedCompletion() openai.CompletionNewParams
55-
GetSharedMemory(which string) interface{}
56-
Run(bodyBytes []byte, metaDataKeys []string, registry PluginRegistry) ([]byte, map[string]string, error) //return potentially mutated body and all headers map safely merged
50+
GetPlugins() []string
51+
Run(bodyBytes []byte, metaDataKeys []string, registry PluginRegistry) (map[string]string, []byte, error) //return potentially mutated body and all headers map safely merged
5752
String() string
5853
}
5954

pkg/bbr/framework/registry.go

Lines changed: 8 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,13 @@ import (
2020
"fmt"
2121
"slices"
2222

23-
"github.com/openai/openai-go/v3" //imported as openai
2423
bbrplugins "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/plugins"
2524
)
2625

2726
// -------------------- INTERFACES -----------------------------------------------------------------------
2827
// Interfaces are defined in "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/framework/interfaces.go"
2928

3029
// --------------------- PluginRegistry implementation ---------------------------------------------------
31-
// --------------------- Constructors lookup map ---------------------------------------------------------
32-
// This lookup map should be extended when a new concrete plugin implementation is added
33-
/* var pluginConstructors = map[string]PluginFactoryFunc{
34-
"simple-model-extractor": NewSimpleModelExtractor,
35-
"semantic-model-selector": NewSemanticModelSelector,
36-
"bad-words-blocker": NewBadWordsBlocker,
37-
"pid-disclosure-blocker": NewPidDisclosureBlocker,
38-
"extract-metadata-bykeys": NewExtractMetadataByKeys,
39-
} */
40-
// Registration: of constructors for implementations is done in the main's initPlugins based on the ConfigMap
4130

4231
// pluginRegistry implements PluginRegistry
4332
type pluginRegistry struct {
@@ -123,15 +112,6 @@ func (r *pluginRegistry) UnregisterFactory(typeKey string) error {
123112
return fmt.Errorf("plugin (%s) not found", typeKey)
124113
}
125114

126-
// Removes a plugin instance by type key
127-
func (r *pluginRegistry) UnregisterPlugin(typeKey string) error {
128-
if _, ok := r.plugins[typeKey]; ok {
129-
delete(r.plugins, typeKey)
130-
return nil
131-
}
132-
return fmt.Errorf("plugin (%s) not found", typeKey)
133-
}
134-
135115
// ListPlugins lists all registered plugins
136116
func (r *pluginRegistry) ListPlugins() []string {
137117
typeKeys := make([]string, 0, len(r.plugins))
@@ -160,12 +140,6 @@ func (r *pluginRegistry) GetPlugins() map[string]bbrplugins.BBRPlugin {
160140
return r.plugins
161141
}
162142

163-
// Clear removes all registered factories and plugins
164-
func (r *pluginRegistry) Clear() {
165-
r.pluginsFactory = make(map[string]PluginFactoryFunc)
166-
r.plugins = make(map[string]bbrplugins.BBRPlugin)
167-
}
168-
169143
// Checks for presense of a factory in this registry
170144
func (r *pluginRegistry) ContainsFactory(typeKey string) bool {
171145
_, exists := r.pluginsFactory[typeKey]
@@ -186,17 +160,13 @@ func (r *pluginRegistry) String() string {
186160

187161
// PluginsChain is a sequence of plugins to be executed in order inside the ext_proc server
188162
type pluginsChain struct {
189-
plugins []string
190-
sharedChatCompletion openai.ChatCompletionNewParams //will be nil if an instance of pluginsChain does not contain a plugin that requires full parsing
191-
sharedCompletion openai.CompletionNewParams //likewise
163+
plugins []string
192164
}
193165

194166
// NewPluginsChain creates a new PluginsChain instance
195167
func NewPluginsChain() PluginsChain {
196168
return &pluginsChain{
197-
plugins: []string{},
198-
sharedChatCompletion: openai.ChatCompletionNewParams{},
199-
sharedCompletion: openai.CompletionNewParams{},
169+
plugins: []string{},
200170
}
201171
}
202172

@@ -212,17 +182,6 @@ func (pc *pluginsChain) AddPlugin(typeKey string, r PluginRegistry) error {
212182
return nil
213183
}
214184

215-
// DeletePlugin deletes a plugin from the chain
216-
func (pc *pluginsChain) DeletePlugin(p string) error {
217-
for i := range len(pc.plugins) {
218-
if pc.plugins[i] == p {
219-
pc.plugins = append(pc.plugins[:i], pc.plugins[i+1:]...)
220-
return nil
221-
}
222-
}
223-
return fmt.Errorf("plugin %s not found in chain", p)
224-
}
225-
226185
// GetPlugin retrieves the next plugin in the chain by index
227186
func (pc *pluginsChain) GetPlugin(index int, r PluginRegistry) (bbrplugins.BBRPlugin, error) {
228187
if index < 0 || index >= len(pc.plugins) {
@@ -255,36 +214,8 @@ func (pc *pluginsChain) AddPluginAtInd(typeKey string, i int, r PluginRegistry)
255214
return nil
256215
}
257216

258-
func (pc *pluginsChain) ParseChatCompletion(data []byte) (openai.ChatCompletionNewParams, error) {
259-
if err := pc.sharedChatCompletion.UnmarshalJSON(data); err != nil {
260-
return pc.sharedChatCompletion, err
261-
}
262-
return pc.sharedChatCompletion, nil
263-
}
264-
265-
func (pc *pluginsChain) ParseCompletion(data []byte) (openai.CompletionNewParams, error) {
266-
if err := pc.sharedCompletion.UnmarshalJSON(data); err != nil {
267-
return pc.sharedCompletion, err
268-
}
269-
return pc.sharedCompletion, nil
270-
}
271-
272-
func (pc *pluginsChain) GetSharedChatCompletion() openai.ChatCompletionNewParams {
273-
return pc.sharedChatCompletion
274-
}
275-
276-
func (pc *pluginsChain) GetSharedCompletion() openai.CompletionNewParams {
277-
return pc.sharedCompletion
278-
}
279-
280-
func (pc *pluginsChain) GetSharedMemory(which string) interface{} {
281-
if which == "/v1/completions" {
282-
return pc.sharedCompletion
283-
}
284-
if which == "/v1/chat/completions" {
285-
return pc.sharedChatCompletion
286-
}
287-
return nil
217+
func (pc *pluginsChain) GetPlugins() []string {
218+
return pc.plugins
288219
}
289220

290221
// MergeMaps copies all key/value pairs from src into dst and returns dst.
@@ -315,7 +246,7 @@ func (pc *pluginsChain) Run(
315246
bodyBytes []byte,
316247
metaDataKeys []string,
317248
r PluginRegistry,
318-
) (mutateBodyBytes []byte, headers map[string]string, err error) {
249+
) (headers map[string]string, mutateBodyBytes []byte, err error) {
319250

320251
allHeaders := make(map[string]string)
321252
mutatedBodyBytes := bodyBytes
@@ -327,20 +258,20 @@ func (pc *pluginsChain) Run(
327258
metExtPlugin, err := r.GetPlugin(pluginType)
328259

329260
if err != nil {
330-
return bodyBytes, allHeaders, err
261+
return allHeaders, bodyBytes, err
331262
}
332263

333264
// The plugin i in the chain receives the (potentially mutated) body from plugin i-1 in the chain
334265
headers, mutatedBodyBytes, err := metExtPlugin.Execute(mutatedBodyBytes, metaDataKeys)
335266

336267
if err != nil {
337-
return mutatedBodyBytes, headers, err
268+
return headers, mutatedBodyBytes, err
338269
}
339270

340271
//note that the existing overlapping keys are NOT over-written by merge
341272
MergeMaps(allHeaders, headers)
342273
}
343-
return
274+
return allHeaders, mutatedBodyBytes, nil
344275
}
345276

346277
func (pc *pluginsChain) String() string {

pkg/bbr/handlers/request.go

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package handlers
1818

1919
import (
2020
"context"
21+
"strings"
2122

2223
basepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2324
eppb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
@@ -31,45 +32,34 @@ import (
3132
// HandleRequestBody handles request bodies.
3233
func (s *Server) HandleRequestBody(ctx context.Context, requestBodyBytes []byte) ([]*eppb.ProcessingResponse, error) {
3334
logger := log.FromContext(ctx)
35+
var ret []*eppb.ProcessingResponse
3436

35-
allHeaders, mutatedBodyBytes, _ := helpers.RunPluginsChain( //TODO
36-
ctx,
37-
logger,
38-
requestBodyBytes,
39-
s.requestChain,
40-
s.registry,
41-
s.metaDataKeys)
42-
43-
//At this point, we have all the headers and a mutated body (note that actually, the body might not be mutated, but we do not care)
37+
allHeaders, mutatedBodyBytes, err := s.requestChain.Run(requestBodyBytes, s.metaDataKeys, s.registry)
4438

45-
var ret []*eppb.ProcessingResponse
39+
if err != nil {
40+
//TODO: add metric in metrics.go to count "other errors"
41+
logger.V(logutil.DEFAULT).Info("error processing body", "error", err)
42+
ret, _ := buildEmptyResponsesForMissingModel(s.streaming, requestBodyBytes)
43+
return ret, nil
44+
}
4645

47-
// process headers
48-
Model := allHeaders[bbrplugins.ModelHeader] //it is required that the ModelHeader is always set (i.e., that there always exist requestPluginsChain with at least one plugin that sets the model header)
46+
model, ok := allHeaders[bbrplugins.ModelHeader]
4947

50-
logger.V(logutil.DEFAULT).Info("model extracted from request body", "model", Model)
48+
if !ok {
49+
//TODO: add metric in metrics.go to count "other errors"
50+
logger.V(logutil.DEFAULT).Info("manadatory header X-Gateway-Model-Name value is undetermined")
51+
ret, _ := buildEmptyResponsesForMissingModel(s.streaming, requestBodyBytes)
52+
return ret, nil
53+
}
5154

52-
if Model == "" {
55+
if strings.TrimSpace(model) == "" {
5356
metrics.RecordModelNotInBodyCounter()
54-
55-
if s.streaming {
56-
ret = append(ret, &eppb.ProcessingResponse{
57-
Response: &eppb.ProcessingResponse_RequestHeaders{
58-
RequestHeaders: &eppb.HeadersResponse{},
59-
},
60-
})
61-
ret = addStreamedBodyResponse(ret, requestBodyBytes)
62-
return ret, nil
63-
} else {
64-
ret = append(ret, &eppb.ProcessingResponse{
65-
Response: &eppb.ProcessingResponse_RequestBody{
66-
RequestBody: &eppb.BodyResponse{},
67-
},
68-
})
69-
}
57+
ret, _ := buildEmptyResponsesForMissingModel(s.streaming, requestBodyBytes)
7058
return ret, nil
7159
}
7260

61+
logger.V(logutil.DEFAULT).Info("model extracted from request body", "model", model)
62+
7363
metrics.RecordSuccessCounter()
7464

7565
if s.streaming {
@@ -83,7 +73,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestBodyBytes []byte)
8373
{
8474
Header: &basepb.HeaderValue{
8575
Key: bbrplugins.ModelHeader,
86-
RawValue: []byte(Model),
76+
RawValue: []byte(model),
8777
},
8878
},
8979
},
@@ -111,7 +101,7 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestBodyBytes []byte)
111101
{
112102
Header: &basepb.HeaderValue{
113103
Key: bbrplugins.ModelHeader,
114-
RawValue: []byte(Model),
104+
RawValue: []byte(model),
115105
},
116106
},
117107
},
@@ -168,3 +158,31 @@ func (s *Server) HandleRequestTrailers(trailers *eppb.HttpTrailers) ([]*eppb.Pro
168158
},
169159
}, nil
170160
}
161+
162+
// buildEmptyResponsesForMissingModel is a local helper that returns the appropriate empty responses
163+
// for the "model not found" branch depending on streaming mode.
164+
// It is also used to create empty responses in case of other errors related to running plugins on the body
165+
// This is not very clean and MUST be segregated in the future.
166+
// Corresponding metrics should be defined to make different errors observable
167+
func buildEmptyResponsesForMissingModel(streaming bool, requestBodyBytes []byte) ([]*eppb.ProcessingResponse, error) {
168+
var ret []*eppb.ProcessingResponse
169+
170+
if streaming {
171+
// Emit empty headers response, then stream body unchanged.
172+
ret = append(ret, &eppb.ProcessingResponse{
173+
Response: &eppb.ProcessingResponse_RequestHeaders{
174+
RequestHeaders: &eppb.HeadersResponse{},
175+
},
176+
})
177+
ret = addStreamedBodyResponse(ret, requestBodyBytes)
178+
return ret, nil
179+
}
180+
181+
// Non-streaming: emit empty body response.
182+
ret = append(ret, &eppb.ProcessingResponse{
183+
Response: &eppb.ProcessingResponse_RequestBody{
184+
RequestBody: &eppb.BodyResponse{},
185+
},
186+
})
187+
return ret, nil
188+
}

pkg/bbr/handlers/server_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ func TestProcessRequestBody(t *testing.T) {
128128
}
129129

130130
//Initialize PluginRegistry and request/response PluginsChain instances
131+
//Change to hermetic
131132
registry, requestChain, responseChain, metaDataKeys, _ := bbrutils.InitPlugins()
132133

133134
for _, tc := range cases {

pkg/bbr/plugins/simple_model_extractor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import (
3232

3333
type simpleModelExtractor struct { //implements the MetadataExtractor interface
3434
typedName plugins.TypedName
35-
requiresFullParsing bool
35+
requiresFullParsing bool //this field will be used to determine whether shared struct should be created in this chain
3636
}
3737

3838
// NewSimpleModelExtractor is a factory that constructs SimpleModelExtractor plugin

0 commit comments

Comments
 (0)