diff --git a/agentstack/cli/__init__.py b/agentstack/cli/__init__.py index 3c35ec37..afd42af5 100644 --- a/agentstack/cli/__init__.py +++ b/agentstack/cli/__init__.py @@ -1 +1 @@ -from .cli import init_project_builder, list_tools +from .cli import init_project_builder, list_tools, configure_default_model diff --git a/agentstack/cli/cli.py b/agentstack/cli/cli.py index 9b560d16..f10866b3 100644 --- a/agentstack/cli/cli.py +++ b/agentstack/cli/cli.py @@ -16,10 +16,18 @@ from .agentstack_data import FrameworkData, ProjectMetadata, ProjectStructure, CookiecutterData from agentstack.logger import log from agentstack.utils import get_package_path +from agentstack.generation.files import ConfigFile from agentstack.generation.tool_generation import get_all_tools from .. import generation from ..utils import open_json_file, term_color, is_snake_case +PREFERRED_MODELS = [ + 'openai/gpt-4o', + 'anthropic/claude-3-5-sonnet', + 'openai/o1-preview', + 'openai/gpt-4-turbo', + 'anthropic/claude-3-opus', +] def init_project_builder(slug_name: Optional[str] = None, template: Optional[str] = None, use_wizard: bool = False): if slug_name and not is_snake_case(slug_name): @@ -114,6 +122,27 @@ def welcome_message(): print(border) +def configure_default_model(path: Optional[str] = None): + """Set the default model""" + agentstack_config = ConfigFile(path) + if agentstack_config.default_model: + return # Default model already set + + print("Project does not have a default model configured.") + other_msg = f"Other (enter a model name)" + model = inquirer.list_input( + message="Which model would you like to use?", + choices=PREFERRED_MODELS + [other_msg], + ) + + if model == other_msg: # If the user selects "Other", prompt for a model name + print(f'A list of available models is available at: "https://docs.litellm.ai/docs/providers"') + model = inquirer.text(message="Enter the model name") + + with ConfigFile(path) as agentstack_config: + agentstack_config.default_model = model + + def ask_framework() -> str: framework = "CrewAI" # framework = inquirer.list_input( diff --git a/agentstack/generation/agent_generation.py b/agentstack/generation/agent_generation.py index f13a5d9f..bf64dd2e 100644 --- a/agentstack/generation/agent_generation.py +++ b/agentstack/generation/agent_generation.py @@ -2,6 +2,7 @@ from .gen_utils import insert_code_after_tag, get_crew_components, CrewComponent from agentstack.utils import verify_agentstack_project, get_framework +from agentstack.generation.files import ConfigFile import os from ruamel.yaml import YAML from ruamel.yaml.scalarstring import FoldedScalarString @@ -14,6 +15,7 @@ def generate_agent( backstory: Optional[str], llm: Optional[str] ): + agentstack_config = ConfigFile() # TODO path if not role: role = 'Add your role here' if not goal: @@ -21,7 +23,7 @@ def generate_agent( if not backstory: backstory = 'Add your backstory here' if not llm: - llm = 'openai/gpt-4o' + llm = agentstack_config.default_model verify_agentstack_project() @@ -37,9 +39,6 @@ def generate_agent( print(f"Added agent \"{name}\" to your AgentStack project successfully!") - - - def generate_crew_agent( name, role: Optional[str] = 'Add your role here', diff --git a/agentstack/generation/files.py b/agentstack/generation/files.py index 0fc1fb14..b1c226c3 100644 --- a/agentstack/generation/files.py +++ b/agentstack/generation/files.py @@ -31,10 +31,13 @@ class ConfigFile(BaseModel): A list of tools that are currently installed in the project. telemetry_opt_out: Optional[bool] Whether the user has opted out of telemetry. + default_model: Optional[str] + The default model to use when generating agent configurations. """ framework: Optional[str] = DEFAULT_FRAMEWORK tools: list[str] = [] telemetry_opt_out: Optional[bool] = None + default_model: Optional[str] = None def __init__(self, path: Union[str, Path, None] = None): path = Path(path) if path else Path.cwd() diff --git a/agentstack/main.py b/agentstack/main.py index 14a448cf..77a7ed7f 100644 --- a/agentstack/main.py +++ b/agentstack/main.py @@ -2,7 +2,7 @@ import os import sys -from agentstack.cli import init_project_builder, list_tools +from agentstack.cli import init_project_builder, list_tools, configure_default_model from agentstack.telemetry import track_cli_command from agentstack.utils import get_version, get_framework import agentstack.generation as generation @@ -102,6 +102,8 @@ def main(): os.system('python src/main.py') elif args.command in ['generate', 'g']: if args.generate_command in ['agent', 'a']: + if not args.llm: + configure_default_model() generation.generate_agent(args.name, args.role, args.goal, args.backstory, args.llm) elif args.generate_command in ['task', 't']: generation.generate_task(args.name, args.description, args.expected_output, args.agent) diff --git a/tests/test_generation_files.py b/tests/test_generation_files.py index 8f8549e3..e2d80d7e 100644 --- a/tests/test_generation_files.py +++ b/tests/test_generation_files.py @@ -14,6 +14,7 @@ def test_read_config(self): assert config.framework == "crewai" assert config.tools == ["tool1", "tool2"] assert config.telemetry_opt_out is None + assert config.default_model is None def test_write_config(self): try: @@ -25,6 +26,7 @@ def test_write_config(self): config.framework = "crewai" config.tools = ["tool1", "tool2"] config.telemetry_opt_out = True + config.default_model = "openai/gpt-4o" tmp_data = open(BASE_PATH/"tmp/agentstack.json").read() assert tmp_data == """{ @@ -33,7 +35,8 @@ def test_write_config(self): "tool1", "tool2" ], - "telemetry_opt_out": true + "telemetry_opt_out": true, + "default_model": "openai/gpt-4o" }""" except Exception as e: raise e