Skip to content

Fix(multi-agent #20): inherit model from parent/caller in swarm, agent graph, workflow, use_llm and think tools #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ errors/
repl_state/
venv/
*.egg-info
.idea
2 changes: 2 additions & 0 deletions src/strands_tools/agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ def agent_graph(tool: ToolUse, **kwargs: Any) -> ToolResult:
inference_config = kwargs.get("inference_config")
messages = kwargs.get("messages")
tool_config = kwargs.get("tool_config")
agent = kwargs.get("agent")

try:
# Create tool context
Expand All @@ -522,6 +523,7 @@ def agent_graph(tool: ToolUse, **kwargs: Any) -> ToolResult:
"inference_config": inference_config,
"messages": messages,
"tool_config": tool_config,
"agent": agent,
}

# Get manager instance thread-safely
Expand Down
2 changes: 1 addition & 1 deletion src/strands_tools/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def memory(
next_token: Token for pagination in 'list' or 'retrieve' action (optional).
query: The search query for semantic search (required for 'retrieve' action).
min_score: Minimum relevance score threshold (0.0-1.0) for 'retrieve' action. Default is 0.4.
region_name: Optional AWS region name. If not provided, will use the AWS_REGION env variable.
region_name: Optional AWS region name. If not provided, will use the AWS_REGION env variable.
If AWS_REGION is not specified, it will default to us-west-2.

Returns:
Expand Down
2 changes: 2 additions & 0 deletions src/strands_tools/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,12 +586,14 @@ def swarm(tool: ToolUse, **kwargs: Any) -> ToolResult:
inference_config = kwargs.get("inference_config")
messages = kwargs.get("messages")
tool_config = kwargs.get("tool_config")
agent = kwargs.get("agent")
# Create tool context
tool_context = {
"system_prompt": system_prompt,
"inference_config": inference_config,
"messages": messages,
"tool_config": tool_config,
"agent": agent,
}
if "callback_handler" in kwargs:
tool_context["callback_handler"] = kwargs["callback_handler"]
Expand Down
10 changes: 9 additions & 1 deletion src/strands_tools/think.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,21 @@ def process_cycle(
# Get tools from parent agent if available
tools = []
trace_attributes = {}
extra_kwargs = {}
parent_agent = kwargs.get("agent")
if parent_agent:
tools = list(parent_agent.tool_registry.registry.values())
trace_attributes = parent_agent.trace_attributes
extra_kwargs["model"] = parent_agent.model

# Initialize the new Agent with provided parameters
agent = Agent(messages=[], tools=tools, system_prompt=custom_system_prompt, trace_attributes=trace_attributes)
agent = Agent(
messages=[],
tools=tools,
system_prompt=custom_system_prompt,
trace_attributes=trace_attributes,
**extra_kwargs,
)

# Run the agent with the provided prompt
result = agent(prompt)
Expand Down
1 change: 1 addition & 0 deletions src/strands_tools/use_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult:
tools = list(parent_agent.tool_registry.registry.values())
trace_attributes = parent_agent.trace_attributes
extra_kwargs["callback_handler"] = parent_agent.callback_handler
extra_kwargs["model"] = parent_agent.model
if "callback_handler" in kwargs:
extra_kwargs["callback_handler"] = kwargs["callback_handler"]

Expand Down
11 changes: 10 additions & 1 deletion src/strands_tools/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ def __new__(cls, tool_context: Dict[str, Any]):
def __init__(self, tool_context: Dict[str, Any]):
if not hasattr(self, "initialized"):
# Initialize core attributes
extra_kwargs = {}
parent_agent = tool_context.get("agent", None)
if parent_agent:
extra_kwargs["tools"] = list(parent_agent.tool_registry.registry.values())
extra_kwargs["trace_attributes"] = parent_agent.trace_attributes
extra_kwargs["callback_handler"] = parent_agent.callback_handler
extra_kwargs["model"] = parent_agent.model
self.system_prompt = tool_context["system_prompt"]
self.inference_config = tool_context["inference_config"]
self.messages = tool_context["messages"]
Expand All @@ -221,7 +228,7 @@ def __init__(self, tool_context: Dict[str, Any]):
self.task_executor = TaskExecutor()

# Initialize base agent for task execution
self.base_agent = Agent(system_prompt=self.system_prompt)
self.base_agent = Agent(system_prompt=self.system_prompt, **extra_kwargs)

# Start file watching if not already started
if not self._observer:
Expand Down Expand Up @@ -748,6 +755,7 @@ def workflow(tool: ToolUse, **kwargs: Any) -> ToolResult:
inference_config = kwargs.get("inference_config")
messages = kwargs.get("messages")
tool_config = kwargs.get("tool_config")
agent = kwargs.get("agent")

try:
tool_use_id = tool.get("toolUseId", str(uuid.uuid4()))
Expand All @@ -761,6 +769,7 @@ def workflow(tool: ToolUse, **kwargs: Any) -> ToolResult:
"inference_config": inference_config,
"messages": messages,
"tool_config": tool_config,
"agent": agent,
}
)

Expand Down