From 427e7633952fb132885ad53dca73915befae93c4 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 11 Feb 2025 09:56:17 +0100 Subject: [PATCH] Small refactoring --- haystack/components/generators/chat/openai.py | 63 ++++++++----------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 058084ec9d..db430d2e06 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -344,46 +344,37 @@ def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[Str :param chunk: The last chunk returned by the OpenAI API. :param chunks: The list of all `StreamingChunk` objects. """ - text = "".join([chunk.content for chunk in chunks]) tool_calls = [] - # are there any tool calls in the chunks? - if any(chunk.meta.get("tool_calls") for chunk in chunks): - payloads = {} # Use a dict to track tool calls by ID - for chunk_payload in chunks: - deltas = chunk_payload.meta.get("tool_calls") or [] - - # deltas is a list of ChoiceDeltaToolCall - for delta in deltas: - if delta.id not in payloads: - payloads[delta.id] = {"id": delta.id, "arguments": "", "name": "", "type": None} - # ChoiceDeltaToolCall has a 'function' field of type ChoiceDeltaToolCallFunction + # Process tool calls if present in any chunk + tool_call_data = {} # Track tool calls by ID + for chunk_payload in chunks: + tool_calls_meta = chunk_payload.meta.get("tool_calls") + if tool_calls_meta: + for delta in tool_calls_meta: + if not delta.id in tool_call_data: + tool_call_data[delta.id] = {"id": delta.id, "name": "", "arguments": ""} + if delta.function: - # For tool calls with the same ID, use the latest values - if delta.function.name is not None: - payloads[delta.id]["name"] = delta.function.name - if delta.function.arguments is not None: - # Use the latest arguments value - payloads[delta.id]["arguments"] = delta.function.arguments - if delta.type is not None: - payloads[delta.id]["type"] = delta.type - - for payload in payloads.values(): - arguments_str = payload["arguments"] - try: - # Try to parse the concatenated arguments string as JSON - arguments = json.loads(arguments_str) - tool_calls.append(ToolCall(id=payload["id"], tool_name=payload["name"], arguments=arguments)) - except json.JSONDecodeError: - logger.warning( - "OpenAI returned a malformed JSON string for tool call arguments. This tool call " - "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " - "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", - _id=payload["id"], - _name=payload["name"], - _arguments=arguments_str, - ) + if delta.function.name: + tool_call_data[delta.id]["name"] = delta.function.name + if delta.function.arguments: + tool_call_data[delta.id]["arguments"] = delta.function.arguments + + # Convert accumulated tool call data into ToolCall objects + for call_data in tool_call_data.values(): + try: + arguments = json.loads(call_data["arguments"]) + tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments)) + except json.JSONDecodeError: + logger.warning( + "Skipping malformed tool call due to invalid JSON. Set `tools_strict=True` for valid JSON. " + "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", + _id=call_data["id"], + _name=call_data["name"], + _arguments=call_data["arguments"], + ) meta = { "model": chunk.model,