Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/application/use_cases/stream_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def execute(self, thread_id: str, message: str) -> AsyncGenerator[StreamEv
elif event.type == StreamEventType.MESSAGE:
final_message = Message.model_validate_json(event.data)
if final_message and final_message.structured_response is not None:
yield StreamEvent(
event = StreamEvent(
type=StreamEventType.STRUCTURED,
data=json.dumps(final_message.structured_response)
)
Expand Down
45 changes: 38 additions & 7 deletions src/infrastructure/deepagent/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,29 @@ def _build_config(self, thread_id: str) -> dict:
config["callbacks"] = callbacks
return config

def _extract_structured_response(self, messages: list) -> dict | None:
"""Extract structured_response from tool_calls in messages."""
if not messages:
return None
# Walk messages in reverse to find the most recent structured_response tool call
for msg in reversed(messages):
tool_calls = getattr(msg, "tool_calls", None) or []
if isinstance(tool_calls, list):
for tc in tool_calls:
if tc.get("name") == "structured_response":
args = tc.get("args")
if args:
if isinstance(args, dict):
return args
if isinstance(args, str):
try:
parsed = json.loads(args)
if isinstance(parsed, dict):
return parsed
except (json.JSONDecodeError, TypeError):
pass
return None

def _build_response(self, result: dict, config: dict, thinking: str | None) -> Message:
messages = result.get("messages", [])
if not messages:
Expand All @@ -74,15 +97,23 @@ def _build_response(self, result: dict, config: dict, thinking: str | None) -> M
all_tool_calls = getattr(last_message, "tool_calls", None) or []
state = self._graph.get_state(config)
status = MessageStatus.AWAITING_HITL if state.interrupts else MessageStatus.COMPLETED
structured_response = None
raw_structured = result.get("structured_response")
if raw_structured is not None:
if hasattr(raw_structured, "model_dump"):
structured_response = raw_structured.model_dump()
elif isinstance(raw_structured, dict):
structured_response = raw_structured

# 1. Try extracting structured_response from tool_calls (ToolStrategy mode)
structured_response = self._extract_structured_response(messages)

# 2. Fallback to result structured_response (ProviderStrategy/ToolStrategy native mode)
if structured_response is None:
raw_structured = result.get("structured_response")
if raw_structured is not None:
if hasattr(raw_structured, "model_dump"):
structured_response = raw_structured.model_dump()
elif isinstance(raw_structured, dict):
structured_response = raw_structured

# 3. Fallback to parsing the last message content as JSON
if structured_response is None:
structured_response = self._try_parse_json(last_message.content)

return Message(
role=MessageRole.AI,
content=last_message.content,
Expand Down
13 changes: 11 additions & 2 deletions src/infrastructure/deepagent/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from deepagents import create_deep_agent
from deepagents.backends import FilesystemBackend, StoreBackend
from langchain.agents.structured_output import ProviderStrategy
from langchain_core.tools import StructuredTool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.store.memory import InMemoryStore
Expand Down Expand Up @@ -239,7 +238,17 @@ async def create_agent_from_config(
kwargs["skills"] = config.skills

if config.response_format:
kwargs["response_format"] = ProviderStrategy(config.response_format)
# Use tool-based structured output instead of ProviderStrategy to bypass
# Bedrock schema limitations (max 16 anyOf, max 24 optionals, max grammar size).
# The structured_response tool is injected into the agent's tool list so the LLM
# sees it, but we do NOT pass response_format to avoid create_agent forcing
# tool_choice="any", which suppresses intermediate streaming messages.
# The system prompt is augmented with an instruction to use the tool.
response_tool = _create_response_tool(config.response_format)
all_tools = (all_tools or []) + [response_tool]
kwargs["tools"] = all_tools
current_prompt = kwargs.get("system_prompt", "")
kwargs["system_prompt"] = (current_prompt or "") + STRUCTURED_OUTPUT_INSTRUCTION

subagents = await _resolve_subagents(config, mcp_tool_loader, prompt_manager)
if subagents:
Expand Down
12 changes: 7 additions & 5 deletions tests/unit/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,17 @@ def test_tool_args_schema_has_properties(self):

class TestResponseFormatIntegration:
@patch("src.infrastructure.deepagent.factory.create_deep_agent")
async def test_passes_provider_strategy_when_response_format_set(self, mock_create):
async def test_injects_structured_response_tool_when_response_format_set(self, mock_create):
"""When response_format is set, inject structured_response tool and instruction."""
mock_create.return_value = MagicMock()
config = AgentConfig(name="test", response_format=WEATHER_SCHEMA)
await create_agent_from_config(config)
kwargs = mock_create.call_args.kwargs
assert "response_format" in kwargs
from langchain.agents.structured_output import ProviderStrategy

assert isinstance(kwargs["response_format"], ProviderStrategy)
assert "response_format" not in kwargs
assert "tools" in kwargs
tool_names = [t.name for t in kwargs["tools"]]
assert "structured_response" in tool_names
assert "structured_response" in kwargs.get("system_prompt", "")

@patch("src.infrastructure.deepagent.factory.create_deep_agent")
async def test_omits_response_format_when_none(self, mock_create):
Expand Down
43 changes: 19 additions & 24 deletions tests/unit/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ async def test_stream_ends_with_done(self, client):
]
assert data_lines[-1] == "[DONE]", f"Expected [DONE] as last data line, got: {data_lines[-1]}"

async def test_stream_emits_message_json_before_done(self, client):
"""Stream emits Message JSON before [DONE], even with a preceding structured event."""
async def test_stream_emits_structured_event_before_done(self, client):
"""Stream emits a STRUCTURED event (from MESSAGE with structured_response) before [DONE]."""
create_resp = await client.post("/api/v1/threads", json={"agent_name": "my-agent"})
thread_id = create_resp.json()["id"]
resp = await client.post(
Expand All @@ -443,21 +443,19 @@ async def test_stream_emits_message_json_before_done(self, client):

assert data_lines[-1] == "[DONE]"

# Find the last message event before [DONE]
message_event = None
structured_event = None
for line in reversed(data_lines[:-1]):
event = json.loads(line)
if event.get("type") == "message":
message_event = event
if event.get("type") == "structured":
structured_event = event
break

assert message_event is not None, "No message event found in stream"
message_json = json.loads(message_event["data"])
assert message_json["role"] == "ai"
assert message_json["structured_response"] == {"key": "value"}
assert structured_event is not None, "No structured event found in stream"
structured_data = json.loads(structured_event["data"])
assert structured_data == {"key": "value"}

async def test_stream_message_format_matches_sync(self, client):
"""The Message JSON from stream has the same fields as the sync endpoint."""
"""The stream yields content events and a final structured event with expected fields."""
create_resp = await client.post("/api/v1/threads", json={"agent_name": "my-agent"})
thread_id = create_resp.json()["id"]

Expand All @@ -483,20 +481,17 @@ async def test_stream_message_format_matches_sync(self, client):
if line.strip().startswith("data:")
]

# Locate the message event (ignore structured events, [DONE], etc.)
message_event = None
content_events = []
structured_event = None
for line in data_lines:
if line == "[DONE]":
continue
event = json.loads(line)
if event.get("type") == "message":
message_event = event
break

assert message_event is not None, "No message event found in stream"
message_json = json.loads(message_event["data"])

for field in ["role", "content", "timestamp", "status"]:
assert field in message_json, f"Missing field {field!r} in stream Message: {message_json}"

assert message_json["role"] == sync_data["role"]
if event.get("type") == "content":
content_events.append(event)
elif event.get("type") == "structured":
structured_event = event

assert len(content_events) > 0, "No content events found in stream"
assert structured_event is not None, "No structured event found in stream"
assert sync_data["role"] == "ai"
Loading