Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Add solar implementation and video #45

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ https://github.com/OpenGenerativeAI/llm-colosseum/assets/19614572/79b58e26-7902-

https://github.com/OpenGenerativeAI/llm-colosseum/assets/19614572/5d3d386b-150a-48a5-8f68-7e2954ec18db


### 1 VS 1: Mistral vs Solar
#### mistral-small-latest vs solar-1-mini-chat
https://github.com/Tokkiu/llm-colosseum/assets/13414571/2a7e681d-d022-486c-9250-68fedff4b069
#### mistral-medium-latest vs solar-1-mini-chat
https://github.com/Tokkiu/llm-colosseum/assets/13414571/d0532e43-11e2-447e-b2b3-6023b9760f11
#### mistral-large-latest vs solar-1-mini-chat
https://github.com/Tokkiu/llm-colosseum/assets/13414571/4757d562-f800-40ef-8f1a-675b0e23b8ed


## A new kind of benchmark ?
Expand Down
1 change: 1 addition & 0 deletions agent/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def call_llm(
"""
# Get the correct provider, default is mistral
provider_name, model_name = get_provider_and_model(model)
print("Provider", provider_name, model_name)
client = get_sync_client(provider_name)

# Generate the prompts
Expand Down
14 changes: 9 additions & 5 deletions agent/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
MODELS = {
"OPENAI": {
"openai:gpt-4-0125-preview",
"openai:gpt-4",
# "openai:gpt-4-0125-preview",
# "openai:gpt-4",
"openai:gpt-3.5-turbo-0125",
# "openai:gpt-3.5-turbo-instruct", # not a chat model
},
"MISTRAL": {
"mistral:mistral-small-latest",
"mistral:mistral-medium-latest",
"mistral:mistral-large-latest",
# "groq:mistral-8x6b-32768",
# "mistral:mistral-medium-latest",
# "mistral:mistral-large-latest",
# "mistral:open-mistral-7b",
# "mistral:open-mixtral-8x7b",
},
"Solar": {
"solar:solar-1-mini-chat",
},
}

Expand Down
8 changes: 8 additions & 0 deletions agent/language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def get_async_client(provider: str) -> AsyncOpenAI:
return AsyncOpenAI(
base_url="https://api.mistral.ai/v1/", api_key=os.getenv("MISTRAL_API_KEY")
)
if provider == "solar":
return AsyncOpenAI(
base_url="https://api.upstage.ai/v1/solar", api_key=os.getenv("SOLAR_API_KEY")
)
if provider == "ollama":
return AsyncOpenAI(base_url="http://localhost:11434/v1/")
raise NotImplementedError(f"Provider {provider} is not supported.")
Expand All @@ -35,6 +39,10 @@ def get_sync_client(provider: str) -> OpenAI:
return OpenAI(
base_url="https://api.mistral.ai/v1/", api_key=os.getenv("MISTRAL_API_KEY")
)
if provider == "solar":
return OpenAI(
base_url="https://api.upstage.ai/v1/solar", api_key=os.getenv("SOLAR_API_KEY")
)
if provider == "ollama":
return OpenAI(base_url="http://localhost:11434/v1/")
if provider == "groq":
Expand Down
25 changes: 17 additions & 8 deletions eval/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
from agent.config import MODELS


def generate_model(openai: bool = False, mistral: bool = True):
def generate_model(openai: bool = False, mistral: bool = True, solar: bool = False):
models_available = []

for model, models in MODELS.items():
if openai and model == "OPENAI":
models_available.extend(models)
if mistral and model == "MISTRAL":
models_available.extend(models)

if solar and model == "Solar":
models_available.extend(models)
print('models:', models_available)
random.seed()
# Generate a pair of random two models
random_model = random.choice(models_available)
Expand All @@ -49,14 +51,15 @@ def __init__(
model: Optional[str] = None,
openai: bool = False,
mistral: bool = True,
solar: bool = False,
):
self.nickname = nickname
self.model = model or generate_model(openai=openai, mistral=mistral)
self.model = model or generate_model(openai=openai, mistral=mistral, solar=solar)
self.openai = False
self.mistral = True
self.robot = Robot(
action_space=None,
character="Ken",
character="Mistral",
side=0,
character_color=KEN_RED,
ennemy_color=KEN_GREEN,
Expand All @@ -74,12 +77,13 @@ def __init__(
model: Optional[str] = None,
openai: bool = False,
mistral: bool = True,
solar: bool = False,
):
self.nickname = nickname
self.model = model or generate_model(openai=openai, mistral=mistral)
self.model = model or generate_model(openai=openai, mistral=mistral, solar=solar)
self.robot = Robot(
action_space=None,
character="Ken",
character="Solar",
side=1,
character_color=KEN_GREEN,
ennemy_color=KEN_RED,
Expand Down Expand Up @@ -120,6 +124,7 @@ class Game:
render: Optional[bool] = False
splash_screen: Optional[bool] = False
save_game: Optional[bool] = False
# characters: Optional[List[str]] = ["Makoto", "Sean"]
characters: Optional[List[str]] = ["Ken", "Ken"]
outfits: Optional[List[int]] = [1, 3]
frame_shape: Optional[List[int]] = [0, 0, 0]
Expand All @@ -137,6 +142,7 @@ def __init__(
save_game: bool = False,
splash_screen: bool = False,
characters: List[str] = ["Ken", "Ken"],
# characters: List[str] = ["Makoto", "Sean"],
super_arts: List[int] = [3,3],
outfits: List[int] = [1, 3],
frame_shape: List[int] = [0, 0, 0],
Expand All @@ -145,6 +151,7 @@ def __init__(
player_2: Player2 = None,
openai: bool = False,
mistral: bool = True,
solar: bool = False,
):
"""_summary_

Expand All @@ -169,17 +176,19 @@ def __init__(
self.observation, self.info = self.env.reset(seed=self.seed)
self.openai = openai
self.mistral = mistral
self.solar = solar
print("GAME", openai, mistral, solar, player_1, player_2)
self.player_1 = (
player_1
if player_1
else Player1(nickname="Player 1", openai=self.openai, mistral=self.mistral)
else Player1(nickname="Player 1", openai=False, mistral=False, solar=True)
# else Player1(nickname="Player 1", model="grok:mixtral-8x7b-32768")
)
self.player_2 = (
player_2
if player_2
# else Player2(nickname="Player 2", model="openai:gpt-4-turbo-preview")
else Player2(nickname="Player 2", openai=self.openai, mistral=self.mistral)
else Player2(nickname="Player 2", openai=False, mistral=True, solar=False)
)

def _init_settings(self) -> EnvironmentSettingsMultiAgent:
Expand Down
2 changes: 1 addition & 1 deletion script.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main():
# Environment Settings
# Environment Settings
while True:
game = Game(render=True, save_game=True, openai=True)
game = Game(render=True, save_game=True, openai=True, solar=True, mistral=False)

game.run()

Expand Down