Skip to content

Commit

Permalink
Merge streaming versions in one more elegant function
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Oct 31, 2024
1 parent df8640c commit 1659e66
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 56 deletions.
71 changes: 22 additions & 49 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,8 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
Args:
task (`str`): The task to perform
stream (`bool`, *optional*, defaults to `False`): Whether to stream the logs of the agent's interactions.
reset (`bool`, *optional*, defaults to `True`): Whether to reset the agent's state before running it.
Example:
```py
Expand All @@ -760,19 +762,12 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
print("OKOK TASK", self.task)
if reset:
self.initialize_for_run()
else:
self.logs.append({"task": task})
if stream:
return self.stream_run(task)
else:
return self.direct_run(task)

def stream_run(self, task: str):
"""
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
Expand All @@ -785,47 +780,25 @@ def stream_run(self, task: str):
self.logs[-1]["error"] = e
finally:
iteration += 1
yield self.logs[-1]

if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
if stream:
yield self.logs[-1]

if iteration == self.max_iterations:
if final_answer is None:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
else:
final_step_log = self.logs[-1]

if stream:
yield final_step_log
else:
return final_answer

yield final_answer

def direct_run(self, task: str):
"""
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
self.logs[-1]["error"] = e
finally:
iteration += 1

if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer

return final_answer

def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
"""
Expand Down Expand Up @@ -967,7 +940,7 @@ def step(self):
# Add new step in logs
current_step_logs = {}
self.logs.append(current_step_logs)
current_step_logs["agent_memory"] = agent_memory.copy()
# current_step_logs["agent_memory"] = agent_memory.copy()

self.logger.info("===== Calling LLM with this last message: =====")
self.logger.info(self.prompt[-1])
Expand Down Expand Up @@ -1089,7 +1062,7 @@ def step(self):
# Add new step in logs
current_step_logs = {}
self.logs.append(current_step_logs)
current_step_logs["agent_memory"] = agent_memory.copy()
# current_step_logs["agent_memory"] = agent_memory.copy()

self.logger.info("===== Calling LLM with these last messages: =====")
self.logger.info(self.prompt[-2:])
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/agents/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,18 @@ def stream_to_gradio(agent: ReactAgent, task: str, **kwargs):
for message in pull_message(step_log):
yield message

if isinstance(step_log, AgentText):
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```")
elif isinstance(step_log, AgentImage):

if isinstance(step_log["final_answer"], AgentText):
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log["final_answer"].to_string()}\n```")
elif isinstance(step_log["final_answer"], AgentImage):
yield ChatMessage(
role="assistant",
content={"path": step_log.to_string(), "mime_type": "image/png"},
content={"path": step_log["final_answer"].to_string(), "mime_type": "image/png"},
)
elif isinstance(step_log, AgentAudio):
elif isinstance(step_log["final_answer"], AgentAudio):
yield ChatMessage(
role="assistant",
content={"path": step_log.to_string(), "mime_type": "audio/wav"},
content={"path": step_log["final_answer"].to_string(), "mime_type": "audio/wav"},
)
else:
yield ChatMessage(role="assistant", content=str(step_log))
yield ChatMessage(role="assistant", content=str(step_log["final_answer"]))
84 changes: 84 additions & 0 deletions tests/agents/test_monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from transformers.agents.monitoring import stream_to_gradio
from transformers.agents.agents import ReactCodeAgent, ReactJsonAgent, AgentError
from transformers.agents.agent_types import AgentImage

class TestMonitoring(unittest.TestCase):
def test_streaming_agent_text_output(self):
# Create a dummy LLM engine that returns a final answer
def dummy_llm_engine(prompt, **kwargs):
return "final_answer('This is the final answer.')"

agent = ReactCodeAgent(
tools=[],
llm_engine=dummy_llm_engine,
max_iterations=1,
)

# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task"))

# Check that the final output is a ChatMessage with the expected content
self.assertEqual(len(outputs), 3)
final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant")
self.assertIn("This is the final answer.", final_message.content)

def test_streaming_agent_image_output(self):
# Create a dummy LLM engine that returns an image as final answer
def dummy_llm_engine(prompt, **kwargs):
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'

agent = ReactJsonAgent(
tools=[],
llm_engine=dummy_llm_engine,
max_iterations=1,
)

# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png")))

# Check that the final output is a ChatMessage with the expected content
self.assertEqual(len(outputs), 3)
final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant")
self.assertIsInstance(final_message.content, dict)
self.assertEqual(final_message.content["path"], "path.png")
self.assertEqual(final_message.content["mime_type"], "image/png")

def test_streaming_with_agent_error(self):
# Create a dummy LLM engine that raises an error
def dummy_llm_engine(prompt, **kwargs):
raise AgentError("Simulated agent error.")

agent = ReactCodeAgent(
tools=[],
llm_engine=dummy_llm_engine,
max_iterations=1,
)

# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task"))

# Check that the error message is yielded
print("OUTPUTTTTS", outputs)
self.assertEqual(len(outputs), 3)
final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant")
self.assertIn("Simulated agent error", final_message.content)

0 comments on commit 1659e66

Please sign in to comment.