Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Add solar implementation and video, fix conflict #48

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ eval/diambra
results.csv
results/

.DS_Store
.DS_Store
.idea
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ https://github.com/OpenGenerativeAI/llm-colosseum/assets/19614572/79b58e26-7902-

https://github.com/OpenGenerativeAI/llm-colosseum/assets/19614572/5d3d386b-150a-48a5-8f68-7e2954ec18db

### 1 VS 1: Mistral vs Solar
#### mistral-small-latest vs solar-1-mini-chat
https://github.com/Tokkiu/llm-colosseum/assets/13414571/2a7e681d-d022-486c-9250-68fedff4b069
#### mistral-medium-latest vs solar-1-mini-chat
https://github.com/Tokkiu/llm-colosseum/assets/13414571/d0532e43-11e2-447e-b2b3-6023b9760f11
#### mistral-large-latest vs solar-1-mini-chat
https://github.com/Tokkiu/llm-colosseum/assets/13414571/4757d562-f800-40ef-8f1a-675b0e23b8ed


## A new kind of benchmark ?

Street Fighter III assesses the ability of LLMs to understand their environment and take actions based on a specific context.
Expand Down
144 changes: 144 additions & 0 deletions agent/actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Take observations and return actions for the Robot to use
"""

import os
import random
import re
import time
from typing import List, Optional
from rich import print

from loguru import logger

from agent.language_models import get_sync_client, get_provider_and_model

from .config import MOVES, META_INSTRUCTIONS, META_INSTRUCTIONS_WITH_LOWER
from .prompts import build_main_prompt, build_system_prompt


def call_llm(
context_prompt: str,
character: str,
model: str = "mistral:mistral-large-latest",
temperature: float = 0.3,
max_tokens: int = 20,
top_p: float = 1.0,
wrong_answer: Optional[str] = None,
) -> str:
"""
Get actions from the language model
context_prompt: str, the prompt to describe the situation to the LLM. Will be placed inside the main prompt template.
"""
# Get the correct provider, default is mistral
provider_name, model_name = get_provider_and_model(model)
print("Provider", provider_name, model_name)
client = get_sync_client(provider_name)

# Generate the prompts
system_prompt = build_system_prompt(
character=character, context_prompt=context_prompt
)
main_prompt = build_main_prompt()

start_time = time.time()

completion = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": main_prompt},
],
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
logger.debug(f"LLM call to {model}: {system_prompt}\n\n\n{main_prompt}")
logger.debug(f"LLM call to {model}: {time.time() - start_time} s")

llm_response = completion.choices[0].message.content.strip()

return llm_response


def get_simple_actions_from_llm(
context_prompt: str,
character: str,
model: str = "mistral:mistral-large-latest",
temperature: float = 0.1,
max_tokens: int = 20,
top_p: float = 1.0,
) -> List[int]:
"""
Get actions from the language model
context_prompt: str, the prompt to describe the situation to the LLM.
Return one action and then wait for the next observation

Will be placed inside the main prompt template.
"""
pass


def get_actions_from_llm(
context_prompt: str,
character: str,
model: str = "mistral:mistral-large-latest",
temperature: float = 0.1,
max_tokens: int = 20,
top_p: float = 1.0,
player_nb: int = 0, # 0 for unspecified, 1 for player 1, 2 for player 2
) -> List[str]:
"""
Get actions from the language model
context_prompt: str, the prompt to describe the situation to the LLM.

Will be placed inside the main prompt template.
"""

# Filter the moves that are not in the list of moves
invalid_moves = []
valid_moves = []
wrong_answer = None

# If we are in the test environment, we don't want to call the LLM
if os.getenv("DISABLE_LLM", "False") == "True":
# Choose a random int from the list of moves
logger.debug("DISABLE_LLM is True, returning a random move")
return [random.choice(list(MOVES.values()))]

while len(valid_moves) == 0:
llm_response = call_llm(
context_prompt=context_prompt,
character=character,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
wrong_answer=wrong_answer,
)

# The response is a bullet point list of moves. Use regex
matches = re.findall(r"- ([\w ]+)", llm_response)
moves = ["".join(match) for match in matches]
invalid_moves = []
valid_moves = []
for move in moves:
cleaned_move_name = move.strip().lower()
if cleaned_move_name in META_INSTRUCTIONS_WITH_LOWER.keys():
if player_nb == 1:
print(f"[red] Player {player_nb} move: {cleaned_move_name}")
elif player_nb == 2:
print(f"[green] Player {player_nb} move: {cleaned_move_name}")
valid_moves.append(cleaned_move_name)
else:
logger.debug(f"Invalid completion: {move}")
logger.debug(f"Cleaned move name: {cleaned_move_name}")
invalid_moves.append(move)

if len(invalid_moves) > 2:
logger.warning(f"Too many invalid moves: {invalid_moves}")
wrong_answer = llm_response

logger.debug(f"Next moves: {valid_moves}")

return valid_moves
14 changes: 9 additions & 5 deletions agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@

MODELS = {
"OPENAI": {
"openai:gpt-4-0125-preview",
"openai:gpt-4",
# "openai:gpt-4-0125-preview",
# "openai:gpt-4",
"openai:gpt-3.5-turbo-0125",
# "openai:gpt-3.5-turbo-instruct", # not a chat model
},
"MISTRAL": {
"mistral:mistral-small-latest",
"mistral:mistral-medium-latest",
"mistral:mistral-large-latest",
# "groq:mistral-8x6b-32768",
# "mistral:mistral-medium-latest",
# "mistral:mistral-large-latest",
# "mistral:open-mistral-7b",
# "mistral:open-mixtral-8x7b",
},
"Solar": {
"solar:solar-1-mini-chat",
},
}

Expand Down
71 changes: 71 additions & 0 deletions agent/language_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import dotenv
from typing import Tuple

dotenv.load_dotenv()

try:
from openai import AsyncOpenAI, OpenAI
except ImportError:
pass

# Check we can access the environment variables
assert os.getenv("MISTRAL_API_KEY") is not None


def get_async_client(provider: str) -> AsyncOpenAI:
"""
Provider can be "openai", "mistral" or "ollama".
"""
if provider == "openai":
return AsyncOpenAI()
if provider == "mistral":
return AsyncOpenAI(
base_url="https://api.mistral.ai/v1/", api_key=os.getenv("MISTRAL_API_KEY")
)
if provider == "solar":
return AsyncOpenAI(
base_url="https://api.upstage.ai/v1/solar", api_key=os.getenv("SOLAR_API_KEY")
)
if provider == "ollama":
return AsyncOpenAI(base_url="http://localhost:11434/v1/")
raise NotImplementedError(f"Provider {provider} is not supported.")


def get_sync_client(provider: str) -> OpenAI:
if provider == "openai":
return OpenAI()
if provider == "mistral":
return OpenAI(
base_url="https://api.mistral.ai/v1/", api_key=os.getenv("MISTRAL_API_KEY")
)
if provider == "solar":
return OpenAI(
base_url="https://api.upstage.ai/v1/solar", api_key=os.getenv("SOLAR_API_KEY")
)
if provider == "ollama":
return OpenAI(base_url="http://localhost:11434/v1/")
if provider == "groq":
return OpenAI(
base_url="https://api.groq.com/openai/v1/",
api_key=os.getenv("GROK_API_KEY"),
)
raise NotImplementedError(f"Provider {provider} is not supported.")


def get_provider_and_model(model: str) -> Tuple[str, str]:
"""
Get the provider and model from a string in the format "provider:model"
If no provider is specified, it defaults to "openai"

Args:
model (str): The model string in the format "provider:model"

Returns:
tuple: A tuple with the provider and model
"""

split_result = model.split(":")
if len(split_result) == 1:
return "openai", split_result[0]
return split_result[0], split_result[1]