Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,9 @@ class ChatAgent(BaseAgent):
updates return accumulated content (current behavior). When False,
partial updates return only the incremental delta. (default:
:obj:`True`)
stream_hook (Optional[Callable[[Union[ChatAgentResponse, BaseMessage]], None]],
optional): Function that will be called when streaming chunk message come.
(default: :obj:`None`)
"""

def __init__(
Expand Down Expand Up @@ -447,6 +450,7 @@ def __init__(
retry_delay: float = 1.0,
step_timeout: Optional[float] = None,
stream_accumulate: bool = True,
stream_hook: Optional[Callable[[Union[ChatAgentResponse, BaseMessage]], None]] = None,
) -> None:
if isinstance(model, ModelManager):
self.model_backend = model
Expand Down Expand Up @@ -536,6 +540,7 @@ def __init__(
self.retry_delay = max(0.0, retry_delay)
self.step_timeout = step_timeout
self.stream_accumulate = stream_accumulate
self.stream_hook = stream_hook

def reset(self):
r"""Resets the :obj:`ChatAgent` to its initial state."""
Expand Down Expand Up @@ -2463,13 +2468,19 @@ def _stream(
try:
openai_messages, num_tokens = self.memory.get_context()
except RuntimeError as e:
yield self._step_terminate(e.args[1], [], "max_tokens_exceeded")
error = self._step_terminate(e.args[1], [], "max_tokens_exceeded")

if self.stream_hook is not None: self.stream_hook(error)

yield error
return

# Start streaming response
yield from self._stream_response(
for response in self._stream_response(
openai_messages, num_tokens, response_format
)
):
if self.stream_hook is not None: self.stream_hook(response)
yield response

def _get_token_count(self, content: str) -> int:
r"""Get token count for content with fallback."""
Expand Down Expand Up @@ -2615,6 +2626,7 @@ def _stream_response(
), # type: ignore[arg-type]
)

if self.stream_hook is not None: self.stream_hook(final_message)
self.record_message(final_message)

# Create final response
Expand Down Expand Up @@ -2743,6 +2755,9 @@ def _process_stream_chunks_with_accumulator(
final_message, response_format
)

if self.stream_hook is not None:
self.stream_hook(final_message)

self.record_message(final_message)
elif chunk.usage and not chunk.choices:
# Handle final chunk with usage but empty choices
Expand Down Expand Up @@ -3204,7 +3219,11 @@ async def _astream(
try:
openai_messages, num_tokens = self.memory.get_context()
except RuntimeError as e:
yield self._step_terminate(e.args[1], [], "max_tokens_exceeded")
result = self._step_terminate(e.args[1], [], "max_tokens_exceeded")

if self.stream_hook is not None: self.stream_hook(result)
yield result

return

# Start async streaming response
Expand All @@ -3213,6 +3232,8 @@ async def _astream(
openai_messages, num_tokens, response_format
):
last_response = response

if self.stream_hook is not None: self.stream_hook(response)
yield response

# Clean tool call messages from memory after response generation
Expand Down Expand Up @@ -3372,6 +3393,7 @@ async def _astream_response(
), # type: ignore[arg-type]
)

if self.stream_hook is not None: self.stream_hook(final_message)
self.record_message(final_message)

# Create final response
Expand Down Expand Up @@ -3540,6 +3562,8 @@ async def _aprocess_stream_chunks_with_accumulator(
final_message, response_format
)

if self.stream_hook is not None: self.stream_hook(final_message)

self.record_message(final_message)
elif chunk.usage and not chunk.choices:
# Handle final chunk with usage but empty choices
Expand Down
10 changes: 10 additions & 0 deletions examples/agents/chatagent_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from camel.agents import ChatAgent
from camel.messages import BaseMessage
from camel.models import ModelFactory
from camel.responses import ChatAgentResponse
from camel.types import ModelPlatformType, ModelType

# Create a streaming model
Expand Down Expand Up @@ -46,11 +48,19 @@

print("\n\n---\nDelta streaming mode (stream_accumulate=False):\n")

def stream_hook(message):
# do something like push message to Queue or else
if isinstance(message, ChatAgentResponse):
print("message chunk", message)
elif isinstance(message, BaseMessage):
print('message complete', message)

# Create an agent that yields delta chunks instead of accumulated content
agent_delta = ChatAgent(
system_message="You are a helpful assistant that provides concise "
"and informative responses.",
model=streaming_model,
stream_hook=stream_hook,
stream_accumulate=False, # Only yield the delta part per chunk
)

Expand Down
37 changes: 37 additions & 0 deletions test/agents/test_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from camel.memories import MemoryRecord
from camel.messages import BaseMessage
from camel.models import AnthropicModel, ModelFactory, OpenAIModel
from camel.responses import ChatAgentResponse
from camel.terminators import ResponseWordsTerminator
from camel.toolkits import (
FunctionTool,
Expand Down Expand Up @@ -742,6 +743,42 @@ def test_chat_agent_stream_output(step_call_count=3):
+ stream_usage["prompt_tokens"]
), f"Error in calling round {i+1}"

@pytest.mark.model_backend
def test_chat_agent_stream_hook(step_call_count=3):
system_msg = BaseMessage(
"Assistant",
RoleType.ASSISTANT,
meta_dict=None,
content="You are a helpful assistant.",
)
user_msg = BaseMessage(
role_name="User",
role_type=RoleType.USER,
meta_dict=dict(),
content="Tell me a joke.",
)

stream_model_config = ChatGPTConfig(temperature=0, n=2, stream=True)
model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_5_MINI,
model_config_dict=stream_model_config.as_dict(),
)
model.run = MagicMock(return_value=model_backend_rsp_base)

message_holder = []
def stream_hook(message):
if isinstance(message, ChatAgentResponse):
message_holder.append(message)

stream_assistant = ChatAgent(system_msg, model=model, stream_hook=stream_hook)
stream_assistant.reset()
for i in range(step_call_count):
stream_assistant_response = stream_assistant.step(user_msg)

assert stream_assistant_response.msgs[0] == message_holder[0].msgs[0], f"Error in calling round {i+1}"



@pytest.mark.model_backend
def test_chat_agent_stream_accumulate_mode_accumulated():
Expand Down
Loading