diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index b2ec1ac44352..ab715d8b166d 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -265,55 +265,52 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator lastEmittedCount = len(partialResults) } } else { - // Try JSON tool call parsing for streaming - // Check if the result looks like JSON tool calls + // Try JSON tool call parsing for streaming. + // Only emit NEW tool calls (same guard as XML parser above). jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true) - if jsonErr == nil && len(jsonResults) > 0 { - // Check if these are tool calls (have "name" and optionally "arguments") - for _, jsonObj := range jsonResults { - if name, ok := jsonObj["name"].(string); ok && name != "" { - // This looks like a tool call - args := "{}" - if argsVal, ok := jsonObj["arguments"]; ok { - if argsStr, ok := argsVal.(string); ok { - args = argsStr - } else { - argsBytes, _ := json.Marshal(argsVal) - args = string(argsBytes) - } + if jsonErr == nil && len(jsonResults) > lastEmittedCount { + for i := lastEmittedCount; i < len(jsonResults); i++ { + jsonObj := jsonResults[i] + name, ok := jsonObj["name"].(string) + if !ok || name == "" { + continue + } + args := "{}" + if argsVal, ok := jsonObj["arguments"]; ok { + if argsStr, ok := argsVal.(string); ok { + args = argsStr + } else { + argsBytes, _ := json.Marshal(argsVal) + args = string(argsBytes) } - // Emit tool call - initialMessage := schema.OpenAIResponse{ - ID: id, - Created: created, - Model: req.Model, - Choices: []schema.Choice{{ - Delta: &schema.Message{ - Role: "assistant", - ToolCalls: []schema.ToolCall{ - { - Index: lastEmittedCount, - ID: id, - Type: "function", - FunctionCall: schema.FunctionCall{ - Name: name, - Arguments: args, - }, + } + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: i, + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + Arguments: args, }, }, }, - Index: 0, - FinishReason: nil, - }}, - Object: "chat.completion.chunk", - } - select { - case responses <- initialMessage: - default: - } - lastEmittedCount++ + }, + Index: 0, + FinishReason: nil, + }}, + Object: "chat.completion.chunk", } + responses <- initialMessage } + lastEmittedCount = len(jsonResults) } } return true @@ -426,10 +423,17 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator toolCallID = id } + if i < lastEmittedCount { + // Already emitted during streaming by the incremental + // JSON/XML parser — skip to avoid duplicate tool calls. + continue + } + + // Tool call not yet emitted — send name + args (two chunks). initialMessage := schema.OpenAIResponse{ ID: id, Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Model: req.Model, Choices: []schema.Choice{{ Delta: &schema.Message{ Role: "assistant", @@ -454,7 +458,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator responses <- schema.OpenAIResponse{ ID: id, Created: created, - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Model: req.Model, Choices: []schema.Choice{{ Delta: &schema.Message{ Role: "assistant", diff --git a/tests/e2e/mock-backend/main.go b/tests/e2e/mock-backend/main.go index 469795e25bab..c9e096f42a33 100644 --- a/tests/e2e/mock-backend/main.go +++ b/tests/e2e/mock-backend/main.go @@ -115,6 +115,15 @@ func (m *MockBackend) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.R }, nil } + // Simulate multiple tool calls in a single response (Go-side JSON parser path). + if strings.Contains(in.Prompt, "MULTI_TOOL_CALL") { + return &pb.Reply{ + Message: []byte(`{"name": "get_weather", "arguments": {"location": "Rome"}} +{"name": "get_weather", "arguments": {"location": "Paris"}}`), + Tokens: 30, + PromptTokens: 10, + }, nil + } var response string toolName := mockToolNameFromRequest(in) if toolName != "" && !promptHasToolResults(in.Prompt) { @@ -219,6 +228,38 @@ func (m *MockBackend) PredictStream(in *pb.PredictOptions, stream pb.Backend_Pre return nil } + // Simulate tool calls streamed as whole JSON objects (Go-side parser path). + // Each object is sent as a complete chunk so the incremental parser can + // detect tool calls mid-stream (unlike char-by-char which only parses after + // streaming completes). + if strings.Contains(in.Prompt, "MULTI_TOOL_CALL") { + chunks := []string{ + `{"name": "get_weather", "arguments": {"location": "Rome"}}`, + "\n", + `{"name": "get_weather", "arguments": {"location": "Paris"}}`, + } + for i, chunk := range chunks { + if err := stream.Send(&pb.Reply{ + Message: []byte(chunk), + Tokens: int32(i + 1), + }); err != nil { + return err + } + } + return nil + } + + // Simulate single tool call streamed as whole JSON (Go-side parser path). + if strings.Contains(in.Prompt, "SINGLE_TOOL_CALL") { + if err := stream.Send(&pb.Reply{ + Message: []byte(`{"name": "get_weather", "arguments": {"location": "San Francisco"}}`), + Tokens: 1, + }); err != nil { + return err + } + return nil + } + var toStream string toolName := mockToolNameFromRequest(in) if toolName != "" && !promptHasToolResults(in.Prompt) { diff --git a/tests/e2e/mock_backend_test.go b/tests/e2e/mock_backend_test.go index d45301ef047c..9e7cecf88c36 100644 --- a/tests/e2e/mock_backend_test.go +++ b/tests/e2e/mock_backend_test.go @@ -509,4 +509,138 @@ var _ = Describe("Mock Backend E2E Tests", Label("MockBackend"), func() { }) }) }) + + // Tests for duplicate tool call emissions during streaming. + // The Go-side incremental JSON parser was emitting the same tool call on + // every streaming token, and the post-streaming default: case re-emitted + // all tool calls again, producing massive duplication. + Describe("Streaming Tool Call Deduplication", Label("ToolDedup"), func() { + // Helper: parse SSE lines and count tool call name/arguments chunks + parseToolCallChunks := func(data []byte) (nameChunks int, argChunks int) { + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + var chunk map[string]any + if err := json.Unmarshal([]byte(strings.TrimPrefix(line, "data: ")), &chunk); err != nil { + continue + } + choices, _ := chunk["choices"].([]any) + if len(choices) == 0 { + continue + } + delta, _ := choices[0].(map[string]any)["delta"].(map[string]any) + if delta == nil { + continue + } + toolCalls, _ := delta["tool_calls"].([]any) + for _, tc := range toolCalls { + tcMap, _ := tc.(map[string]any) + fn, _ := tcMap["function"].(map[string]any) + if fn == nil { + continue + } + if name, _ := fn["name"].(string); name != "" { + nameChunks++ + } + if args, _ := fn["arguments"].(string); args != "" { + argChunks++ + } + } + } + return + } + + Context("Single tool call via Go-side JSON parser", func() { + It("should emit exactly one tool call name without duplicates", func() { + body := `{ + "model": "mock-model-autoparser", + "messages": [{"role": "user", "content": "SINGLE_TOOL_CALL"}], + "tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}], + "stream": true + }` + req, err := http.NewRequest("POST", apiURL+"/chat/completions", strings.NewReader(body)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + httpClient := &http.Client{Timeout: 60 * time.Second} + resp, err := httpClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + Expect(resp.StatusCode).To(Equal(200)) + + data, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + + nameChunks, argChunks := parseToolCallChunks(data) + + Expect(nameChunks).To(Equal(1), + "Expected exactly 1 tool call name chunk, got %d. Full SSE:\n%s", + nameChunks, string(data)) + Expect(argChunks).To(BeNumerically(">=", 1), + "Expected at least 1 arguments chunk. Full SSE:\n%s", string(data)) + }) + }) + + Context("ChatDelta tool calls (regression guard)", func() { + It("should emit exactly one tool call name per tool", func() { + body := `{ + "model": "mock-model-autoparser", + "messages": [{"role": "user", "content": "AUTOPARSER_TOOL_CALL"}], + "tools": [{"type": "function", "function": {"name": "search_collections", "description": "Search documents", "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}}}], + "stream": true + }` + req, err := http.NewRequest("POST", apiURL+"/chat/completions", strings.NewReader(body)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + httpClient := &http.Client{Timeout: 60 * time.Second} + resp, err := httpClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + Expect(resp.StatusCode).To(Equal(200)) + + data, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + + nameChunks, _ := parseToolCallChunks(data) + + Expect(nameChunks).To(Equal(1), + "Expected exactly 1 tool call name chunk from ChatDeltas, got %d. Full SSE:\n%s", + nameChunks, string(data)) + }) + }) + + Context("Multiple tool calls via Go-side JSON parser", func() { + It("should emit exactly two tool call names without duplicates", func() { + body := `{ + "model": "mock-model-autoparser", + "messages": [{"role": "user", "content": "MULTI_TOOL_CALL"}], + "tools": [{"type": "function", "function": {"name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}}}], + "stream": true + }` + req, err := http.NewRequest("POST", apiURL+"/chat/completions", strings.NewReader(body)) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + httpClient := &http.Client{Timeout: 60 * time.Second} + resp, err := httpClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + defer resp.Body.Close() + Expect(resp.StatusCode).To(Equal(200)) + + data, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + + nameChunks, argChunks := parseToolCallChunks(data) + + Expect(nameChunks).To(Equal(2), + "Expected exactly 2 tool call name chunks (one per tool), got %d. Full SSE:\n%s", + nameChunks, string(data)) + Expect(argChunks).To(BeNumerically(">=", 2), + "Expected at least 2 arguments chunks. Full SSE:\n%s", string(data)) + }) + }) + }) })