From 1659e662285ab647e5af3098e4f0d9fe1dcbfe93 Mon Sep 17 00:00:00 2001 From: Aymeric Date: Thu, 31 Oct 2024 23:19:47 +0100 Subject: [PATCH] Merge streaming versions in one more elegant function --- src/transformers/agents/agents.py | 71 +++++++--------------- src/transformers/agents/monitoring.py | 15 ++--- tests/agents/test_monitoring.py | 84 +++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 56 deletions(-) create mode 100644 tests/agents/test_monitoring.py diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 73b7186d25a3c7..d5ba7e5b6a5404 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -748,6 +748,8 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs): Args: task (`str`): The task to perform + stream (`bool`, *optional*, defaults to `False`): Whether to stream the logs of the agent's interactions. + reset (`bool`, *optional*, defaults to `True`): Whether to reset the agent's state before running it. Example: ```py @@ -760,19 +762,12 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs): if len(kwargs) > 0: self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." self.state = kwargs.copy() + print("OKOK TASK", self.task) if reset: self.initialize_for_run() else: self.logs.append({"task": task}) - if stream: - return self.stream_run(task) - else: - return self.direct_run(task) - def stream_run(self, task: str): - """ - Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method. - """ final_answer = None iteration = 0 while final_answer is None and iteration < self.max_iterations: @@ -785,47 +780,25 @@ def stream_run(self, task: str): self.logs[-1]["error"] = e finally: iteration += 1 - yield self.logs[-1] - - if final_answer is None and iteration == self.max_iterations: - error_message = "Reached max iterations." - final_step_log = {"error": AgentMaxIterationsError(error_message)} - self.logs.append(final_step_log) - self.logger.error(error_message, exc_info=1) - final_answer = self.provide_final_answer(task) - final_step_log["final_answer"] = final_answer + if stream: + yield self.logs[-1] + + if iteration == self.max_iterations: + if final_answer is None: + error_message = "Reached max iterations." + final_step_log = {"error": AgentMaxIterationsError(error_message)} + self.logs.append(final_step_log) + self.logger.error(error_message, exc_info=1) + final_answer = self.provide_final_answer(task) + final_step_log["final_answer"] = final_answer + else: + final_step_log = self.logs[-1] + + if stream: yield final_step_log + else: + return final_answer - yield final_answer - - def direct_run(self, task: str): - """ - Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method. - """ - final_answer = None - iteration = 0 - while final_answer is None and iteration < self.max_iterations: - try: - if self.planning_interval is not None and iteration % self.planning_interval == 0: - self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) - step_logs = self.step() - if "final_answer" in step_logs: - final_answer = step_logs["final_answer"] - except AgentError as e: - self.logger.error(e, exc_info=1) - self.logs[-1]["error"] = e - finally: - iteration += 1 - - if final_answer is None and iteration == self.max_iterations: - error_message = "Reached max iterations." - final_step_log = {"error": AgentMaxIterationsError(error_message)} - self.logs.append(final_step_log) - self.logger.error(error_message, exc_info=1) - final_answer = self.provide_final_answer(task) - final_step_log["final_answer"] = final_answer - - return final_answer def planning_step(self, task, is_first_step: bool = False, iteration: int = None): """ @@ -967,7 +940,7 @@ def step(self): # Add new step in logs current_step_logs = {} self.logs.append(current_step_logs) - current_step_logs["agent_memory"] = agent_memory.copy() + # current_step_logs["agent_memory"] = agent_memory.copy() self.logger.info("===== Calling LLM with this last message: =====") self.logger.info(self.prompt[-1]) @@ -1089,7 +1062,7 @@ def step(self): # Add new step in logs current_step_logs = {} self.logs.append(current_step_logs) - current_step_logs["agent_memory"] = agent_memory.copy() + # current_step_logs["agent_memory"] = agent_memory.copy() self.logger.info("===== Calling LLM with these last messages: =====") self.logger.info(self.prompt[-2:]) diff --git a/src/transformers/agents/monitoring.py b/src/transformers/agents/monitoring.py index 8e28a72deb2a3e..93101a3ee1f654 100644 --- a/src/transformers/agents/monitoring.py +++ b/src/transformers/agents/monitoring.py @@ -59,17 +59,18 @@ def stream_to_gradio(agent: ReactAgent, task: str, **kwargs): for message in pull_message(step_log): yield message - if isinstance(step_log, AgentText): - yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```") - elif isinstance(step_log, AgentImage): + + if isinstance(step_log["final_answer"], AgentText): + yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log["final_answer"].to_string()}\n```") + elif isinstance(step_log["final_answer"], AgentImage): yield ChatMessage( role="assistant", - content={"path": step_log.to_string(), "mime_type": "image/png"}, + content={"path": step_log["final_answer"].to_string(), "mime_type": "image/png"}, ) - elif isinstance(step_log, AgentAudio): + elif isinstance(step_log["final_answer"], AgentAudio): yield ChatMessage( role="assistant", - content={"path": step_log.to_string(), "mime_type": "audio/wav"}, + content={"path": step_log["final_answer"].to_string(), "mime_type": "audio/wav"}, ) else: - yield ChatMessage(role="assistant", content=str(step_log)) + yield ChatMessage(role="assistant", content=str(step_log["final_answer"])) diff --git a/tests/agents/test_monitoring.py b/tests/agents/test_monitoring.py new file mode 100644 index 00000000000000..2df00f7efd87e5 --- /dev/null +++ b/tests/agents/test_monitoring.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from transformers.agents.monitoring import stream_to_gradio +from transformers.agents.agents import ReactCodeAgent, ReactJsonAgent, AgentError +from transformers.agents.agent_types import AgentImage + +class TestMonitoring(unittest.TestCase): + def test_streaming_agent_text_output(self): + # Create a dummy LLM engine that returns a final answer + def dummy_llm_engine(prompt, **kwargs): + return "final_answer('This is the final answer.')" + + agent = ReactCodeAgent( + tools=[], + llm_engine=dummy_llm_engine, + max_iterations=1, + ) + + # Use stream_to_gradio to capture the output + outputs = list(stream_to_gradio(agent, task="Test task")) + + # Check that the final output is a ChatMessage with the expected content + self.assertEqual(len(outputs), 3) + final_message = outputs[-1] + self.assertEqual(final_message.role, "assistant") + self.assertIn("This is the final answer.", final_message.content) + + def test_streaming_agent_image_output(self): + # Create a dummy LLM engine that returns an image as final answer + def dummy_llm_engine(prompt, **kwargs): + return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' + + agent = ReactJsonAgent( + tools=[], + llm_engine=dummy_llm_engine, + max_iterations=1, + ) + + # Use stream_to_gradio to capture the output + outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"))) + + # Check that the final output is a ChatMessage with the expected content + self.assertEqual(len(outputs), 3) + final_message = outputs[-1] + self.assertEqual(final_message.role, "assistant") + self.assertIsInstance(final_message.content, dict) + self.assertEqual(final_message.content["path"], "path.png") + self.assertEqual(final_message.content["mime_type"], "image/png") + + def test_streaming_with_agent_error(self): + # Create a dummy LLM engine that raises an error + def dummy_llm_engine(prompt, **kwargs): + raise AgentError("Simulated agent error.") + + agent = ReactCodeAgent( + tools=[], + llm_engine=dummy_llm_engine, + max_iterations=1, + ) + + # Use stream_to_gradio to capture the output + outputs = list(stream_to_gradio(agent, task="Test task")) + + # Check that the error message is yielded + print("OUTPUTTTTS", outputs) + self.assertEqual(len(outputs), 3) + final_message = outputs[-1] + self.assertEqual(final_message.role, "assistant") + self.assertIn("Simulated agent error", final_message.content) +