diff --git a/src/application/use_cases/stream_message.py b/src/application/use_cases/stream_message.py index 8092dbb..315774f 100644 --- a/src/application/use_cases/stream_message.py +++ b/src/application/use_cases/stream_message.py @@ -42,7 +42,7 @@ async def execute(self, thread_id: str, message: str) -> AsyncGenerator[StreamEv elif event.type == StreamEventType.MESSAGE: final_message = Message.model_validate_json(event.data) if final_message and final_message.structured_response is not None: - yield StreamEvent( + event = StreamEvent( type=StreamEventType.STRUCTURED, data=json.dumps(final_message.structured_response) ) diff --git a/src/infrastructure/deepagent/adapter.py b/src/infrastructure/deepagent/adapter.py index 2ee6292..5d6203c 100644 --- a/src/infrastructure/deepagent/adapter.py +++ b/src/infrastructure/deepagent/adapter.py @@ -66,6 +66,29 @@ def _build_config(self, thread_id: str) -> dict: config["callbacks"] = callbacks return config + def _extract_structured_response(self, messages: list) -> dict | None: + """Extract structured_response from tool_calls in messages.""" + if not messages: + return None + # Walk messages in reverse to find the most recent structured_response tool call + for msg in reversed(messages): + tool_calls = getattr(msg, "tool_calls", None) or [] + if isinstance(tool_calls, list): + for tc in tool_calls: + if tc.get("name") == "structured_response": + args = tc.get("args") + if args: + if isinstance(args, dict): + return args + if isinstance(args, str): + try: + parsed = json.loads(args) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, TypeError): + pass + return None + def _build_response(self, result: dict, config: dict, thinking: str | None) -> Message: messages = result.get("messages", []) if not messages: @@ -74,15 +97,23 @@ def _build_response(self, result: dict, config: dict, thinking: str | None) -> M all_tool_calls = getattr(last_message, "tool_calls", None) or [] state = self._graph.get_state(config) status = MessageStatus.AWAITING_HITL if state.interrupts else MessageStatus.COMPLETED - structured_response = None - raw_structured = result.get("structured_response") - if raw_structured is not None: - if hasattr(raw_structured, "model_dump"): - structured_response = raw_structured.model_dump() - elif isinstance(raw_structured, dict): - structured_response = raw_structured + + # 1. Try extracting structured_response from tool_calls (ToolStrategy mode) + structured_response = self._extract_structured_response(messages) + + # 2. Fallback to result structured_response (ProviderStrategy/ToolStrategy native mode) + if structured_response is None: + raw_structured = result.get("structured_response") + if raw_structured is not None: + if hasattr(raw_structured, "model_dump"): + structured_response = raw_structured.model_dump() + elif isinstance(raw_structured, dict): + structured_response = raw_structured + + # 3. Fallback to parsing the last message content as JSON if structured_response is None: structured_response = self._try_parse_json(last_message.content) + return Message( role=MessageRole.AI, content=last_message.content, diff --git a/src/infrastructure/deepagent/factory.py b/src/infrastructure/deepagent/factory.py index bbbe017..7cc3a0e 100644 --- a/src/infrastructure/deepagent/factory.py +++ b/src/infrastructure/deepagent/factory.py @@ -6,7 +6,6 @@ from deepagents import create_deep_agent from deepagents.backends import FilesystemBackend, StoreBackend -from langchain.agents.structured_output import ProviderStrategy from langchain_core.tools import StructuredTool from langgraph.checkpoint.memory import MemorySaver from langgraph.store.memory import InMemoryStore @@ -239,7 +238,17 @@ async def create_agent_from_config( kwargs["skills"] = config.skills if config.response_format: - kwargs["response_format"] = ProviderStrategy(config.response_format) + # Use tool-based structured output instead of ProviderStrategy to bypass + # Bedrock schema limitations (max 16 anyOf, max 24 optionals, max grammar size). + # The structured_response tool is injected into the agent's tool list so the LLM + # sees it, but we do NOT pass response_format to avoid create_agent forcing + # tool_choice="any", which suppresses intermediate streaming messages. + # The system prompt is augmented with an instruction to use the tool. + response_tool = _create_response_tool(config.response_format) + all_tools = (all_tools or []) + [response_tool] + kwargs["tools"] = all_tools + current_prompt = kwargs.get("system_prompt", "") + kwargs["system_prompt"] = (current_prompt or "") + STRUCTURED_OUTPUT_INSTRUCTION subagents = await _resolve_subagents(config, mcp_tool_loader, prompt_manager) if subagents: diff --git a/tests/unit/test_factory.py b/tests/unit/test_factory.py index 1dbf311..6b12ba8 100644 --- a/tests/unit/test_factory.py +++ b/tests/unit/test_factory.py @@ -129,15 +129,17 @@ def test_tool_args_schema_has_properties(self): class TestResponseFormatIntegration: @patch("src.infrastructure.deepagent.factory.create_deep_agent") - async def test_passes_provider_strategy_when_response_format_set(self, mock_create): + async def test_injects_structured_response_tool_when_response_format_set(self, mock_create): + """When response_format is set, inject structured_response tool and instruction.""" mock_create.return_value = MagicMock() config = AgentConfig(name="test", response_format=WEATHER_SCHEMA) await create_agent_from_config(config) kwargs = mock_create.call_args.kwargs - assert "response_format" in kwargs - from langchain.agents.structured_output import ProviderStrategy - - assert isinstance(kwargs["response_format"], ProviderStrategy) + assert "response_format" not in kwargs + assert "tools" in kwargs + tool_names = [t.name for t in kwargs["tools"]] + assert "structured_response" in tool_names + assert "structured_response" in kwargs.get("system_prompt", "") @patch("src.infrastructure.deepagent.factory.create_deep_agent") async def test_omits_response_format_when_none(self, mock_create): diff --git a/tests/unit/test_routes.py b/tests/unit/test_routes.py index 4f63d09..53637b6 100644 --- a/tests/unit/test_routes.py +++ b/tests/unit/test_routes.py @@ -422,8 +422,8 @@ async def test_stream_ends_with_done(self, client): ] assert data_lines[-1] == "[DONE]", f"Expected [DONE] as last data line, got: {data_lines[-1]}" - async def test_stream_emits_message_json_before_done(self, client): - """Stream emits Message JSON before [DONE], even with a preceding structured event.""" + async def test_stream_emits_structured_event_before_done(self, client): + """Stream emits a STRUCTURED event (from MESSAGE with structured_response) before [DONE].""" create_resp = await client.post("/api/v1/threads", json={"agent_name": "my-agent"}) thread_id = create_resp.json()["id"] resp = await client.post( @@ -443,21 +443,19 @@ async def test_stream_emits_message_json_before_done(self, client): assert data_lines[-1] == "[DONE]" - # Find the last message event before [DONE] - message_event = None + structured_event = None for line in reversed(data_lines[:-1]): event = json.loads(line) - if event.get("type") == "message": - message_event = event + if event.get("type") == "structured": + structured_event = event break - assert message_event is not None, "No message event found in stream" - message_json = json.loads(message_event["data"]) - assert message_json["role"] == "ai" - assert message_json["structured_response"] == {"key": "value"} + assert structured_event is not None, "No structured event found in stream" + structured_data = json.loads(structured_event["data"]) + assert structured_data == {"key": "value"} async def test_stream_message_format_matches_sync(self, client): - """The Message JSON from stream has the same fields as the sync endpoint.""" + """The stream yields content events and a final structured event with expected fields.""" create_resp = await client.post("/api/v1/threads", json={"agent_name": "my-agent"}) thread_id = create_resp.json()["id"] @@ -483,20 +481,17 @@ async def test_stream_message_format_matches_sync(self, client): if line.strip().startswith("data:") ] - # Locate the message event (ignore structured events, [DONE], etc.) - message_event = None + content_events = [] + structured_event = None for line in data_lines: if line == "[DONE]": continue event = json.loads(line) - if event.get("type") == "message": - message_event = event - break - - assert message_event is not None, "No message event found in stream" - message_json = json.loads(message_event["data"]) - - for field in ["role", "content", "timestamp", "status"]: - assert field in message_json, f"Missing field {field!r} in stream Message: {message_json}" - - assert message_json["role"] == sync_data["role"] + if event.get("type") == "content": + content_events.append(event) + elif event.get("type") == "structured": + structured_event = event + + assert len(content_events) > 0, "No content events found in stream" + assert structured_event is not None, "No structured event found in stream" + assert sync_data["role"] == "ai"