Skip to content

Commit 07d2e2a

Browse files
committed
Add streaming response process.
1 parent eec941c commit 07d2e2a

File tree

7 files changed

+400
-123
lines changed

7 files changed

+400
-123
lines changed

pkg/epp/handlers/response.go

Lines changed: 30 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@ limitations under the License.
1717
package handlers
1818

1919
import (
20+
"bytes"
2021
"context"
21-
"encoding/json"
22-
"fmt"
23-
"strings"
2422

2523
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2624
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2725
"sigs.k8s.io/controller-runtime/pkg/log"
2826

2927
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
28+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3029
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3130
)
3231

@@ -36,49 +35,56 @@ const (
3635
)
3736

3837
// HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling.
39-
func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, response map[string]any) (*RequestContext, error) {
38+
func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, body []byte) (*RequestContext, error) {
4039
logger := log.FromContext(ctx)
41-
responseBytes, err := json.Marshal(response)
40+
llmResponse, err := types.NewLLMResponseFromBytes(body)
4241
if err != nil {
43-
return reqCtx, fmt.Errorf("error marshalling responseBody - %w", err)
44-
}
45-
if response["usage"] != nil {
46-
usg := response["usage"].(map[string]any)
47-
usage := Usage{
48-
PromptTokens: int(usg["prompt_tokens"].(float64)),
49-
CompletionTokens: int(usg["completion_tokens"].(float64)),
50-
TotalTokens: int(usg["total_tokens"].(float64)),
42+
logger.Error(err, "failed to create LLMResponse from bytes")
43+
} else {
44+
reqCtx.SchedulingResponse = llmResponse
45+
if usage := reqCtx.SchedulingResponse.Usage(); usage != nil {
46+
reqCtx.Usage = usage
47+
logger.V(logutil.VERBOSE).Info("Response generated", "usage", usage)
5148
}
52-
reqCtx.Usage = usage
53-
logger.V(logutil.VERBOSE).Info("Response generated", "usage", reqCtx.Usage)
5449
}
55-
reqCtx.ResponseSize = len(responseBytes)
50+
reqCtx.ResponseSize = len(body)
5651
// ResponseComplete is to indicate the response is complete. In non-streaming
5752
// case, it will be set to be true once the response is processed; in
5853
// streaming case, it will be set to be true once the last chunk is processed.
5954
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/178)
6055
// will add the processing for streaming case.
6156
reqCtx.ResponseComplete = true
6257

63-
reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true)
58+
reqCtx.respBodyResp = generateResponseBodyResponses(body, true)
6459

6560
return s.director.HandleResponseBodyComplete(ctx, reqCtx)
6661
}
6762

6863
// The function is to handle streaming response if the modelServer is streaming.
69-
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
64+
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, streamBody []byte) {
7065
logger := log.FromContext(ctx)
7166
_, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx)
7267
if err != nil {
7368
logger.Error(err, "error in HandleResponseBodyStreaming")
7469
}
75-
if strings.Contains(responseText, streamingEndMsg) {
70+
}
71+
72+
func (s *StreamingServer) HandleResponseBodyModelStreamingComplete(ctx context.Context, reqCtx *RequestContext, streamBody []byte) {
73+
logger := log.FromContext(ctx)
74+
if bytes.Contains(streamBody, []byte(streamingEndMsg)) {
7675
reqCtx.ResponseComplete = true
77-
resp := parseRespForUsage(ctx, responseText)
78-
reqCtx.Usage = resp.Usage
79-
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
80-
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
81-
_, err := s.director.HandleResponseBodyComplete(ctx, reqCtx)
76+
resp, err := types.NewLLMResponseFromStream(streamBody)
77+
if err != nil {
78+
logger.Error(err, "error in converting stream response to LLMResponse.")
79+
} else {
80+
reqCtx.SchedulingResponse = resp
81+
if usage := resp.Usage(); usage != nil {
82+
reqCtx.Usage = usage
83+
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.PromptTokens)
84+
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.CompletionTokens)
85+
}
86+
}
87+
_, err = s.director.HandleResponseBodyComplete(ctx, reqCtx)
8288
if err != nil {
8389
logger.Error(err, "error in HandleResponseBodyComplete")
8490
}
@@ -153,41 +159,6 @@ func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*con
153159
return headers
154160
}
155161

156-
// Example message if "stream_options": {"include_usage": "true"} is included in the request:
157-
// data: {"id":"...","object":"text_completion","created":1739400043,"model":"food-review-0","choices":[],
158-
// "usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}
159-
//
160-
// data: [DONE]
161-
//
162-
// Noticed that vLLM returns two entries in one response.
163-
// We need to strip the `data:` prefix and next Data: [DONE] from the message to fetch response data.
164-
//
165-
// If include_usage is not included in the request, `data: [DONE]` is returned separately, which
166-
// indicates end of streaming.
167-
func parseRespForUsage(ctx context.Context, responseText string) ResponseBody {
168-
response := ResponseBody{}
169-
logger := log.FromContext(ctx)
170-
171-
lines := strings.Split(responseText, "\n")
172-
for _, line := range lines {
173-
if !strings.HasPrefix(line, streamingRespPrefix) {
174-
continue
175-
}
176-
content := strings.TrimPrefix(line, streamingRespPrefix)
177-
if content == "[DONE]" {
178-
continue
179-
}
180-
181-
byteSlice := []byte(content)
182-
if err := json.Unmarshal(byteSlice, &response); err != nil {
183-
logger.Error(err, "unmarshaling response body")
184-
continue
185-
}
186-
}
187-
188-
return response
189-
}
190-
191162
type ResponseBody struct {
192163
Usage Usage `json:"usage"`
193164
}

pkg/epp/handlers/response_test.go

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ package handlers
1818

1919
import (
2020
"context"
21-
"encoding/json"
2221
"testing"
2322

2423
"github.com/google/go-cmp/cmp"
2524

2625
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
26+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2727
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2828
)
2929

@@ -52,12 +52,33 @@ const (
5252
}
5353
`
5454

55-
streamingBodyWithoutUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":null}
56-
`
55+
streamingBodyWithoutUsage = `
56+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]}
5757
58-
streamingBodyWithUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}
59-
data: [DONE]
60-
`
58+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]}
59+
60+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]}
61+
62+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
63+
64+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":null}
65+
66+
data: [DONE]
67+
`
68+
69+
streamingBodyWithUsage = `
70+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]}
71+
72+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]}
73+
74+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]}
75+
76+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
77+
78+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":7,"total_tokens":12}}
79+
80+
data: [DONE]
81+
`
6182
)
6283

6384
type mockDirector struct{}
@@ -88,13 +109,13 @@ func TestHandleResponseBody(t *testing.T) {
88109
name string
89110
body []byte
90111
reqCtx *RequestContext
91-
want Usage
112+
want *types.Usage
92113
wantErr bool
93114
}{
94115
{
95116
name: "success",
96117
body: []byte(body),
97-
want: Usage{
118+
want: &types.Usage{
98119
PromptTokens: 11,
99120
TotalTokens: 111,
100121
CompletionTokens: 100,
@@ -110,12 +131,7 @@ func TestHandleResponseBody(t *testing.T) {
110131
if reqCtx == nil {
111132
reqCtx = &RequestContext{}
112133
}
113-
var responseMap map[string]any
114-
marshalErr := json.Unmarshal(test.body, &responseMap)
115-
if marshalErr != nil {
116-
t.Error(marshalErr, "Error unmarshaling request body")
117-
}
118-
_, err := server.HandleResponseBody(ctx, reqCtx, responseMap)
134+
_, err := server.HandleResponseBody(ctx, reqCtx, test.body)
119135
if err != nil {
120136
if !test.wantErr {
121137
t.Fatalf("HandleResponseBody returned unexpected error: %v, want %v", err, test.wantErr)
@@ -136,7 +152,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
136152
name string
137153
body string
138154
reqCtx *RequestContext
139-
want Usage
155+
want *types.Usage
140156
wantErr bool
141157
}{
142158
{
@@ -155,10 +171,10 @@ func TestHandleStreamedResponseBody(t *testing.T) {
155171
modelServerStreaming: true,
156172
},
157173
wantErr: false,
158-
want: Usage{
159-
PromptTokens: 7,
160-
TotalTokens: 17,
161-
CompletionTokens: 10,
174+
want: &types.Usage{
175+
PromptTokens: 5,
176+
TotalTokens: 12,
177+
CompletionTokens: 7,
162178
},
163179
},
164180
}
@@ -171,7 +187,8 @@ func TestHandleStreamedResponseBody(t *testing.T) {
171187
if reqCtx == nil {
172188
reqCtx = &RequestContext{}
173189
}
174-
server.HandleResponseBodyModelStreaming(ctx, reqCtx, test.body)
190+
server.HandleResponseBodyModelStreaming(ctx, reqCtx, []byte(test.body))
191+
server.HandleResponseBodyModelStreamingComplete(ctx, reqCtx, []byte(test.body))
175192

176193
if diff := cmp.Diff(test.want, reqCtx.Usage); diff != "" {
177194
t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff)

pkg/epp/handlers/server.go

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,15 @@ type RequestContext struct {
8585
RequestReceivedTimestamp time.Time
8686
ResponseCompleteTimestamp time.Time
8787
RequestSize int
88-
Usage Usage
88+
Usage *schedulingtypes.Usage
8989
ResponseSize int
9090
ResponseComplete bool
9191
ResponseStatusCode string
9292
RequestRunning bool
9393
Request *Request
9494

95-
SchedulingRequest *schedulingtypes.LLMRequest
95+
SchedulingRequest *schedulingtypes.LLMRequest
96+
SchedulingResponse *schedulingtypes.LLMResponse
9697

9798
RequestState StreamRequestState
9899
modelServerStreaming bool
@@ -115,7 +116,6 @@ type Request struct {
115116
}
116117
type Response struct {
117118
Headers map[string]string
118-
Body []byte
119119
}
120120
type StreamRequestState int
121121

@@ -268,53 +268,50 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
268268
reqCtx.respHeaderResp = s.generateResponseHeaderResponse(reqCtx)
269269

270270
case *extProcPb.ProcessingRequest_ResponseBody:
271+
body = append(body, v.ResponseBody.Body...)
271272
if reqCtx.modelServerStreaming {
272273
// Currently we punt on response parsing if the modelServer is streaming, and we just passthrough.
273-
274-
responseText := string(v.ResponseBody.Body)
275-
s.HandleResponseBodyModelStreaming(ctx, reqCtx, responseText)
274+
s.HandleResponseBodyModelStreaming(ctx, reqCtx, v.ResponseBody.Body)
276275
if v.ResponseBody.EndOfStream {
277276
loggerTrace.Info("stream completed")
277+
s.HandleResponseBodyModelStreamingComplete(ctx, reqCtx, body)
278278

279279
reqCtx.ResponseCompleteTimestamp = time.Now()
280280
metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
281281
metrics.RecordResponseSizes(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.ResponseSize)
282282
}
283283

284284
reqCtx.respBodyResp = generateResponseBodyResponses(v.ResponseBody.Body, v.ResponseBody.EndOfStream)
285-
} else {
286-
body = append(body, v.ResponseBody.Body...)
287-
288-
// Message is buffered, we can read and decode.
289-
if v.ResponseBody.EndOfStream {
290-
loggerTrace.Info("stream completed")
291-
// Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes.
292-
// We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message.
293-
// Using the standard 'err' var will send an immediate error response back to the caller.
294-
var responseErr error
295-
responseErr = json.Unmarshal(body, &responseBody)
296-
if responseErr != nil {
297-
if logger.V(logutil.DEBUG).Enabled() {
298-
logger.V(logutil.DEBUG).Error(responseErr, "Error unmarshalling request body", "body", string(body))
299-
} else {
300-
logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshalling request body", "body", string(body))
301-
}
302-
reqCtx.respBodyResp = generateResponseBodyResponses(body, true)
303-
break
285+
} else if v.ResponseBody.EndOfStream {
286+
loggerTrace.Info("stream completed")
287+
// Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes.
288+
// We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message.
289+
// Using the standard 'err' var will send an immediate error response back to the caller.
290+
var responseErr error
291+
responseErr = json.Unmarshal(body, &responseBody)
292+
if responseErr != nil {
293+
if logger.V(logutil.DEBUG).Enabled() {
294+
logger.V(logutil.DEBUG).Error(responseErr, "Error unmarshalling request body", "body", string(body))
295+
} else {
296+
logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshalling request body", "body", string(body))
304297
}
298+
reqCtx.respBodyResp = generateResponseBodyResponses(body, true)
299+
break
300+
}
305301

306-
reqCtx.Response.Body = body
307-
reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody)
308-
if responseErr != nil {
309-
if logger.V(logutil.DEBUG).Enabled() {
310-
logger.V(logutil.DEBUG).Error(responseErr, "Failed to process response body", "request", req)
311-
} else {
312-
logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response body")
313-
}
314-
} else if reqCtx.ResponseComplete {
315-
reqCtx.ResponseCompleteTimestamp = time.Now()
316-
metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
317-
metrics.RecordResponseSizes(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.ResponseSize)
302+
reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, body)
303+
if responseErr != nil {
304+
if logger.V(logutil.DEBUG).Enabled() {
305+
logger.V(logutil.DEBUG).Error(responseErr, "Failed to process response body", "request", req)
306+
} else {
307+
logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response body")
308+
}
309+
} else if reqCtx.ResponseComplete {
310+
reqCtx.ResponseCompleteTimestamp = time.Now()
311+
metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
312+
metrics.RecordResponseSizes(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.ResponseSize)
313+
if reqCtx.Usage != nil {
314+
// Response complete does not guarantee the Usage is populated.
318315
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.Usage.PromptTokens)
319316
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.Usage.CompletionTokens)
320317
}

pkg/epp/requestcontrol/director.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package requestcontrol
2020

2121
import (
2222
"context"
23+
"errors"
2324
"fmt"
2425
"math/rand"
2526
"net"
@@ -292,12 +293,11 @@ func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handl
292293
requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey]
293294
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID)
294295
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
295-
llmResponse, err := schedulingtypes.NewLLMResponseFromBytes(reqCtx.Response.Body)
296-
if err != nil {
297-
logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.")
296+
if reqCtx.SchedulingResponse == nil {
297+
err := errors.New("nil scheduling response from reqCtx")
298298
return reqCtx, err
299299
}
300-
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, llmResponse, reqCtx.TargetPod)
300+
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, reqCtx.SchedulingResponse, reqCtx.TargetPod)
301301

302302
logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
303303
return reqCtx, nil

0 commit comments

Comments
 (0)