Skip to content

Commit

Permalink
Update llm_engine.py (#33332)
Browse files Browse the repository at this point in the history
* Update llm_engine.py
- Added support for optional token and max_tokens parameters in the constructor.
- Provided usage examples and detailed documentation for each method.
  • Loading branch information
louisbrulenaudet authored Nov 10, 2024
1 parent 768f3c0 commit 134ba90
Showing 1 changed file with 66 additions and 7 deletions.
73 changes: 66 additions & 7 deletions src/transformers/agents/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,84 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:


class HfApiEngine:
"""This engine leverages Hugging Face's Inference API service, either serverless or with a dedicated endpoint."""
"""A class to interact with Hugging Face's Inference API for language model interaction.
This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
Parameters:
model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3.1-8B-Instruct"`):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
token (`str`, *optional*):
The Hugging Face API token for authentication. If not provided, the class will use the token stored in the Hugging Face CLI configuration.
max_tokens (`int`, *optional*, defaults to 1500):
The maximum number of tokens allowed in the output.
timeout (`int`, *optional*, defaults to 120):
Timeout for the API request, in seconds.
Raises:
ValueError:
If the model name is not provided.
"""

def __init__(
self,
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
token: Optional[str] = None,
max_tokens: Optional[int] = 1500,
timeout: Optional[int] = 120,
):
"""Initialize the HfApiEngine."""
if not model:
raise ValueError("Model name must be provided.")

def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"):
self.model = model
self.client = InferenceClient(self.model, timeout=120)
self.client = InferenceClient(self.model, token=token, timeout=timeout)
self.max_tokens = max_tokens

def __call__(
self, messages: List[Dict[str, str]], stop_sequences: List[str] = [], grammar: Optional[str] = None
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = [],
grammar: Optional[str] = None,
) -> str:
"""Process the input messages and return the model's response.
This method sends a list of messages to the Hugging Face Inference API, optionally with stop sequences and grammar customization.
Parameters:
messages (`List[Dict[str, str]]`):
A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`.
stop_sequences (`List[str]`, *optional*):
A list of strings that will stop the generation if encountered in the model's output.
grammar (`str`, *optional*):
The grammar or formatting structure to use in the model's response.
Returns:
`str`: The text content of the model's response.
Example:
```python
>>> engine = HfApiEngine(
... model="meta-llama/Meta-Llama-3.1-8B-Instruct",
... token="your_hf_token_here",
... max_tokens=2000
... )
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
>>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
```
"""
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)

# Get LLM output
# Send messages to the Hugging Face Inference API
if grammar is not None:
response = self.client.chat_completion(
messages, stop=stop_sequences, max_tokens=1500, response_format=grammar
messages, stop=stop_sequences, max_tokens=self.max_tokens, response_format=grammar
)
else:
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=self.max_tokens)

response = response.choices[0].message.content

Expand Down

0 comments on commit 134ba90

Please sign in to comment.