Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agentstack/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .cli import init_project_builder, list_tools
from .cli import init_project_builder, list_tools, configure_default_model
29 changes: 29 additions & 0 deletions agentstack/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,18 @@
from .agentstack_data import FrameworkData, ProjectMetadata, ProjectStructure, CookiecutterData
from agentstack.logger import log
from agentstack.utils import get_package_path
from agentstack.generation.files import ConfigFile
from agentstack.generation.tool_generation import get_all_tools
from .. import generation
from ..utils import open_json_file, term_color, is_snake_case

PREFERRED_MODELS = [
'openai/gpt-4o',
'anthropic/claude-3-5-sonnet',
'openai/o1-preview',
'openai/gpt-4-turbo',
'anthropic/claude-3-opus',
]

def init_project_builder(slug_name: Optional[str] = None, template: Optional[str] = None, use_wizard: bool = False):
if slug_name and not is_snake_case(slug_name):
Expand Down Expand Up @@ -114,6 +122,27 @@ def welcome_message():
print(border)


def configure_default_model(path: Optional[str] = None):
"""Set the default model"""
agentstack_config = ConfigFile(path)
if agentstack_config.default_model:
return # Default model already set

print("Project does not have a default model configured.")
other_msg = f"Other (enter a model name)"
model = inquirer.list_input(
message="Which model would you like to use?",
choices=PREFERRED_MODELS + [other_msg],
)

if model == other_msg: # If the user selects "Other", prompt for a model name
print(f'A list of available models is available at: "https://docs.litellm.ai/docs/providers"')
model = inquirer.text(message="Enter the model name")

with ConfigFile(path) as agentstack_config:
agentstack_config.default_model = model


def ask_framework() -> str:
framework = "CrewAI"
# framework = inquirer.list_input(
Expand Down
7 changes: 3 additions & 4 deletions agentstack/generation/agent_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .gen_utils import insert_code_after_tag, get_crew_components, CrewComponent
from agentstack.utils import verify_agentstack_project, get_framework
from agentstack.generation.files import ConfigFile
import os
from ruamel.yaml import YAML
from ruamel.yaml.scalarstring import FoldedScalarString
Expand All @@ -14,14 +15,15 @@ def generate_agent(
backstory: Optional[str],
llm: Optional[str]
):
agentstack_config = ConfigFile() # TODO path
if not role:
role = 'Add your role here'
if not goal:
goal = 'Add your goal here'
if not backstory:
backstory = 'Add your backstory here'
if not llm:
llm = 'openai/gpt-4o'
llm = agentstack_config.default_model

verify_agentstack_project()

Expand All @@ -37,9 +39,6 @@ def generate_agent(
print(f"Added agent \"{name}\" to your AgentStack project successfully!")





def generate_crew_agent(
name,
role: Optional[str] = 'Add your role here',
Expand Down
3 changes: 3 additions & 0 deletions agentstack/generation/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ class ConfigFile(BaseModel):
A list of tools that are currently installed in the project.
telemetry_opt_out: Optional[bool]
Whether the user has opted out of telemetry.
default_model: Optional[str]
The default model to use when generating agent configurations.
"""
framework: Optional[str] = DEFAULT_FRAMEWORK
tools: list[str] = []
telemetry_opt_out: Optional[bool] = None
default_model: Optional[str] = None

def __init__(self, path: Union[str, Path, None] = None):
path = Path(path) if path else Path.cwd()
Expand Down
4 changes: 3 additions & 1 deletion agentstack/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import sys

from agentstack.cli import init_project_builder, list_tools
from agentstack.cli import init_project_builder, list_tools, configure_default_model
from agentstack.telemetry import track_cli_command
from agentstack.utils import get_version, get_framework
import agentstack.generation as generation
Expand Down Expand Up @@ -102,6 +102,8 @@ def main():
os.system('python src/main.py')
elif args.command in ['generate', 'g']:
if args.generate_command in ['agent', 'a']:
if not args.llm:
configure_default_model()
generation.generate_agent(args.name, args.role, args.goal, args.backstory, args.llm)
elif args.generate_command in ['task', 't']:
generation.generate_task(args.name, args.description, args.expected_output, args.agent)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_generation_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def test_read_config(self):
assert config.framework == "crewai"
assert config.tools == ["tool1", "tool2"]
assert config.telemetry_opt_out is None
assert config.default_model is None

def test_write_config(self):
try:
Expand All @@ -25,6 +26,7 @@ def test_write_config(self):
config.framework = "crewai"
config.tools = ["tool1", "tool2"]
config.telemetry_opt_out = True
config.default_model = "openai/gpt-4o"

tmp_data = open(BASE_PATH/"tmp/agentstack.json").read()
assert tmp_data == """{
Expand All @@ -33,7 +35,8 @@ def test_write_config(self):
"tool1",
"tool2"
],
"telemetry_opt_out": true
"telemetry_opt_out": true,
"default_model": "openai/gpt-4o"
}"""
except Exception as e:
raise e
Expand Down
Loading