diff --git a/.gitignore b/.gitignore index bc116dd..327b809 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ eval/diambra results.csv results/ -.DS_Store \ No newline at end of file +.DS_Store +.idea \ No newline at end of file diff --git a/README.md b/README.md index a64beee..cf2953f 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,15 @@ https://github.com/OpenGenerativeAI/llm-colosseum/assets/19614572/79b58e26-7902- https://github.com/OpenGenerativeAI/llm-colosseum/assets/19614572/5d3d386b-150a-48a5-8f68-7e2954ec18db +### 1 VS 1: Mistral vs Solar +#### mistral-small-latest vs solar-1-mini-chat +https://github.com/Tokkiu/llm-colosseum/assets/13414571/2a7e681d-d022-486c-9250-68fedff4b069 +#### mistral-medium-latest vs solar-1-mini-chat +https://github.com/Tokkiu/llm-colosseum/assets/13414571/d0532e43-11e2-447e-b2b3-6023b9760f11 +#### mistral-large-latest vs solar-1-mini-chat +https://github.com/Tokkiu/llm-colosseum/assets/13414571/4757d562-f800-40ef-8f1a-675b0e23b8ed + + ## A new kind of benchmark ? Street Fighter III assesses the ability of LLMs to understand their environment and take actions based on a specific context. diff --git a/agent/actions.py b/agent/actions.py new file mode 100644 index 0000000..8b24ab0 --- /dev/null +++ b/agent/actions.py @@ -0,0 +1,144 @@ +""" +Take observations and return actions for the Robot to use +""" + +import os +import random +import re +import time +from typing import List, Optional +from rich import print + +from loguru import logger + +from agent.language_models import get_sync_client, get_provider_and_model + +from .config import MOVES, META_INSTRUCTIONS, META_INSTRUCTIONS_WITH_LOWER +from .prompts import build_main_prompt, build_system_prompt + + +def call_llm( + context_prompt: str, + character: str, + model: str = "mistral:mistral-large-latest", + temperature: float = 0.3, + max_tokens: int = 20, + top_p: float = 1.0, + wrong_answer: Optional[str] = None, +) -> str: + """ + Get actions from the language model + context_prompt: str, the prompt to describe the situation to the LLM. Will be placed inside the main prompt template. + """ + # Get the correct provider, default is mistral + provider_name, model_name = get_provider_and_model(model) + print("Provider", provider_name, model_name) + client = get_sync_client(provider_name) + + # Generate the prompts + system_prompt = build_system_prompt( + character=character, context_prompt=context_prompt + ) + main_prompt = build_main_prompt() + + start_time = time.time() + + completion = client.chat.completions.create( + model=model_name, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": main_prompt}, + ], + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ) + logger.debug(f"LLM call to {model}: {system_prompt}\n\n\n{main_prompt}") + logger.debug(f"LLM call to {model}: {time.time() - start_time} s") + + llm_response = completion.choices[0].message.content.strip() + + return llm_response + + +def get_simple_actions_from_llm( + context_prompt: str, + character: str, + model: str = "mistral:mistral-large-latest", + temperature: float = 0.1, + max_tokens: int = 20, + top_p: float = 1.0, +) -> List[int]: + """ + Get actions from the language model + context_prompt: str, the prompt to describe the situation to the LLM. + Return one action and then wait for the next observation + + Will be placed inside the main prompt template. + """ + pass + + +def get_actions_from_llm( + context_prompt: str, + character: str, + model: str = "mistral:mistral-large-latest", + temperature: float = 0.1, + max_tokens: int = 20, + top_p: float = 1.0, + player_nb: int = 0, # 0 for unspecified, 1 for player 1, 2 for player 2 +) -> List[str]: + """ + Get actions from the language model + context_prompt: str, the prompt to describe the situation to the LLM. + + Will be placed inside the main prompt template. + """ + + # Filter the moves that are not in the list of moves + invalid_moves = [] + valid_moves = [] + wrong_answer = None + + # If we are in the test environment, we don't want to call the LLM + if os.getenv("DISABLE_LLM", "False") == "True": + # Choose a random int from the list of moves + logger.debug("DISABLE_LLM is True, returning a random move") + return [random.choice(list(MOVES.values()))] + + while len(valid_moves) == 0: + llm_response = call_llm( + context_prompt=context_prompt, + character=character, + model=model, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + wrong_answer=wrong_answer, + ) + + # The response is a bullet point list of moves. Use regex + matches = re.findall(r"- ([\w ]+)", llm_response) + moves = ["".join(match) for match in matches] + invalid_moves = [] + valid_moves = [] + for move in moves: + cleaned_move_name = move.strip().lower() + if cleaned_move_name in META_INSTRUCTIONS_WITH_LOWER.keys(): + if player_nb == 1: + print(f"[red] Player {player_nb} move: {cleaned_move_name}") + elif player_nb == 2: + print(f"[green] Player {player_nb} move: {cleaned_move_name}") + valid_moves.append(cleaned_move_name) + else: + logger.debug(f"Invalid completion: {move}") + logger.debug(f"Cleaned move name: {cleaned_move_name}") + invalid_moves.append(move) + + if len(invalid_moves) > 2: + logger.warning(f"Too many invalid moves: {invalid_moves}") + wrong_answer = llm_response + + logger.debug(f"Next moves: {valid_moves}") + + return valid_moves diff --git a/agent/config.py b/agent/config.py index 25c26ca..b98f9f2 100644 --- a/agent/config.py +++ b/agent/config.py @@ -4,16 +4,20 @@ MODELS = { "OPENAI": { - "openai:gpt-4-0125-preview", - "openai:gpt-4", + # "openai:gpt-4-0125-preview", + # "openai:gpt-4", "openai:gpt-3.5-turbo-0125", # "openai:gpt-3.5-turbo-instruct", # not a chat model }, "MISTRAL": { "mistral:mistral-small-latest", - "mistral:mistral-medium-latest", - "mistral:mistral-large-latest", - # "groq:mistral-8x6b-32768", + # "mistral:mistral-medium-latest", + # "mistral:mistral-large-latest", + # "mistral:open-mistral-7b", + # "mistral:open-mixtral-8x7b", + }, + "Solar": { + "solar:solar-1-mini-chat", }, } diff --git a/agent/language_models.py b/agent/language_models.py new file mode 100644 index 0000000..52d3eea --- /dev/null +++ b/agent/language_models.py @@ -0,0 +1,71 @@ +import os +import dotenv +from typing import Tuple + +dotenv.load_dotenv() + +try: + from openai import AsyncOpenAI, OpenAI +except ImportError: + pass + +# Check we can access the environment variables +assert os.getenv("MISTRAL_API_KEY") is not None + + +def get_async_client(provider: str) -> AsyncOpenAI: + """ + Provider can be "openai", "mistral" or "ollama". + """ + if provider == "openai": + return AsyncOpenAI() + if provider == "mistral": + return AsyncOpenAI( + base_url="https://api.mistral.ai/v1/", api_key=os.getenv("MISTRAL_API_KEY") + ) + if provider == "solar": + return AsyncOpenAI( + base_url="https://api.upstage.ai/v1/solar", api_key=os.getenv("SOLAR_API_KEY") + ) + if provider == "ollama": + return AsyncOpenAI(base_url="http://localhost:11434/v1/") + raise NotImplementedError(f"Provider {provider} is not supported.") + + +def get_sync_client(provider: str) -> OpenAI: + if provider == "openai": + return OpenAI() + if provider == "mistral": + return OpenAI( + base_url="https://api.mistral.ai/v1/", api_key=os.getenv("MISTRAL_API_KEY") + ) + if provider == "solar": + return OpenAI( + base_url="https://api.upstage.ai/v1/solar", api_key=os.getenv("SOLAR_API_KEY") + ) + if provider == "ollama": + return OpenAI(base_url="http://localhost:11434/v1/") + if provider == "groq": + return OpenAI( + base_url="https://api.groq.com/openai/v1/", + api_key=os.getenv("GROK_API_KEY"), + ) + raise NotImplementedError(f"Provider {provider} is not supported.") + + +def get_provider_and_model(model: str) -> Tuple[str, str]: + """ + Get the provider and model from a string in the format "provider:model" + If no provider is specified, it defaults to "openai" + + Args: + model (str): The model string in the format "provider:model" + + Returns: + tuple: A tuple with the provider and model + """ + + split_result = model.split(":") + if len(split_result) == 1: + return "openai", split_result[0] + return split_result[0], split_result[1]