diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 73b7186d25a3c7..c461c50f29592c 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -1141,11 +1141,10 @@ def step(self): ) self.logger.warning("Print outputs:") self.logger.log(32, self.state["print_outputs"]) + observation = "Print outputs:\n" + self.state["print_outputs"] if result is not None: self.logger.warning("Last output from code snippet:") self.logger.log(32, str(result)) - observation = "Print outputs:\n" + self.state["print_outputs"] - if result is not None: observation += "Last output from code snippet:\n" + str(result)[:100000] current_step_logs["observation"] = observation except Exception as e: diff --git a/src/transformers/agents/monitoring.py b/src/transformers/agents/monitoring.py index 8e28a72deb2a3e..755418d35a56a3 100644 --- a/src/transformers/agents/monitoring.py +++ b/src/transformers/agents/monitoring.py @@ -18,11 +18,19 @@ from .agents import ReactAgent -def pull_message(step_log: dict): +def pull_message(step_log: dict, test_mode: bool = True): try: from gradio import ChatMessage except ImportError: - raise ImportError("Gradio should be installed in order to launch a gradio demo.") + if test_mode: + + class ChatMessage: + def __init__(self, role, content, metadata=None): + self.role = role + self.content = content + self.metadata = metadata + else: + raise ImportError("Gradio should be installed in order to launch a gradio demo.") if step_log.get("rationale"): yield ChatMessage(role="assistant", content=step_log["rationale"]) @@ -46,30 +54,40 @@ def pull_message(step_log: dict): ) -def stream_to_gradio(agent: ReactAgent, task: str, **kwargs): +def stream_to_gradio(agent: ReactAgent, task: str, test_mode: bool = False, **kwargs): """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" try: from gradio import ChatMessage except ImportError: - raise ImportError("Gradio should be installed in order to launch a gradio demo.") + if test_mode: + + class ChatMessage: + def __init__(self, role, content, metadata=None): + self.role = role + self.content = content + self.metadata = metadata + else: + raise ImportError("Gradio should be installed in order to launch a gradio demo.") for step_log in agent.run(task, stream=True, **kwargs): if isinstance(step_log, dict): - for message in pull_message(step_log): + for message in pull_message(step_log, test_mode=test_mode): yield message - if isinstance(step_log, AgentText): - yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```") - elif isinstance(step_log, AgentImage): + final_answer = step_log # Last log is the run's final_answer + + if isinstance(final_answer, AgentText): + yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```") + elif isinstance(final_answer, AgentImage): yield ChatMessage( role="assistant", - content={"path": step_log.to_string(), "mime_type": "image/png"}, + content={"path": final_answer.to_string(), "mime_type": "image/png"}, ) - elif isinstance(step_log, AgentAudio): + elif isinstance(final_answer, AgentAudio): yield ChatMessage( role="assistant", - content={"path": step_log.to_string(), "mime_type": "audio/wav"}, + content={"path": final_answer.to_string(), "mime_type": "audio/wav"}, ) else: - yield ChatMessage(role="assistant", content=str(step_log)) + yield ChatMessage(role="assistant", content=str(final_answer)) diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index fbece2bebd350f..6e90f356cb928e 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -848,6 +848,13 @@ def evaluate_ast( raise InterpreterError(f"{expression.__class__.__name__} is not supported.") +def truncate_print_outputs(print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT) -> str: + if len(print_outputs) < max_len_outputs: + return print_outputs + else: + return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n" + + def evaluate_python_code( code: str, static_tools: Optional[Dict[str, Callable]] = None, @@ -890,25 +897,12 @@ def evaluate_python_code( PRINT_OUTPUTS = "" global OPERATIONS_COUNT OPERATIONS_COUNT = 0 - for node in expression.body: - try: + try: + for node in expression.body: result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) - except InterpreterError as e: - msg = "" - if len(PRINT_OUTPUTS) > 0: - if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT: - msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n" - else: - msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n" - msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" - raise InterpreterError(msg) - finally: - if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT: - state["print_outputs"] = PRINT_OUTPUTS - else: - state["print_outputs"] = ( - PRINT_OUTPUTS[:MAX_LEN_OUTPUT] - + f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._" - ) - - return result + state["print_outputs"] = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT) + return result + except InterpreterError as e: + msg = truncate_print_outputs(PRINT_OUTPUTS, max_len_outputs=MAX_LEN_OUTPUT) + msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" + raise InterpreterError(msg) diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index 994e1bdd817b0c..84bcf0fde61f18 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import ast import base64 import importlib import inspect @@ -141,15 +142,19 @@ def validate_arguments(self, do_validate_forward: bool = True): required_attributes = { "description": str, "name": str, - "inputs": Dict, + "inputs": dict, "output_type": str, } authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"] for attr, expected_type in required_attributes.items(): attr_value = getattr(self, attr, None) + if attr_value is None: + raise TypeError(f"You must set an attribute {attr}.") if not isinstance(attr_value, expected_type): - raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.") + raise TypeError( + f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." + ) for input_name, input_content in self.inputs.items(): assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary." assert ( @@ -248,7 +253,6 @@ def save(self, output_dir): def from_hub( cls, repo_id: str, - model_repo_id: Optional[str] = None, token: Optional[str] = None, **kwargs, ): @@ -266,9 +270,6 @@ def from_hub( Args: repo_id (`str`): The name of the repo on the Hub where your tool is defined. - model_repo_id (`str`, *optional*): - If your tool uses a model and you want to use a different model than the default, you can pass a second - repo ID or an endpoint url to this argument. token (`str`, *optional*): The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). @@ -354,6 +355,9 @@ def from_hub( if tool_class.output_type != custom_tool["output_type"]: tool_class.output_type = custom_tool["output_type"] + if not isinstance(tool_class.inputs, dict): + tool_class.inputs = ast.literal_eval(tool_class.inputs) + return tool_class(**kwargs) def push_to_hub( diff --git a/tests/agents/test_monitoring.py b/tests/agents/test_monitoring.py new file mode 100644 index 00000000000000..c43c9cb8bf86dd --- /dev/null +++ b/tests/agents/test_monitoring.py @@ -0,0 +1,82 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers.agents.agent_types import AgentImage +from transformers.agents.agents import AgentError, ReactCodeAgent, ReactJsonAgent +from transformers.agents.monitoring import stream_to_gradio + + +class MonitoringTester(unittest.TestCase): + def test_streaming_agent_text_output(self): + def dummy_llm_engine(prompt, **kwargs): + return """ +Code: +```` +final_answer('This is the final answer.') +```""" + + agent = ReactCodeAgent( + tools=[], + llm_engine=dummy_llm_engine, + max_iterations=1, + ) + + # Use stream_to_gradio to capture the output + outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True)) + + self.assertEqual(len(outputs), 3) + final_message = outputs[-1] + self.assertEqual(final_message.role, "assistant") + self.assertIn("This is the final answer.", final_message.content) + + def test_streaming_agent_image_output(self): + def dummy_llm_engine(prompt, **kwargs): + return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' + + agent = ReactJsonAgent( + tools=[], + llm_engine=dummy_llm_engine, + max_iterations=1, + ) + + # Use stream_to_gradio to capture the output + outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True)) + + self.assertEqual(len(outputs), 2) + final_message = outputs[-1] + self.assertEqual(final_message.role, "assistant") + self.assertIsInstance(final_message.content, dict) + self.assertEqual(final_message.content["path"], "path.png") + self.assertEqual(final_message.content["mime_type"], "image/png") + + def test_streaming_with_agent_error(self): + def dummy_llm_engine(prompt, **kwargs): + raise AgentError("Simulated agent error") + + agent = ReactCodeAgent( + tools=[], + llm_engine=dummy_llm_engine, + max_iterations=1, + ) + + # Use stream_to_gradio to capture the output + outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True)) + + self.assertEqual(len(outputs), 3) + final_message = outputs[-1] + self.assertEqual(final_message.role, "assistant") + self.assertIn("Simulated agent error", final_message.content)