Skip to content

Commit

Permalink
Small refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Feb 11, 2025
1 parent 973ef29 commit 427e763
Showing 1 changed file with 27 additions and 36 deletions.
63 changes: 27 additions & 36 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,46 +344,37 @@ def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[Str
:param chunk: The last chunk returned by the OpenAI API.
:param chunks: The list of all `StreamingChunk` objects.
"""

text = "".join([chunk.content for chunk in chunks])
tool_calls = []

# are there any tool calls in the chunks?
if any(chunk.meta.get("tool_calls") for chunk in chunks):
payloads = {} # Use a dict to track tool calls by ID
for chunk_payload in chunks:
deltas = chunk_payload.meta.get("tool_calls") or []

# deltas is a list of ChoiceDeltaToolCall
for delta in deltas:
if delta.id not in payloads:
payloads[delta.id] = {"id": delta.id, "arguments": "", "name": "", "type": None}
# ChoiceDeltaToolCall has a 'function' field of type ChoiceDeltaToolCallFunction
# Process tool calls if present in any chunk
tool_call_data = {} # Track tool calls by ID
for chunk_payload in chunks:
tool_calls_meta = chunk_payload.meta.get("tool_calls")
if tool_calls_meta:
for delta in tool_calls_meta:
if not delta.id in tool_call_data:
tool_call_data[delta.id] = {"id": delta.id, "name": "", "arguments": ""}

if delta.function:
# For tool calls with the same ID, use the latest values
if delta.function.name is not None:
payloads[delta.id]["name"] = delta.function.name
if delta.function.arguments is not None:
# Use the latest arguments value
payloads[delta.id]["arguments"] = delta.function.arguments
if delta.type is not None:
payloads[delta.id]["type"] = delta.type

for payload in payloads.values():
arguments_str = payload["arguments"]
try:
# Try to parse the concatenated arguments string as JSON
arguments = json.loads(arguments_str)
tool_calls.append(ToolCall(id=payload["id"], tool_name=payload["name"], arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"OpenAI returned a malformed JSON string for tool call arguments. This tool call "
"will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=payload["id"],
_name=payload["name"],
_arguments=arguments_str,
)
if delta.function.name:
tool_call_data[delta.id]["name"] = delta.function.name
if delta.function.arguments:
tool_call_data[delta.id]["arguments"] = delta.function.arguments

# Convert accumulated tool call data into ToolCall objects
for call_data in tool_call_data.values():
try:
arguments = json.loads(call_data["arguments"])
tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"Skipping malformed tool call due to invalid JSON. Set `tools_strict=True` for valid JSON. "
"Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=call_data["id"],
_name=call_data["name"],
_arguments=call_data["arguments"],
)

meta = {
"model": chunk.model,
Expand Down

0 comments on commit 427e763

Please sign in to comment.