Skip to content

Commit

Permalink
Added groq provider
Browse files Browse the repository at this point in the history
  • Loading branch information
Maximilian-Winter committed Jun 13, 2024
1 parent 1ce4e19 commit 4d54ca5
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 20 deletions.
23 changes: 23 additions & 0 deletions examples/01_Basics/chatbot_using_groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from llama_cpp_agent import LlamaCppAgent
from llama_cpp_agent import MessagesFormatterType
from llama_cpp_agent.providers.groq import GroqProvider

provider = GroqProvider(base_url="https://api.groq.com/openai/v1", model="mixtral-8x7b-32768", huggingface_model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key="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")

agent = LlamaCppAgent(
provider,
system_prompt="You are a helpful assistant.",
predefined_messages_formatter_type=MessagesFormatterType.MISTRAL,
)

settings = provider.get_provider_default_settings()
settings.stream = True
settings.max_tokens = 512
settings.temperature = 0.65

while True:
user_input = input(">")
if user_input == "exit":
break
agent_output = agent.get_chat_response(user_input, llm_sampling_settings=settings)
print(f"Agent: {agent_output.strip()}")
2 changes: 1 addition & 1 deletion examples/07_Memory/MemoryAssistant/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from prompts import assistant_prompt, memory_prompt, wrap_function_response_in_xml_tags_json_mode, \
generate_write_message, generate_write_message_with_examples, wrap_user_message_in_xml_tags_json_mode

provider = LlamaCppServerProvider("http://hades.hq.solidrust.net:8084")
provider = LlamaCppServerProvider("http://localhost:8080")

agent = LlamaCppAgent(
provider,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ email = "[email protected]"
agent_memory = ["chromadb", "SQLAlchemy", "numpy", "scipy"]
rag = ["ragatouille"]
vllm_provider = ["openai", "transformers", "sentencepiece", "protobuf"]
groq_provider = ["groq"]
mixtral_agent = ["mistral-common"]
web_search_summarization = ["duckduckgo_search", "trafilatura", "lxml-html-clean", "lxml", "googlesearch-python" , "beautifulsoup4", "readability-lxml"]

Expand Down
2 changes: 1 addition & 1 deletion src/llama_cpp_agent/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def pydantic_model_to_openai_function_definition(pydantic_model: Type[BaseModel]
function_definition = {
"type": "function",
"function": {
"name": pydantic_model.__name__.lower(),
"name": pydantic_model.__name__,
"description": class_description,
"parameters": {
"type": "object",
Expand Down
17 changes: 9 additions & 8 deletions src/llama_cpp_agent/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
function_calling_function_list_templater, structured_output_templater, \
structured_output_thoughts_and_reasoning_templater

from .providers.provider_base import LlmProvider, LlmSamplingSettings
from .providers.provider_base import LlmProvider, LlmSamplingSettings, LlmProviderId


class SystemPromptModulePosition(Enum):
after_system_instructions = 1
Expand Down Expand Up @@ -195,7 +196,7 @@ def stream_results():
yield out_text

return structured_output_settings.handle_structured_output(
full_response_stream
full_response_stream, provider=self.provider
)

if llm_sampling_settings.is_streaming():
Expand All @@ -219,7 +220,7 @@ def stream_results():
print("")
self.last_response = full_response
return structured_output_settings.handle_structured_output(
full_response
full_response, provider=self.provider
)
else:
full_response = ""
Expand All @@ -229,7 +230,7 @@ def stream_results():
print(full_response)
self.last_response = full_response
return structured_output_settings.handle_structured_output(
full_response
full_response, provider=self.provider
)
return "Error: No model loaded!"

Expand Down Expand Up @@ -322,7 +323,7 @@ def stream_results():
}
)
return structured_output_settings.handle_structured_output(
full_response_stream, prompt_suffix=prompt_suffix
full_response_stream, prompt_suffix=prompt_suffix, provider=self.provider
)

if self.provider:
Expand Down Expand Up @@ -358,7 +359,7 @@ def stream_results():
)

return structured_output_settings.handle_structured_output(
full_response, prompt_suffix=prompt_suffix
full_response, prompt_suffix=prompt_suffix, provider=self.provider
)
else:
text = completion["choices"][0]["text"]
Expand All @@ -377,7 +378,7 @@ def stream_results():
}
)

return structured_output_settings.handle_structured_output(text, prompt_suffix=prompt_suffix)
return structured_output_settings.handle_structured_output(text, prompt_suffix=prompt_suffix, provider=self.provider)
return "Error: No model loaded!"

def get_text_completion(
Expand Down Expand Up @@ -645,7 +646,7 @@ def get_response_role_and_completion(

return (
self.provider.create_completion(
prompt,
prompt if self.provider.get_provider_identifier() is not LlmProviderId.groq else messages,
structured_output_settings,
llm_sampling_settings,
self.messages_formatter.bos_token,
Expand Down
4 changes: 3 additions & 1 deletion src/llama_cpp_agent/llm_output_settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class LlmStructuredOutputSettings(BaseModel):
False,
description="If the output should be a tuple of the output and the generated JSON string by the LLM",
)

class Config:
arbitrary_types_allowed = True

Expand Down Expand Up @@ -616,7 +617,8 @@ def add_all_current_functions_to_heartbeat_list(self, excluded: list[str] = None
[tool.model.__name__ for tool in self.function_tools if tool.model.__name__ not in excluded]
)

def handle_structured_output(self, llm_output: str, prompt_suffix: str = None):
def handle_structured_output(self, llm_output: str, prompt_suffix: str = None, provider=None):

if self.output_raw_json_string:
return llm_output

Expand Down
209 changes: 209 additions & 0 deletions src/llama_cpp_agent/providers/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import json
from copy import copy, deepcopy
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Union

from llama_cpp_agent.llm_output_settings import (
LlmStructuredOutputSettings,
LlmStructuredOutputType,
)
from llama_cpp_agent.providers.provider_base import (
LlmProvider,
LlmProviderId,
LlmSamplingSettings,
)


@dataclass
class GroqSamplingSettings(LlmSamplingSettings):
"""
GroqSamplingSettings dataclass
"""

top_p: float = 1
temperature: float = 0.7
max_tokens: int = 16
stream: bool = False

def get_provider_identifier(self) -> LlmProviderId:
return LlmProviderId.groq

def get_additional_stop_sequences(self) -> Union[List[str], None]:
return None

def add_additional_stop_sequences(self, sequences: List[str]):
pass

def is_streaming(self):
return self.stream

@staticmethod
def load_from_dict(settings: dict) -> "GroqSamplingSettings":
"""
Load the settings from a dictionary.
Args:
settings (dict): The dictionary containing the settings.
Returns:
LlamaCppSamplingSettings: The loaded settings.
"""
return GroqSamplingSettings(**settings)

def as_dict(self) -> dict:
"""
Convert the settings to a dictionary.
Returns:
dict: The dictionary representation of the settings.
"""
return self.__dict__


class GroqProvider(LlmProvider):
def __init__(self, base_url: str, model: str, huggingface_model: str, api_key: str = None):
from openai import OpenAI
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(huggingface_model)
self.client = OpenAI(
base_url=base_url,
api_key=api_key if api_key else "xxx-xxxxxxxx",
)
self.model = model

def is_using_json_schema_constraints(self):
return True

def get_provider_identifier(self) -> LlmProviderId:
return LlmProviderId.groq

def get_provider_default_settings(self) -> GroqSamplingSettings:
return GroqSamplingSettings()

def create_completion(
self,
prompt: str | list[dict],
structured_output_settings: LlmStructuredOutputSettings,
settings: GroqSamplingSettings,
bos_token: str,
):
tools = None
if (
structured_output_settings.output_type
== LlmStructuredOutputType.function_calling
or structured_output_settings.output_type == LlmStructuredOutputType.parallel_function_calling
):
tools = [tool.to_openai_tool() for tool in structured_output_settings.function_tools]
top_p = settings.top_p
stream = settings.stream
temperature = settings.temperature
max_tokens = settings.max_tokens

settings_dict = deepcopy(settings.as_dict())
settings_dict.pop("top_p")
settings_dict.pop("stream")
settings_dict.pop("temperature")
settings_dict.pop("max_tokens")

if settings.stream:
result = self.client.chat.completions.create(
messages=prompt,
model=self.model,
extra_body=settings_dict,
tools=tools,
top_p=top_p,
stream=stream,
temperature=temperature,
max_tokens=max_tokens,
)

def generate_chunks():
for chunk in result:
if chunk.choices[0].delta.tool_calls is not None:
if tools is not None:
args = chunk.choices[0].delta.tool_calls[0].function.arguments
args_loaded = json.loads(args)
function_name = chunk.choices[0].delta.tool_calls[0].function.name
function_dict = {structured_output_settings.function_calling_name_field_name: function_name, structured_output_settings.function_calling_content: args_loaded}
yield {"choices": [{"text": json.dumps(function_dict)}]}
if chunk.choices[0].delta.content is not None:
yield {"choices": [{"text": chunk.choices[0].delta.content}]}

return generate_chunks()
else:
result = self.client.chat.completions.create(
messages=prompt,
model=self.model,
extra_body=settings_dict,
tools=tools,
top_p=top_p,
stream=stream,
temperature=temperature,
max_tokens=max_tokens,
)
if tools is not None:
args = result.choices[0].message.tool_calls[0].function.arguments
args_loaded = json.loads(args)
function_name = result.choices[0].message.tool_calls[0].function.name
function_dict = {structured_output_settings.function_calling_name_field_name: function_name, structured_output_settings.function_calling_content: args_loaded}
return {"choices": [{"text": json.dumps(function_dict)}]}
return {"choices": [{"text": result.choices[0].message.content}]}

def create_chat_completion(
self,
messages: List[Dict[str, str]],
structured_output_settings: LlmStructuredOutputSettings,
settings: GroqSamplingSettings
):
grammar = None
if (
structured_output_settings.output_type
!= LlmStructuredOutputType.no_structured_output
):
grammar = structured_output_settings.get_json_schema()

top_p = settings.top_p
stream = settings.stream
temperature = settings.temperature
max_tokens = settings.max_tokens

settings_dict = copy(settings.as_dict())
settings_dict.pop("top_p")
settings_dict.pop("stream")
settings_dict.pop("temperature")
settings_dict.pop("max_tokens")
if grammar is not None:
settings_dict["guided_json"] = grammar

if settings.stream:
result = self.client.chat.completions.create(
messages=messages,
model=self.model,
extra_body=settings_dict,
top_p=top_p,
stream=stream,
temperature=temperature,
max_tokens=max_tokens,
)

def generate_chunks():
for chunk in result:
if chunk.choices[0].delta.content is not None:
yield {"choices": [{"text": chunk.choices[0].delta.content}]}

return generate_chunks()
else:
result = self.client.chat.completions.create(
messages=messages,
model=self.model,
extra_body=settings_dict,
top_p=top_p,
stream=stream,
temperature=temperature,
max_tokens=max_tokens,
)
return {"choices": [{"text": result.choices[0].message.content}]}

def tokenize(self, prompt: str) -> list[int]:
result = self.tokenizer.encode(text=prompt)
return result
19 changes: 10 additions & 9 deletions src/llama_cpp_agent/providers/provider_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LlmProviderId(Enum):
llama_cpp_python = "llama_cpp_python"
tgi_server = "text_generation_inference"
vllm_server = "vllm"
groq = "groq"


class LlmSamplingSettings(ABC):
Expand Down Expand Up @@ -146,11 +147,11 @@ def get_provider_default_settings(self) -> LlmSamplingSettings:

@abstractmethod
def create_completion(
self,
prompt: str,
structured_output_settings: LlmStructuredOutputSettings,
settings: LlmSamplingSettings,
bos_token: str,
self,
prompt: str | list[dict],
structured_output_settings: LlmStructuredOutputSettings,
settings: LlmSamplingSettings,
bos_token: str,
):
"""
Create a completion request with the LLM provider and returns the result.
Expand All @@ -168,10 +169,10 @@ def create_completion(

@abstractmethod
def create_chat_completion(
self,
messages: List[Dict[str, str]],
structured_output_settings: LlmStructuredOutputSettings,
settings: LlmSamplingSettings
self,
messages: List[Dict[str, str]],
structured_output_settings: LlmStructuredOutputSettings,
settings: LlmSamplingSettings
):
"""
Create a chat completion request with the LLM provider and returns the result.
Expand Down

0 comments on commit 4d54ca5

Please sign in to comment.