From abf1d873cbf601976a8670d692d0203cac54f32e Mon Sep 17 00:00:00 2001 From: Bohan Qu Date: Mon, 27 Jan 2025 21:58:16 +0800 Subject: [PATCH] feat: support deepseek-reasoner --- haystack/components/generators/chat/openai.py | 21 ++++++++++++++++--- haystack/dataclasses/chat_message.py | 15 ++++++++++++- haystack/dataclasses/streaming_chunk.py | 3 ++- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index b30de1b43d..08197d42d4 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -348,6 +348,14 @@ def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[Str text = "".join([chunk.content for chunk in chunks]) tool_calls = [] + reasoning_content = None + for chunk in chunks: + if chunk.reasoning_content is not None: + if chunk.reasoning_content is not None: + reasoning_content += chunk.reasoning_content + else: + reasoning_content = chunk.reasoning_content + # if it's a tool call , we need to build the payload dict from all the chunks if bool(chunks[0].meta.get("tool_calls")): tools_len = len(chunks[0].meta.get("tool_calls", [])) @@ -386,7 +394,12 @@ def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[Str "usage": {}, # we don't have usage data for streaming responses } - return ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) + return ChatMessage.from_assistant( + text=text, + tool_calls=tool_calls, + reasoning_content=reasoning_content, + meta=meta + ) def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, choice: Choice) -> ChatMessage: """ @@ -398,6 +411,7 @@ def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, c """ message: ChatCompletionMessage = choice.message text = message.content + reasoning_content = message.reasoning_content tool_calls = [] if openai_tool_calls := message.tool_calls: for openai_tc in openai_tool_calls: @@ -415,7 +429,7 @@ def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, c _arguments=arguments_str, ) - chat_message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls) + chat_message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, reasoning_content=reasoning_content) chat_message._meta.update( { "model": completion.model, @@ -437,7 +451,8 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio # we stream the content of the chunk if it's not a tool or function call choice: ChunkChoice = chunk.choices[0] content = choice.delta.content or "" - chunk_message = StreamingChunk(content) + reasoning_content = choice.delta.reasoning_content + chunk_message = StreamingChunk(content, reasoning_content=reasoning_content) # but save the tool calls and function call in the meta if they are present # and then connect the chunks in the _convert_streaming_chunks_to_chat_message method chunk_message.meta.update( diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index 925259359f..70922ef847 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -94,6 +94,7 @@ class ChatMessage: _role: ChatRole _content: Sequence[ChatMessageContentT] + _reasoning_content: Optional[str] = None _name: Optional[str] = None _meta: Dict[str, Any] = field(default_factory=dict, hash=False) @@ -211,6 +212,10 @@ def tool_call_result(self) -> Optional[ToolCallResult]: return tool_call_results[0] return None + @property + def reasoning_content(self) -> Optional[str]: + return self._reasoning_content + def is_from(self, role: Union[ChatRole, str]) -> bool: """ Check if the message is from a specific role. @@ -253,6 +258,7 @@ def from_assistant( meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None, tool_calls: Optional[List[ToolCall]] = None, + reasoning_content: Optional[str] = None, ) -> "ChatMessage": """ Create a message from the assistant. @@ -269,7 +275,13 @@ def from_assistant( if tool_calls: content.extend(tool_calls) - return cls(_role=ChatRole.ASSISTANT, _content=content, _meta=meta or {}, _name=name) + return cls( + _role=ChatRole.ASSISTANT, + _content=content, + _meta=meta or {}, + _name=name, + _reasoning_content=reasoning_content + ) @classmethod def from_tool( @@ -301,6 +313,7 @@ def to_dict(self) -> Dict[str, Any]: serialized["_role"] = self._role.value serialized["_meta"] = self._meta serialized["_name"] = self._name + serialized["_reasoning_content"] = self._reasoning_content content: List[Dict[str, Any]] = [] for part in self._content: if isinstance(part, TextContent): diff --git a/haystack/dataclasses/streaming_chunk.py b/haystack/dataclasses/streaming_chunk.py index 455caa09e2..a55eea912a 100644 --- a/haystack/dataclasses/streaming_chunk.py +++ b/haystack/dataclasses/streaming_chunk.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Any, Dict +from typing import Any, Dict, Optional @dataclass @@ -18,4 +18,5 @@ class StreamingChunk: """ content: str + reasoning_content: Optional[str] = None meta: Dict[str, Any] = field(default_factory=dict, hash=False)