diff --git a/src/transformers/agents/llm_engine.py b/src/transformers/agents/llm_engine.py index 5c36c2922fa2a1..456c6172a77cb0 100644 --- a/src/transformers/agents/llm_engine.py +++ b/src/transformers/agents/llm_engine.py @@ -68,25 +68,84 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: class HfApiEngine: - """This engine leverages Hugging Face's Inference API service, either serverless or with a dedicated endpoint.""" + """A class to interact with Hugging Face's Inference API for language model interaction. + + This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization. + + Parameters: + model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`): + The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub. + token (`str`, *optional*): + The Hugging Face API token for authentication. If not provided, the class will use the token stored in the Hugging Face CLI configuration. + max_tokens (`int`, *optional*, defaults to 1500): + The maximum number of tokens allowed in the output. + timeout (`int`, *optional*, defaults to 120): + Timeout for the API request, in seconds. + + Raises: + ValueError: + If the model name is not provided. + """ + + def __init__( + self, + model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct", + token: Optional[str] = None, + max_tokens: Optional[int] = 1500, + timeout: Optional[int] = 120, + ): + """Initialize the HfApiEngine.""" + if not model: + raise ValueError("Model name must be provided.") - def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"): self.model = model - self.client = InferenceClient(self.model, timeout=120) + self.client = InferenceClient(self.model, token=token, timeout=timeout) + self.max_tokens = max_tokens def __call__( - self, messages: List[Dict[str, str]], stop_sequences: List[str] = [], grammar: Optional[str] = None + self, + messages: List[Dict[str, str]], + stop_sequences: Optional[List[str]] = [], + grammar: Optional[str] = None, ) -> str: + """Process the input messages and return the model's response. + + This method sends a list of messages to the Hugging Face Inference API, optionally with stop sequences and grammar customization. + + Parameters: + messages (`List[Dict[str, str]]`): + A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`. + stop_sequences (`List[str]`, *optional*): + A list of strings that will stop the generation if encountered in the model's output. + grammar (`str`, *optional*): + The grammar or formatting structure to use in the model's response. + + Returns: + `str`: The text content of the model's response. + + Example: + ```python + >>> engine = HfApiEngine( + ... model="meta-llama/Meta-Llama-3.1-8B-Instruct", + ... token="your_hf_token_here", + ... max_tokens=2000 + ... ) + >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] + >>> response = engine(messages, stop_sequences=["END"]) + >>> print(response) + "Quantum mechanics is the branch of physics that studies..." + ``` + """ # Get clean message list messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) - # Get LLM output + # Send messages to the Hugging Face Inference API if grammar is not None: response = self.client.chat_completion( - messages, stop=stop_sequences, max_tokens=1500, response_format=grammar + messages, stop=stop_sequences, max_tokens=self.max_tokens, response_format=grammar ) else: - response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500) + response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens) response = response.choices[0].message.content