Skip to content

Commit a131fd0

Browse files
Fix tests
1 parent f7386ba commit a131fd0

File tree

3 files changed

+35
-26
lines changed

3 files changed

+35
-26
lines changed

src/transformers/agents/agents.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
import logging
1919
import re
2020
import time
21-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
21+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2222

2323
from .. import is_torch_available
2424
from ..utils import logging as transformers_logging
2525
from ..utils.import_utils import is_pygments_available
2626
from .agent_types import AgentAudio, AgentImage
2727
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
2828
from .llm_engine import HfApiEngine, MessageRole
29+
from .monitoring import Monitor
2930
from .prompts import (
3031
DEFAULT_CODE_SYSTEM_PROMPT,
3132
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
@@ -45,7 +46,7 @@
4546
get_tool_description_with_args,
4647
load_tool,
4748
)
48-
from .monitoring import Monitor
49+
4950

5051
if is_pygments_available():
5152
from pygments import highlight
@@ -355,7 +356,7 @@ def __init__(
355356
self,
356357
tools: Union[List[Tool], Toolbox],
357358
llm_engine: Callable = None,
358-
system_prompt:Optional[str] = None,
359+
system_prompt: Optional[str] = None,
359360
tool_description_template: Optional[str] = None,
360361
additional_args: Dict = {},
361362
max_iterations: int = 6,
@@ -810,7 +811,7 @@ def stream_run(self, task: str):
810811
step_start_time = time.time()
811812
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
812813
try:
813-
output = self.step(step_log_entry)
814+
self.step(step_log_entry)
814815
if "final_answer" in step_log_entry:
815816
final_answer = step_log_entry["final_answer"]
816817
except AgentError as e:
@@ -851,7 +852,7 @@ def direct_run(self, task: str):
851852
try:
852853
if self.planning_interval is not None and iteration % self.planning_interval == 0:
853854
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
854-
output = self.step(step_log_entry)
855+
self.step(step_log_entry)
855856
if "final_answer" in step_log_entry:
856857
final_answer = step_log_entry["final_answer"]
857858
except AgentError as e:
@@ -1218,8 +1219,10 @@ def step(self, log_entry: Dict[str, Any]):
12181219
log_entry["final_answer"] = result
12191220
return result
12201221

1222+
12211223
LENGTH_TRUNCATE_REPORTS = 1000
12221224

1225+
12231226
class ManagedAgent:
12241227
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
12251228
self.agent = agent
@@ -1263,7 +1266,11 @@ def __call__(self, request, **kwargs):
12631266
if len(str(content)) < LENGTH_TRUNCATE_REPORTS or "[FACTS LIST]" in str(content):
12641267
answer += "\n" + str(content) + "\n---"
12651268
else:
1266-
answer += "\n" + str(content)[:LENGTH_TRUNCATE_REPORTS] + "\n(...Step was truncated because too long)...\n---"
1269+
answer += (
1270+
"\n"
1271+
+ str(content)[:LENGTH_TRUNCATE_REPORTS]
1272+
+ "\n(...Step was truncated because too long)...\n---"
1273+
)
12671274
answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
12681275
return answer
12691276
else:

src/transformers/agents/llm_engine.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020

2121
from huggingface_hub import InferenceClient
2222

23-
from ..pipelines.base import Pipeline
2423
from .. import AutoTokenizer
24+
from ..pipelines.base import Pipeline
2525
from ..utils import logging
2626

27+
2728
logger = logging.get_logger(__name__)
2829

30+
2931
class MessageRole(str, Enum):
3032
USER = "user"
3133
ASSISTANT = "assistant"
@@ -69,6 +71,7 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:
6971
MessageRole.TOOL_RESPONSE: MessageRole.USER,
7072
}
7173

74+
7275
class HfEngine:
7376
def __init__(self, model_id: Optional[str] = None):
7477
self.last_input_token_count = None
@@ -84,13 +87,15 @@ def __init__(self, model_id: Optional[str] = None):
8487

8588
def get_token_counts(self):
8689
return {
87-
'input_token_count': self.last_input_token_count,
88-
'output_token_count': self.last_output_token_count,
90+
"input_token_count": self.last_input_token_count,
91+
"output_token_count": self.last_output_token_count,
8992
}
90-
91-
def generate(self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None):
93+
94+
def generate(
95+
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
96+
):
9297
raise NotImplementedError
93-
98+
9499
def __call__(
95100
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
96101
) -> str:
@@ -108,20 +113,16 @@ def __call__(
108113
response = response[: -len(stop_seq)]
109114
return response
110115

116+
111117
class HfApiEngine(HfEngine):
112118
"""This engine leverages Hugging Face's Inference API service, either serverless or with a dedicated endpoint."""
113119

114-
def __init__(
115-
self,
116-
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
117-
timeout: int = 120
118-
):
120+
def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct", timeout: int = 120):
119121
super().__init__(model_id=model)
120122
self.model = model
121123
self.timeout = timeout
122124
self.client = InferenceClient(self.model, timeout=self.timeout)
123125

124-
125126
def generate(
126127
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
127128
) -> str:
@@ -152,7 +153,7 @@ def generate(
152153
messages: List[Dict[str, str]],
153154
stop_sequences: Optional[List[str]] = None,
154155
grammar: Optional[str] = None,
155-
max_length: int = 1500
156+
max_length: int = 1500,
156157
) -> str:
157158
# Get clean message list
158159
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)

src/transformers/agents/monitoring.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17-
from .agent_types import AgentAudio, AgentImage, AgentText
1817
from ..utils import logging
18+
from .agent_types import AgentAudio, AgentImage, AgentText
19+
1920

2021
logger = logging.get_logger(__name__)
2122

23+
2224
def pull_message(step_log: dict):
2325
try:
2426
from gradio import ChatMessage
@@ -85,14 +87,13 @@ def __init__(self, tracked_llm_engine):
8587
self.total_output_token_count = 0
8688

8789
def update_metrics(self, step_log):
88-
step_duration = step_log.get('step_duration', None)
90+
step_duration = step_log.get("step_duration", None)
8991
self.step_durations.append(step_duration)
92+
logger.info(f"Step {len(self.step_durations)}:")
93+
logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")
9094

9195
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
9296
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
9397
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
94-
95-
logger.info(f"Step {len(self.step_durations)}:")
96-
logger.info(f" Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")
97-
logger.info(f" Input tokens: {self.total_input_token_count}")
98-
logger.info(f" Output tokens: {self.total_output_token_count}")
98+
logger.info(f"- Input tokens: {self.total_input_token_count}")
99+
logger.info(f"- Output tokens: {self.total_output_token_count}")

0 commit comments

Comments
 (0)