Skip to content

Commit

Permalink
improve the readability of get_champion_model()
Browse files Browse the repository at this point in the history
improve the readability of get_champion_model()
  • Loading branch information
chjuncn committed Jan 17, 2025
1 parent c6dddef commit 1fe61a2
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions src/palimpzest/utils/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,36 @@ def get_models(include_vision: Optional[bool] = False) -> List[Model]:

return models


def get_champion_model(available_models, vision=False):
champion_model = None

# non-vision
if not vision and Model.GPT_4o in available_models:
champion_model = Model.GPT_4o
elif not vision and Model.GPT_4o_MINI in available_models:
champion_model = Model.GPT_4o_MINI
elif not vision and Model.LLAMA3 in available_models:
champion_model = Model.LLAMA3
elif not vision and Model.MIXTRAL in available_models:
champion_model = Model.MIXTRAL

# vision
elif vision and Model.GPT_4o_V in available_models:
champion_model = Model.GPT_4o_V
elif vision and Model.GPT_4o_MINI_V in available_models:
champion_model = Model.GPT_4o_MINI_V
elif vision and Model.LLAMA3_V in available_models:
champion_model = Model.LLAMA3_V

else:
raise Exception(
"No models available to create physical plans! You must set at least one of the following environment"
" variables: [OPENAI_API_KEY, TOGETHER_API_KEY, GOOGLE_API_KEY]\n"
f"available_models: {available_models}"
)

return champion_model
# The order is the priority of the model
TEXT_MODEL_PRIORITY = [
Model.GPT_4o,
Model.GPT_4o_MINI,
Model.LLAMA3,
Model.MIXTRAL
]

VISION_MODEL_PRIORITY = [
Model.GPT_4o_V,
Model.GPT_4o_MINI_V,
Model.LLAMA3_V
]
def get_champion_model(available_models, vision=False):
# Select appropriate priority list based on task
model_priority = VISION_MODEL_PRIORITY if vision else TEXT_MODEL_PRIORITY

# Return first available model from priority list
for model in model_priority:
if model in available_models:
return model
# If no suitable model found, raise informative error
task_type = "vision" if vision else "text"
raise Exception(
f"No {task_type} models available to create physical plans!\n"
"You must set at least one of the following environment variables:\n"
"[OPENAI_API_KEY, TOGETHER_API_KEY, GOOGLE_API_KEY]\n"
f"Available models: {available_models}"
)


def get_conventional_fallback_model(available_models, vision=False):
Expand Down

0 comments on commit 1fe61a2

Please sign in to comment.