Skip to content

Commit

Permalink
Add monitoring to Agent and HfEngine children
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Oct 31, 2024
1 parent c443d8d commit f7386ba
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 83 deletions.
177 changes: 118 additions & 59 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import logging
import re
import time
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

from .. import is_torch_available
Expand Down Expand Up @@ -44,7 +45,7 @@
get_tool_description_with_args,
load_tool,
)

from .monitoring import Monitor

if is_pygments_available():
from pygments import highlight
Expand Down Expand Up @@ -353,17 +354,23 @@ class Agent:
def __init__(
self,
tools: Union[List[Tool], Toolbox],
llm_engine: Callable = HfApiEngine(),
system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template=None,
additional_args={},
llm_engine: Callable = None,
system_prompt:Optional[str] = None,
tool_description_template: Optional[str] = None,
additional_args: Dict = {},
max_iterations: int = 6,
tool_parser=parse_json_tool_call,
tool_parser: Optional[Callable] = None,
add_base_tools: bool = False,
verbose: int = 0,
grammar: Dict[str, str] = None,
managed_agents: List = None,
grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[List] = None,
step_callbacks: Optional[List[Callable]] = None,
monitor_metrics: bool = True,
):
if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
if tool_parser is None:
tool_parser = parse_json_tool_call
self.agent_name = self.__class__.__name__
self.llm_engine = llm_engine
self.system_prompt_template = system_prompt
Expand Down Expand Up @@ -406,6 +413,15 @@ def __init__(
elif verbose == 2:
logger.setLevel(logging.DEBUG)

# Initialize step callbacks
self.step_callbacks = step_callbacks if step_callbacks is not None else []

# Initialize Monitor if monitor_metrics is True
self.monitor = None
if monitor_metrics:
self.monitor = Monitor(self.llm_engine)
self.step_callbacks.append(self.monitor.update_metrics)

@property
def toolbox(self) -> Toolbox:
"""Get the toolbox currently available to the agent"""
Expand Down Expand Up @@ -578,13 +594,19 @@ class CodeAgent(Agent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__(
tools=tools,
llm_engine=llm_engine,
Expand Down Expand Up @@ -700,15 +722,24 @@ class ReactAgent(Agent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
plan_type: Literal[tuple(SUPPORTED_PLAN_TYPES)] = SUPPORTED_PLAN_TYPES[0],
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
plan_type: Optional[str] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported"
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
if plan_type is None:
plan_type = SUPPORTED_PLAN_TYPES[0]
else:
assert plan_type in SUPPORTED_PLAN_TYPES, f"plan type {plan_type} is not supported"
super().__init__(
tools=tools,
llm_engine=llm_engine,
Expand Down Expand Up @@ -776,16 +807,24 @@ def stream_run(self, task: str):
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
step_start_time = time.time()
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
try:
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
output = self.step(step_log_entry)
if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
self.logs[-1]["error"] = e
step_log_entry["error"] = e
finally:
step_end_time = time.time()
step_log_entry["step_end_time"] = step_end_time
step_log_entry["step_duration"] = step_end_time - step_start_time
self.logs.append(step_log_entry)
for callback in self.step_callbacks:
callback(step_log_entry)
iteration += 1
yield self.logs[-1]
yield step_log_entry

if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations."
Expand All @@ -794,6 +833,8 @@ def stream_run(self, task: str):
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
for callback in self.step_callbacks:
callback(final_step_log)
yield final_step_log

yield final_answer
Expand All @@ -805,16 +846,24 @@ def direct_run(self, task: str):
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
step_start_time = time.time()
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
output = self.step(step_log_entry)
if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
self.logs[-1]["error"] = e
step_log_entry["error"] = e
finally:
step_end_time = time.time()
step_log_entry["step_end_time"] = step_end_time
step_log_entry["step_duration"] = step_end_time - step_start_time
self.logs.append(step_log_entry)
for callback in self.step_callbacks:
callback(step_log_entry)
iteration += 1

if final_answer is None and iteration == self.max_iterations:
Expand All @@ -824,6 +873,8 @@ def direct_run(self, task: str):
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
for callback in self.step_callbacks:
callback(final_step_log)

return final_answer

Expand Down Expand Up @@ -937,13 +988,19 @@ class ReactJsonAgent(ReactAgent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_REACT_JSON_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__(
tools=tools,
llm_engine=llm_engine,
Expand All @@ -954,7 +1011,7 @@ def __init__(
**kwargs,
)

def step(self):
def step(self, log_entry: Dict[str, Any]):
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method.
Expand All @@ -965,9 +1022,7 @@ def step(self):
self.logger.debug("===== New step =====")

# Add new step in logs
current_step_logs = {}
self.logs.append(current_step_logs)
current_step_logs["agent_memory"] = agent_memory.copy()
log_entry["agent_memory"] = agent_memory.copy()

self.logger.info("===== Calling LLM with this last message: =====")
self.logger.info(self.prompt[-1])
Expand All @@ -981,7 +1036,7 @@ def step(self):
raise AgentGenerationError(f"Error in generating llm output: {e}.")
self.logger.debug("===== Output message of the LLM: =====")
self.logger.debug(llm_output)
current_step_logs["llm_output"] = llm_output
log_entry["llm_output"] = llm_output

# Parse
self.logger.debug("===== Extracting action =====")
Expand All @@ -992,8 +1047,8 @@ def step(self):
except Exception as e:
raise AgentParsingError(f"Could not parse the given action: {e}.")

current_step_logs["rationale"] = rationale
current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
log_entry["rationale"] = rationale
log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}

# Execute
self.logger.warning("=== Agent thoughts:")
Expand All @@ -1011,8 +1066,8 @@ def step(self):
answer = arguments
else:
answer = arguments
current_step_logs["final_answer"] = answer
return current_step_logs
log_entry["final_answer"] = answer
return answer
else:
if arguments is None:
arguments = {}
Expand All @@ -1030,8 +1085,8 @@ def step(self):
else:
updated_information = str(observation).strip()
self.logger.info(updated_information)
current_step_logs["observation"] = updated_information
return current_step_logs
log_entry["observation"] = updated_information
return log_entry


class ReactCodeAgent(ReactAgent):
Expand All @@ -1044,14 +1099,20 @@ class ReactCodeAgent(ReactAgent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
llm_engine: Optional[Callable] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
if llm_engine is None:
llm_engine = HfApiEngine()
if system_prompt is None:
system_prompt = DEFAULT_REACT_CODE_SYSTEM_PROMPT
if tool_description_template is None:
tool_description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
super().__init__(
tools=tools,
llm_engine=llm_engine,
Expand All @@ -1075,21 +1136,18 @@ def __init__(
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
self.custom_tools = {}

def step(self):
def step(self, log_entry: Dict[str, Any]):
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
The errors are raised here, they are caught and logged in the run() method.
"""
agent_memory = self.write_inner_memory_from_logs()

self.prompt = agent_memory.copy()

self.logger.debug("===== New step =====")

# Add new step in logs
current_step_logs = {}
self.logs.append(current_step_logs)
current_step_logs["agent_memory"] = agent_memory.copy()
log_entry["agent_memory"] = agent_memory.copy()

self.logger.info("===== Calling LLM with these last messages: =====")
self.logger.info(self.prompt[-2:])
Expand All @@ -1104,7 +1162,7 @@ def step(self):

self.logger.debug("=== Output message of the LLM:")
self.logger.debug(llm_output)
current_step_logs["llm_output"] = llm_output
log_entry["llm_output"] = llm_output

# Parse
self.logger.debug("=== Extracting action ===")
Expand All @@ -1120,8 +1178,8 @@ def step(self):
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
raise AgentParsingError(error_msg)

current_step_logs["rationale"] = rationale
current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
log_entry["rationale"] = rationale
log_entry["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}

# Execute
self.log_rationale_code_action(rationale, code_action)
Expand All @@ -1147,7 +1205,7 @@ def step(self):
observation = "Print outputs:\n" + self.state["print_outputs"]
if result is not None:
observation += "Last output from code snippet:\n" + str(result)[:100000]
current_step_logs["observation"] = observation
log_entry["observation"] = observation
except Exception as e:
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
if "'dict' object has no attribute 'read'" in str(e):
Expand All @@ -1157,9 +1215,10 @@ def step(self):
if line[: len("final_answer")] == "final_answer":
self.logger.log(33, "Final answer:")
self.logger.log(32, result)
current_step_logs["final_answer"] = result
return current_step_logs
log_entry["final_answer"] = result
return result

LENGTH_TRUNCATE_REPORTS = 1000

class ManagedAgent:
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
Expand Down Expand Up @@ -1201,10 +1260,10 @@ def __call__(self, request, **kwargs):
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
content = message["content"]
if len(str(content)) < 1000 or "[FACTS LIST]" in str(content):
if len(str(content)) < LENGTH_TRUNCATE_REPORTS or "[FACTS LIST]" in str(content):
answer += "\n" + str(content) + "\n---"
else:
answer += "\n" + str(content)[:1000] + "\n(...Step was truncated because too long)...\n---"
answer += "\n" + str(content)[:LENGTH_TRUNCATE_REPORTS] + "\n(...Step was truncated because too long)...\n---"
answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
return answer
else:
Expand Down
Loading

0 comments on commit f7386ba

Please sign in to comment.