From caffbf966e136a686cb3178fb9694c18c3a63bab Mon Sep 17 00:00:00 2001 From: karboom Date: Tue, 23 Sep 2025 15:53:22 +0800 Subject: [PATCH] feat: add stream hook function for ChatAgent.step() & ChatAgent.astep() --- camel/agents/chat_agent.py | 32 +++++++++++++++++++++---- examples/agents/chatagent_stream.py | 10 ++++++++ test/agents/test_chat_agent.py | 37 +++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 4 deletions(-) diff --git a/camel/agents/chat_agent.py b/camel/agents/chat_agent.py index 899e429f0d..bed1a74814 100644 --- a/camel/agents/chat_agent.py +++ b/camel/agents/chat_agent.py @@ -403,6 +403,9 @@ class ChatAgent(BaseAgent): updates return accumulated content (current behavior). When False, partial updates return only the incremental delta. (default: :obj:`True`) + stream_hook (Optional[Callable[[Union[ChatAgentResponse, BaseMessage]], None]], + optional): Function that will be called when streaming chunk message come. + (default: :obj:`None`) """ def __init__( @@ -447,6 +450,7 @@ def __init__( retry_delay: float = 1.0, step_timeout: Optional[float] = None, stream_accumulate: bool = True, + stream_hook: Optional[Callable[[Union[ChatAgentResponse, BaseMessage]], None]] = None, ) -> None: if isinstance(model, ModelManager): self.model_backend = model @@ -536,6 +540,7 @@ def __init__( self.retry_delay = max(0.0, retry_delay) self.step_timeout = step_timeout self.stream_accumulate = stream_accumulate + self.stream_hook = stream_hook def reset(self): r"""Resets the :obj:`ChatAgent` to its initial state.""" @@ -2463,13 +2468,19 @@ def _stream( try: openai_messages, num_tokens = self.memory.get_context() except RuntimeError as e: - yield self._step_terminate(e.args[1], [], "max_tokens_exceeded") + error = self._step_terminate(e.args[1], [], "max_tokens_exceeded") + + if self.stream_hook is not None: self.stream_hook(error) + + yield error return # Start streaming response - yield from self._stream_response( + for response in self._stream_response( openai_messages, num_tokens, response_format - ) + ): + if self.stream_hook is not None: self.stream_hook(response) + yield response def _get_token_count(self, content: str) -> int: r"""Get token count for content with fallback.""" @@ -2615,6 +2626,7 @@ def _stream_response( ), # type: ignore[arg-type] ) + if self.stream_hook is not None: self.stream_hook(final_message) self.record_message(final_message) # Create final response @@ -2743,6 +2755,9 @@ def _process_stream_chunks_with_accumulator( final_message, response_format ) + if self.stream_hook is not None: + self.stream_hook(final_message) + self.record_message(final_message) elif chunk.usage and not chunk.choices: # Handle final chunk with usage but empty choices @@ -3204,7 +3219,11 @@ async def _astream( try: openai_messages, num_tokens = self.memory.get_context() except RuntimeError as e: - yield self._step_terminate(e.args[1], [], "max_tokens_exceeded") + result = self._step_terminate(e.args[1], [], "max_tokens_exceeded") + + if self.stream_hook is not None: self.stream_hook(result) + yield result + return # Start async streaming response @@ -3213,6 +3232,8 @@ async def _astream( openai_messages, num_tokens, response_format ): last_response = response + + if self.stream_hook is not None: self.stream_hook(response) yield response # Clean tool call messages from memory after response generation @@ -3372,6 +3393,7 @@ async def _astream_response( ), # type: ignore[arg-type] ) + if self.stream_hook is not None: self.stream_hook(final_message) self.record_message(final_message) # Create final response @@ -3540,6 +3562,8 @@ async def _aprocess_stream_chunks_with_accumulator( final_message, response_format ) + if self.stream_hook is not None: self.stream_hook(final_message) + self.record_message(final_message) elif chunk.usage and not chunk.choices: # Handle final chunk with usage but empty choices diff --git a/examples/agents/chatagent_stream.py b/examples/agents/chatagent_stream.py index a8dbb2ab30..0247c89cb6 100644 --- a/examples/agents/chatagent_stream.py +++ b/examples/agents/chatagent_stream.py @@ -12,7 +12,9 @@ # limitations under the License. # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= from camel.agents import ChatAgent +from camel.messages import BaseMessage from camel.models import ModelFactory +from camel.responses import ChatAgentResponse from camel.types import ModelPlatformType, ModelType # Create a streaming model @@ -46,11 +48,19 @@ print("\n\n---\nDelta streaming mode (stream_accumulate=False):\n") +def stream_hook(message): + # do something like push message to Queue or else + if isinstance(message, ChatAgentResponse): + print("message chunk", message) + elif isinstance(message, BaseMessage): + print('message complete', message) + # Create an agent that yields delta chunks instead of accumulated content agent_delta = ChatAgent( system_message="You are a helpful assistant that provides concise " "and informative responses.", model=streaming_model, + stream_hook=stream_hook, stream_accumulate=False, # Only yield the delta part per chunk ) diff --git a/test/agents/test_chat_agent.py b/test/agents/test_chat_agent.py index 770c004323..ad4cda1070 100644 --- a/test/agents/test_chat_agent.py +++ b/test/agents/test_chat_agent.py @@ -36,6 +36,7 @@ from camel.memories import MemoryRecord from camel.messages import BaseMessage from camel.models import AnthropicModel, ModelFactory, OpenAIModel +from camel.responses import ChatAgentResponse from camel.terminators import ResponseWordsTerminator from camel.toolkits import ( FunctionTool, @@ -742,6 +743,42 @@ def test_chat_agent_stream_output(step_call_count=3): + stream_usage["prompt_tokens"] ), f"Error in calling round {i+1}" +@pytest.mark.model_backend +def test_chat_agent_stream_hook(step_call_count=3): + system_msg = BaseMessage( + "Assistant", + RoleType.ASSISTANT, + meta_dict=None, + content="You are a helpful assistant.", + ) + user_msg = BaseMessage( + role_name="User", + role_type=RoleType.USER, + meta_dict=dict(), + content="Tell me a joke.", + ) + + stream_model_config = ChatGPTConfig(temperature=0, n=2, stream=True) + model = ModelFactory.create( + model_platform=ModelPlatformType.OPENAI, + model_type=ModelType.GPT_5_MINI, + model_config_dict=stream_model_config.as_dict(), + ) + model.run = MagicMock(return_value=model_backend_rsp_base) + + message_holder = [] + def stream_hook(message): + if isinstance(message, ChatAgentResponse): + message_holder.append(message) + + stream_assistant = ChatAgent(system_msg, model=model, stream_hook=stream_hook) + stream_assistant.reset() + for i in range(step_call_count): + stream_assistant_response = stream_assistant.step(user_msg) + + assert stream_assistant_response.msgs[0] == message_holder[0].msgs[0], f"Error in calling round {i+1}" + + @pytest.mark.model_backend def test_chat_agent_stream_accumulate_mode_accumulated():