|
18 | 18 | import logging |
19 | 19 | import time |
20 | 20 | from dataclasses import dataclass, field |
21 | | -from typing import Any, Callable, Tuple |
| 21 | +from typing import Any, Callable, Tuple, cast |
22 | 22 |
|
23 | 23 | from opentelemetry import trace as trace_api |
24 | 24 |
|
@@ -127,7 +127,7 @@ def _validate_json_serializable(self, value: Any) -> None: |
127 | 127 | class SwarmState: |
128 | 128 | """Current state of swarm execution.""" |
129 | 129 |
|
130 | | - current_node: SwarmNode # The agent currently executing |
| 130 | + current_node: SwarmNode | None # The agent currently executing |
131 | 131 | task: str | list[ContentBlock] # The original task from the user that is being executed |
132 | 132 | completion_status: Status = Status.PENDING # Current swarm execution status |
133 | 133 | shared_context: SharedContext = field(default_factory=SharedContext) # Context shared between agents |
@@ -232,7 +232,7 @@ def __init__( |
232 | 232 | self.shared_context = SharedContext() |
233 | 233 | self.nodes: dict[str, SwarmNode] = {} |
234 | 234 | self.state = SwarmState( |
235 | | - current_node=SwarmNode("", Agent()), # Placeholder, will be set properly |
| 235 | + current_node=None, # Placeholder, will be set properly |
236 | 236 | task="", |
237 | 237 | completion_status=Status.PENDING, |
238 | 238 | ) |
@@ -291,7 +291,8 @@ async def invoke_async( |
291 | 291 | span = self.tracer.start_multiagent_span(task, "swarm") |
292 | 292 | with trace_api.use_span(span, end_on_exit=True): |
293 | 293 | try: |
294 | | - logger.debug("current_node=<%s> | starting swarm execution with node", self.state.current_node.node_id) |
| 294 | + current_node = cast(SwarmNode, self.state.current_node) |
| 295 | + logger.debug("current_node=<%s> | starting swarm execution with node", current_node.node_id) |
295 | 296 | logger.debug( |
296 | 297 | "max_handoffs=<%d>, max_iterations=<%d>, timeout=<%s>s | swarm execution config", |
297 | 298 | self.max_handoffs, |
@@ -438,7 +439,7 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st |
438 | 439 | return |
439 | 440 |
|
440 | 441 | # Update swarm state |
441 | | - previous_agent = self.state.current_node |
| 442 | + previous_agent = cast(SwarmNode, self.state.current_node) |
442 | 443 | self.state.current_node = target_node |
443 | 444 |
|
444 | 445 | # Store handoff message for the target agent |
|
0 commit comments