Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 48 additions & 44 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,55 +265,52 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
lastEmittedCount = len(partialResults)
}
} else {
// Try JSON tool call parsing for streaming
// Check if the result looks like JSON tool calls
// Try JSON tool call parsing for streaming.
// Only emit NEW tool calls (same guard as XML parser above).
jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true)
if jsonErr == nil && len(jsonResults) > 0 {
// Check if these are tool calls (have "name" and optionally "arguments")
for _, jsonObj := range jsonResults {
if name, ok := jsonObj["name"].(string); ok && name != "" {
// This looks like a tool call
args := "{}"
if argsVal, ok := jsonObj["arguments"]; ok {
if argsStr, ok := argsVal.(string); ok {
args = argsStr
} else {
argsBytes, _ := json.Marshal(argsVal)
args = string(argsBytes)
}
if jsonErr == nil && len(jsonResults) > lastEmittedCount {
for i := lastEmittedCount; i < len(jsonResults); i++ {
jsonObj := jsonResults[i]
name, ok := jsonObj["name"].(string)
if !ok || name == "" {
continue
}
args := "{}"
if argsVal, ok := jsonObj["arguments"]; ok {
if argsStr, ok := argsVal.(string); ok {
args = argsStr
} else {
argsBytes, _ := json.Marshal(argsVal)
args = string(argsBytes)
}
// Emit tool call
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model,
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: lastEmittedCount,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
}
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model,
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
},
Index: 0,
FinishReason: nil,
}},
Object: "chat.completion.chunk",
}
select {
case responses <- initialMessage:
default:
}
lastEmittedCount++
},
Index: 0,
FinishReason: nil,
}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
}
lastEmittedCount = len(jsonResults)
}
}
return true
Expand Down Expand Up @@ -426,10 +423,17 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
toolCallID = id
}

if i < lastEmittedCount {
// Already emitted during streaming by the incremental
// JSON/XML parser — skip to avoid duplicate tool calls.
continue
}

// Tool call not yet emitted — send name + args (two chunks).
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Model: req.Model,
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
Expand All @@ -454,7 +458,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
responses <- schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Model: req.Model,
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
Expand Down
41 changes: 41 additions & 0 deletions tests/e2e/mock-backend/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ func (m *MockBackend) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.R
}, nil
}

// Simulate multiple tool calls in a single response (Go-side JSON parser path).
if strings.Contains(in.Prompt, "MULTI_TOOL_CALL") {
return &pb.Reply{
Message: []byte(`{"name": "get_weather", "arguments": {"location": "Rome"}}
{"name": "get_weather", "arguments": {"location": "Paris"}}`),
Tokens: 30,
PromptTokens: 10,
}, nil
}
var response string
toolName := mockToolNameFromRequest(in)
if toolName != "" && !promptHasToolResults(in.Prompt) {
Expand Down Expand Up @@ -219,6 +228,38 @@ func (m *MockBackend) PredictStream(in *pb.PredictOptions, stream pb.Backend_Pre
return nil
}

// Simulate tool calls streamed as whole JSON objects (Go-side parser path).
// Each object is sent as a complete chunk so the incremental parser can
// detect tool calls mid-stream (unlike char-by-char which only parses after
// streaming completes).
if strings.Contains(in.Prompt, "MULTI_TOOL_CALL") {
chunks := []string{
`{"name": "get_weather", "arguments": {"location": "Rome"}}`,
"\n",
`{"name": "get_weather", "arguments": {"location": "Paris"}}`,
}
for i, chunk := range chunks {
if err := stream.Send(&pb.Reply{
Message: []byte(chunk),
Tokens: int32(i + 1),
}); err != nil {
return err
}
}
return nil
}

// Simulate single tool call streamed as whole JSON (Go-side parser path).
if strings.Contains(in.Prompt, "SINGLE_TOOL_CALL") {
if err := stream.Send(&pb.Reply{
Message: []byte(`{"name": "get_weather", "arguments": {"location": "San Francisco"}}`),
Tokens: 1,
}); err != nil {
return err
}
return nil
}

var toStream string
toolName := mockToolNameFromRequest(in)
if toolName != "" && !promptHasToolResults(in.Prompt) {
Expand Down
134 changes: 134 additions & 0 deletions tests/e2e/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,4 +509,138 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() {
})
})
})

// Tests for duplicate tool call emissions during streaming.
// The Go-side incremental JSON parser was emitting the same tool call on
// every streaming token, and the post-streaming default: case re-emitted
// all tool calls again, producing massive duplication.
Describe("Streaming Tool Call Deduplication", Label("ToolDedup"), func() {
// Helper: parse SSE lines and count tool call name/arguments chunks
parseToolCallChunks := func(data []byte) (nameChunks int, argChunks int) {
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
var chunk map[string]any
if err := json.Unmarshal([]byte(strings.TrimPrefix(line, "data: ")), &chunk); err != nil {
continue
}
choices, _ := chunk["choices"].([]any)
if len(choices) == 0 {
continue
}
delta, _ := choices[0].(map[string]any)["delta"].(map[string]any)
if delta == nil {
continue
}
toolCalls, _ := delta["tool_calls"].([]any)
for _, tc := range toolCalls {
tcMap, _ := tc.(map[string]any)
fn, _ := tcMap["function"].(map[string]any)
if fn == nil {
continue
}
if name, _ := fn["name"].(string); name != "" {
nameChunks++
}
if args, _ := fn["arguments"].(string); args != "" {
argChunks++
}
}
}
return
}

Context("Single tool call via Go-side JSON parser", func() {
It("should emit exactly one tool call name without duplicates", func() {
body := `{
"model": "mock-model-autoparser",
"messages": [{"role": "user", "content": "SINGLE_TOOL_CALL"}],
"tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}],
"stream": true
}`
req, err := http.NewRequest("POST", apiURL+"/chat/completions", strings.NewReader(body))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")

httpClient := &http.Client{Timeout: 60 * time.Second}
resp, err := httpClient.Do(req)
Expect(err).ToNot(HaveOccurred())
defer resp.Body.Close()
Expect(resp.StatusCode).To(Equal(200))

data, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())

nameChunks, argChunks := parseToolCallChunks(data)

Expect(nameChunks).To(Equal(1),
"Expected exactly 1 tool call name chunk, got %d. Full SSE:\n%s",
nameChunks, string(data))
Expect(argChunks).To(BeNumerically(">=", 1),
"Expected at least 1 arguments chunk. Full SSE:\n%s", string(data))
})
})

Context("ChatDelta tool calls (regression guard)", func() {
It("should emit exactly one tool call name per tool", func() {
body := `{
"model": "mock-model-autoparser",
"messages": [{"role": "user", "content": "AUTOPARSER_TOOL_CALL"}],
"tools": [{"type": "function", "function": {"name": "search_collections", "description": "Search documents", "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}}}],
"stream": true
}`
req, err := http.NewRequest("POST", apiURL+"/chat/completions", strings.NewReader(body))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")

httpClient := &http.Client{Timeout: 60 * time.Second}
resp, err := httpClient.Do(req)
Expect(err).ToNot(HaveOccurred())
defer resp.Body.Close()
Expect(resp.StatusCode).To(Equal(200))

data, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())

nameChunks, _ := parseToolCallChunks(data)

Expect(nameChunks).To(Equal(1),
"Expected exactly 1 tool call name chunk from ChatDeltas, got %d. Full SSE:\n%s",
nameChunks, string(data))
})
})

Context("Multiple tool calls via Go-side JSON parser", func() {
It("should emit exactly two tool call names without duplicates", func() {
body := `{
"model": "mock-model-autoparser",
"messages": [{"role": "user", "content": "MULTI_TOOL_CALL"}],
"tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}],
"stream": true
}`
req, err := http.NewRequest("POST", apiURL+"/chat/completions", strings.NewReader(body))
Expect(err).ToNot(HaveOccurred())
req.Header.Set("Content-Type", "application/json")

httpClient := &http.Client{Timeout: 60 * time.Second}
resp, err := httpClient.Do(req)
Expect(err).ToNot(HaveOccurred())
defer resp.Body.Close()
Expect(resp.StatusCode).To(Equal(200))

data, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred())

nameChunks, argChunks := parseToolCallChunks(data)

Expect(nameChunks).To(Equal(2),
"Expected exactly 2 tool call name chunks (one per tool), got %d. Full SSE:\n%s",
nameChunks, string(data))
Expect(argChunks).To(BeNumerically(">=", 2),
"Expected at least 2 arguments chunks. Full SSE:\n%s", string(data))
})
})
})
})
Loading