Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Nov 11, 2024
1 parent f7386ba commit a131fd0
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 26 deletions.
19 changes: 13 additions & 6 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
import logging
import re
import time
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from .. import is_torch_available
from ..utils import logging as transformers_logging
from ..utils.import_utils import is_pygments_available
from .agent_types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfApiEngine, MessageRole
from .monitoring import Monitor
from .prompts import (
DEFAULT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
Expand All @@ -45,7 +46,7 @@
get_tool_description_with_args,
load_tool,
)
from .monitoring import Monitor


if is_pygments_available():
from pygments import highlight
Expand Down Expand Up @@ -355,7 +356,7 @@ def __init__(
self,
tools: Union[List[Tool], Toolbox],
llm_engine: Callable = None,
system_prompt:Optional[str] = None,
system_prompt: Optional[str] = None,
tool_description_template: Optional[str] = None,
additional_args: Dict = {},
max_iterations: int = 6,
Expand Down Expand Up @@ -810,7 +811,7 @@ def stream_run(self, task: str):
step_start_time = time.time()
step_log_entry = {"iteration": iteration, "start_time": step_start_time}
try:
output = self.step(step_log_entry)
self.step(step_log_entry)
if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"]
except AgentError as e:
Expand Down Expand Up @@ -851,7 +852,7 @@ def direct_run(self, task: str):
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
output = self.step(step_log_entry)
self.step(step_log_entry)
if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"]
except AgentError as e:
Expand Down Expand Up @@ -1218,8 +1219,10 @@ def step(self, log_entry: Dict[str, Any]):
log_entry["final_answer"] = result
return result


LENGTH_TRUNCATE_REPORTS = 1000


class ManagedAgent:
def __init__(self, agent, name, description, additional_prompting=None, provide_run_summary=False):
self.agent = agent
Expand Down Expand Up @@ -1263,7 +1266,11 @@ def __call__(self, request, **kwargs):
if len(str(content)) < LENGTH_TRUNCATE_REPORTS or "[FACTS LIST]" in str(content):
answer += "\n" + str(content) + "\n---"
else:
answer += "\n" + str(content)[:LENGTH_TRUNCATE_REPORTS] + "\n(...Step was truncated because too long)...\n---"
answer += (
"\n"
+ str(content)[:LENGTH_TRUNCATE_REPORTS]
+ "\n(...Step was truncated because too long)...\n---"
)
answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
return answer
else:
Expand Down
27 changes: 14 additions & 13 deletions src/transformers/agents/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@

from huggingface_hub import InferenceClient

from ..pipelines.base import Pipeline
from .. import AutoTokenizer
from ..pipelines.base import Pipeline
from ..utils import logging


logger = logging.get_logger(__name__)


class MessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
Expand Down Expand Up @@ -69,6 +71,7 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}


class HfEngine:
def __init__(self, model_id: Optional[str] = None):
self.last_input_token_count = None
Expand All @@ -84,13 +87,15 @@ def __init__(self, model_id: Optional[str] = None):

def get_token_counts(self):
return {
'input_token_count': self.last_input_token_count,
'output_token_count': self.last_output_token_count,
"input_token_count": self.last_input_token_count,
"output_token_count": self.last_output_token_count,
}

def generate(self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None):

def generate(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
):
raise NotImplementedError

def __call__(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
) -> str:
Expand All @@ -108,20 +113,16 @@ def __call__(
response = response[: -len(stop_seq)]
return response


class HfApiEngine(HfEngine):
"""This engine leverages Hugging Face's Inference API service, either serverless or with a dedicated endpoint."""

def __init__(
self,
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
timeout: int = 120
):
def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct", timeout: int = 120):
super().__init__(model_id=model)
self.model = model
self.timeout = timeout
self.client = InferenceClient(self.model, timeout=self.timeout)


def generate(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
) -> str:
Expand Down Expand Up @@ -152,7 +153,7 @@ def generate(
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_length: int = 1500
max_length: int = 1500,
) -> str:
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/agents/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .agent_types import AgentAudio, AgentImage, AgentText
from ..utils import logging
from .agent_types import AgentAudio, AgentImage, AgentText


logger = logging.get_logger(__name__)


def pull_message(step_log: dict):
try:
from gradio import ChatMessage
Expand Down Expand Up @@ -85,14 +87,13 @@ def __init__(self, tracked_llm_engine):
self.total_output_token_count = 0

def update_metrics(self, step_log):
step_duration = step_log.get('step_duration', None)
step_duration = step_log.get("step_duration", None)
self.step_durations.append(step_duration)
logger.info(f"Step {len(self.step_durations)}:")
logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")

if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count

logger.info(f"Step {len(self.step_durations)}:")
logger.info(f" Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")
logger.info(f" Input tokens: {self.total_input_token_count}")
logger.info(f" Output tokens: {self.total_output_token_count}")
logger.info(f"- Input tokens: {self.total_input_token_count}")
logger.info(f"- Output tokens: {self.total_output_token_count}")

0 comments on commit a131fd0

Please sign in to comment.