diff --git a/README.md b/README.md index 54349d0..23347ac 100644 --- a/README.md +++ b/README.md @@ -444,24 +444,38 @@ curl -N -X POST http://localhost:8000/api/v1/chat/a1b2c3d4-e5f6-7890-abcd-ef1234 Response (Server-Sent Events): ``` -data: Lines -data: of -data: code -data: align -data: ... -event: message -data: {"role":"ai","content":"Lines of code align...","timestamp":"2025-04-24T10:30:05.000000Z","tool_calls":null,"status":"completed","structured_response":null} +data: {"type":"thinking","data":"Hmm, a haiku needs 5-7-5 syllables..."} -event: done -data: +data: {"type":"content","data":"Lines"} + +data: {"type":"content","data":" of"} + +data: {"type":"content","data":" code"} + +data: {"type":"content","data":" align"} + +data: {"type":"content","data":"..."} + +data: {"type":"message","data":"{\"role\":\"ai\",\"content\":\"Lines of code align...\",\"timestamp\":\"2025-04-24T10:30:05.000000Z\",\"tool_calls\":null,\"status\":\"completed\",\"structured_response\":null,\"thinking\":\"Hmm, a haiku needs 5-7-5 syllables...\"}"} + +data: [DONE] ``` -The stream emits: -1. **`data: `** — text chunks as they are generated (keepalive for proxies like Cloudflare) -2. **`event: message`** — the complete `Message` JSON with all fields (role, content, timestamp, tool_calls, status, structured_response), identical in format to the synchronous `POST /chat/{thread_id}` response -3. **`event: done`** — signals the stream is complete +The stream emits **typed `StreamEvent` JSON objects** over SSE: -This design prevents Cloudflare timeout issues (~100s on idle connections) because chunks and SSE pings (every 15s) keep the connection active. Clients that need the full structured response can read the `event: message` data. +| `type` | Description | Persisted? | +|---|---|---| +| `thinking` | Reasoning / chain-of-thought tokens from extended-thinking models (e.g., Claude reasoning). | Yes — saved in `Message.thinking` | +| `content` | Response text / markdown tokens as they are generated. | Yes — aggregated into `Message.content` | +| `message` | The final complete `Message` JSON with all fields (`role`, `content`, `timestamp`, `tool_calls`, `status`, `structured_response`, `thinking`). Identical in format to the synchronous `POST /chat/{thread_id}` response. | Yes — persisted as the AI turn in the thread | + +The stream ends with `data: [DONE]`. + +This design prevents Cloudflare timeout issues (~100s on idle connections) because chunks and SSE pings (every 15s) keep the connection active. Clients can switch rendering based on `type`: + +- Render `thinking` events in a collapsible reasoning panel. +- Append `content` events directly to the chat bubble. +- Wait for the `message` event to finalize metadata (status, structure, tool calls). ### 7. List All Threads @@ -711,20 +725,18 @@ ws.onopen = () => ws.send(JSON.stringify({ message: "Hello" })); ws.onmessage = (event) => { if (event.data === "[END]") { console.log("Response complete"); - } else { - try { - const data = JSON.parse(event.data); - if (data.type === "message") { - console.log("Final message:", data); - } - } catch { - process.stdout.write(event.data); - } + return; + } + const data = JSON.parse(event.data); + switch (data.type) { + case "thinking": console.log("[Thinking]", data.data); break; + case "content": process.stdout.write(data.data); break; + case "message": console.log("Final message:", data.data); break; } }; ``` -The WebSocket stream emits text chunks, then a JSON object with `type: "message"` containing the full `Message` fields, followed by `[END]`. +The WebSocket stream emits typed `StreamEvent` JSON objects: `thinking` (reasoning tokens), `content` (response text), `message` (final full `Message` JSON), then `[END]`. --- diff --git a/src/alembic/versions/004_add_thinking_column.py b/src/alembic/versions/004_add_thinking_column.py new file mode 100644 index 0000000..12ef8c8 --- /dev/null +++ b/src/alembic/versions/004_add_thinking_column.py @@ -0,0 +1,26 @@ +"""Add thinking column to messages. + +Revision ID: 004 +Revises: 003 +Create Date: 2026-04-29 +""" + +from collections.abc import Sequence + +from sqlalchemy import Column as saColumn +from sqlalchemy import Text + +from alembic import op + +revision: str = "004" +down_revision: str | None = "003" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column("messages", saColumn("thinking", Text, nullable=True)) + + +def downgrade() -> None: + op.drop_column("messages", "thinking") diff --git a/src/application/requests/chat.py b/src/application/requests/chat.py index 18764e6..9307ed2 100644 --- a/src/application/requests/chat.py +++ b/src/application/requests/chat.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field, model_validator -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class ChatRequest(BaseModel): diff --git a/src/application/routes/agents.py b/src/application/routes/agents.py index 1aca93b..f38e207 100644 --- a/src/application/routes/agents.py +++ b/src/application/routes/agents.py @@ -19,7 +19,7 @@ from src.domain.entities.agent_config import AgentConfig from src.domain.entities.agent_config_metadata import AgentConfigMetadata -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/agents", tags=["agents"]) diff --git a/src/application/routes/chat.py b/src/application/routes/chat.py index 1ebaf06..b67de6c 100644 --- a/src/application/routes/chat.py +++ b/src/application/routes/chat.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import Annotated @@ -14,8 +15,9 @@ get_stream_message_use_case, ) from src.domain.entities.message import Message +from src.domain.entities.stream_event import StreamEvent, StreamEventType -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/chat", tags=["chat"]) @@ -46,15 +48,16 @@ async def event_generator(): chunk_count = 0 try: async for event in use_case.execute(thread_id, body.message): - if isinstance(event, str): + if event.type in (StreamEventType.THINKING, StreamEventType.CONTENT): chunk_count += 1 - yield {"data": event} - elif isinstance(event, Message): - yield {"data": event.model_dump_json()} + yield {"data": event.model_dump_json()} yield {"data": "[DONE]"} logger.info("[thread=%s] Stream complete, %d chunks", thread_id, chunk_count) - except Exception: + except asyncio.CancelledError: + raise + except Exception as exc: logger.exception("[thread=%s] Stream error after %d chunks", thread_id, chunk_count) - yield {"event": "error", "data": "stream_error"} + error_event = StreamEvent(type=StreamEventType.ERROR, data=str(exc)) + yield {"data": error_event.model_dump_json()} return EventSourceResponse(event_generator(), sep="\r\n", ping=15) diff --git a/src/application/routes/prompt.py b/src/application/routes/prompt.py index 6cb254d..a7ce97b 100644 --- a/src/application/routes/prompt.py +++ b/src/application/routes/prompt.py @@ -13,7 +13,7 @@ from src.dependencies import get_prompt_manager from src.domain.ports.prompt_manager import PromptManager -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) router = APIRouter(prefix="/prompts", tags=["prompts"]) diff --git a/src/application/routes/threads.py b/src/application/routes/threads.py index 0d202d3..d8a811b 100644 --- a/src/application/routes/threads.py +++ b/src/application/routes/threads.py @@ -18,7 +18,7 @@ ) from src.domain.entities.thread import Thread -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/threads", tags=["threads"]) diff --git a/src/application/routes/websocket.py b/src/application/routes/websocket.py index d2edebb..7d8b11e 100644 --- a/src/application/routes/websocket.py +++ b/src/application/routes/websocket.py @@ -5,9 +5,9 @@ from src.application.use_cases.stream_message import StreamMessageUseCase from src.dependencies import get_stream_message_use_case -from src.domain.entities.message import Message +from src.domain.entities.stream_event import StreamEvent, StreamEventType -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) router = APIRouter(tags=["websocket"]) @@ -34,16 +34,15 @@ async def websocket_chat( chunk_count = 0 try: async for event in use_case.execute(thread_id, message): - if isinstance(event, str): + if event.type in (StreamEventType.THINKING, StreamEventType.CONTENT): chunk_count += 1 - await websocket.send_text(event) - elif isinstance(event, Message): - await websocket.send_text(json.dumps({"type": "message", **event.model_dump(mode="json")})) + await websocket.send_text(event.model_dump_json()) await websocket.send_text("[END]") logger.info("[thread=%s] WS stream complete, %d chunks", thread_id, chunk_count) - except Exception: + except Exception as exc: logger.exception("[thread=%s] WS stream error after %d chunks", thread_id, chunk_count) - await websocket.send_text(json.dumps({"error": "Agent execution error"})) + error_event = StreamEvent(type=StreamEventType.ERROR, data=str(exc)) + await websocket.send_text(error_event.model_dump_json()) except WebSocketDisconnect: logger.info("[thread=%s] WebSocket disconnected", thread_id) except Exception: diff --git a/src/application/use_cases/create_agent_config.py b/src/application/use_cases/create_agent_config.py index f4b3399..0b4b845 100644 --- a/src/application/use_cases/create_agent_config.py +++ b/src/application/use_cases/create_agent_config.py @@ -8,7 +8,7 @@ from src.domain.ports.agent_config_repository import AgentConfigRepository from src.domain.ports.agent_config_store import AgentConfigStore -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class CreateAgentConfigUseCase: diff --git a/src/application/use_cases/create_prompt.py b/src/application/use_cases/create_prompt.py index de57d7d..082b32a 100644 --- a/src/application/use_cases/create_prompt.py +++ b/src/application/use_cases/create_prompt.py @@ -3,7 +3,7 @@ from src.domain.entities.prompt import PromptVersion from src.domain.ports.prompt_manager import PromptManager -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class CreatePromptUseCase: diff --git a/src/application/use_cases/delete_agent_config.py b/src/application/use_cases/delete_agent_config.py index f16ac8a..635e4c0 100644 --- a/src/application/use_cases/delete_agent_config.py +++ b/src/application/use_cases/delete_agent_config.py @@ -4,7 +4,7 @@ from src.domain.ports.agent_config_store import AgentConfigStore from src.domain.ports.agent_registry import AgentRegistry -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class DeleteAgentConfigUseCase: diff --git a/src/application/use_cases/get_agent_config.py b/src/application/use_cases/get_agent_config.py index 03ae3f1..9130cdd 100644 --- a/src/application/use_cases/get_agent_config.py +++ b/src/application/use_cases/get_agent_config.py @@ -4,7 +4,7 @@ from src.domain.ports.agent_config_loader import AgentConfigLoader from src.domain.ports.agent_config_store import AgentConfigStore -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class GetAgentConfigUseCase: diff --git a/src/application/use_cases/get_prompt.py b/src/application/use_cases/get_prompt.py index 72b20c1..2c2671c 100644 --- a/src/application/use_cases/get_prompt.py +++ b/src/application/use_cases/get_prompt.py @@ -3,7 +3,7 @@ from src.domain.entities.prompt import Prompt from src.domain.ports.prompt_manager import PromptManager -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class GetPromptUseCase: diff --git a/src/application/use_cases/list_agent_configs.py b/src/application/use_cases/list_agent_configs.py index 584f539..dbf1d36 100644 --- a/src/application/use_cases/list_agent_configs.py +++ b/src/application/use_cases/list_agent_configs.py @@ -3,7 +3,7 @@ from src.domain.entities.agent_config_metadata import AgentConfigMetadata from src.domain.ports.agent_config_repository import AgentConfigRepository -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class ListAgentConfigsUseCase: diff --git a/src/application/use_cases/send_message.py b/src/application/use_cases/send_message.py index 2519d96..4d7fcd6 100644 --- a/src/application/use_cases/send_message.py +++ b/src/application/use_cases/send_message.py @@ -6,7 +6,7 @@ from src.domain.ports.agent_registry import AgentRegistry from src.domain.ports.thread_repository import ThreadRepository -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class SendMessageUseCase: diff --git a/src/application/use_cases/stream_message.py b/src/application/use_cases/stream_message.py index c1c0ad1..ad51e2e 100644 --- a/src/application/use_cases/stream_message.py +++ b/src/application/use_cases/stream_message.py @@ -1,12 +1,14 @@ import logging import time +import json from collections.abc import AsyncGenerator from src.domain.entities.message import Message, MessageRole +from src.domain.entities.stream_event import StreamEvent, StreamEventType from src.domain.ports.agent_registry import AgentRegistry from src.domain.ports.thread_repository import ThreadRepository -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class StreamMessageUseCase: @@ -16,7 +18,7 @@ def __init__(self, registry: AgentRegistry, threads: ThreadRepository): self._registry = registry self._threads = threads - async def execute(self, thread_id: str, message: str) -> AsyncGenerator[str | Message, None]: + async def execute(self, thread_id: str, message: str) -> AsyncGenerator[StreamEvent, None]: thread = await self._threads.get(thread_id) human_msg = Message(role=MessageRole.HUMAN, content=message) await self._threads.add_message(thread_id, human_msg) @@ -27,11 +29,17 @@ async def execute(self, thread_id: str, message: str) -> AsyncGenerator[str | Me final_message = None try: async for event in runner.stream_with_message(thread_id, message): - if isinstance(event, str): + if event.type in (StreamEventType.THINKING, StreamEventType.CONTENT): chunk_count += 1 yield event - elif isinstance(event, Message): - final_message = event + elif event.type == StreamEventType.MESSAGE: + final_message = Message.model_validate_json(event.data) + if final_message and final_message.structured_response is not None: + yield StreamEvent( + type=StreamEventType.STRUCTURED, + data=json.dumps(final_message.structured_response) + ) + yield event except Exception: logger.exception( "[thread=%s][agent=%s] Stream error after %d chunks", thread_id, thread.agent_name, chunk_count @@ -41,16 +49,15 @@ async def execute(self, thread_id: str, message: str) -> AsyncGenerator[str | Me if final_message is not None: try: await self._threads.add_message(thread_id, final_message) - except Exception: + logger.info( + "[thread=%s][agent=%s] Stream complete, %d chunks, elapsed=%.2fs, message=persisted", + thread_id, + thread.agent_name, + chunk_count, + elapsed, + ) + except Exception as exc: logger.exception( "[thread=%s][agent=%s] Failed to persist AI message after stream", thread_id, thread.agent_name ) - yield final_message - logger.info( - "[thread=%s][agent=%s] Stream complete, %d chunks, elapsed=%.2fs, message=%s", - thread_id, - thread.agent_name, - chunk_count, - elapsed, - "yielded" if final_message else "none", - ) + raise RuntimeError(f"Failed to persist AI message after stream: {exc}") from exc diff --git a/src/application/use_cases/thread_management.py b/src/application/use_cases/thread_management.py index 26fe403..02bb461 100644 --- a/src/application/use_cases/thread_management.py +++ b/src/application/use_cases/thread_management.py @@ -5,7 +5,7 @@ from src.domain.ports.agent_registry import AgentRegistry from src.domain.ports.thread_repository import ThreadRepository -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class CreateThreadUseCase: diff --git a/src/application/use_cases/update_agent_config.py b/src/application/use_cases/update_agent_config.py index 3dd44fa..3048241 100644 --- a/src/application/use_cases/update_agent_config.py +++ b/src/application/use_cases/update_agent_config.py @@ -8,7 +8,7 @@ from src.domain.ports.agent_config_store import AgentConfigStore from src.domain.ports.agent_registry import AgentRegistry -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class UpdateAgentConfigUseCase: diff --git a/src/application/use_cases/update_prompt.py b/src/application/use_cases/update_prompt.py index 76b5d81..fc1e594 100644 --- a/src/application/use_cases/update_prompt.py +++ b/src/application/use_cases/update_prompt.py @@ -4,7 +4,7 @@ from src.domain.ports.prompt_manager import PromptManager -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class UpdatePromptUseCase: diff --git a/src/dependencies.py b/src/dependencies.py index 1796253..bef9a14 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -32,7 +32,7 @@ from src.infrastructure.prompt_management.phoenix_prompt_adapter import PhoenixPromptManagerProvider from src.infrastructure.yaml_config.adapter import YamlAgentConfigLoader -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) # ============= CONFIG ============= diff --git a/src/domain/entities/agent_config.py b/src/domain/entities/agent_config.py index aa80fe6..4b09a67 100644 --- a/src/domain/entities/agent_config.py +++ b/src/domain/entities/agent_config.py @@ -6,7 +6,7 @@ from src.domain.entities.mcp_server_config import McpServerConfig -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class MiddlewareType(StrEnum): diff --git a/src/domain/entities/mcp_server_config.py b/src/domain/entities/mcp_server_config.py index 9a41b68..8861177 100644 --- a/src/domain/entities/mcp_server_config.py +++ b/src/domain/entities/mcp_server_config.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, model_validator -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class McpTransportType(StrEnum): diff --git a/src/domain/entities/message.py b/src/domain/entities/message.py index d607b9c..7143166 100644 --- a/src/domain/entities/message.py +++ b/src/domain/entities/message.py @@ -23,3 +23,4 @@ class Message(BaseModel, frozen=True): tool_calls: list[dict] | None = None status: MessageStatus | None = None structured_response: dict | None = None + thinking: str | None = None diff --git a/src/domain/entities/stream_event.py b/src/domain/entities/stream_event.py new file mode 100644 index 0000000..1e5e696 --- /dev/null +++ b/src/domain/entities/stream_event.py @@ -0,0 +1,16 @@ +from enum import StrEnum + +from pydantic import BaseModel + + +class StreamEventType(StrEnum): + THINKING = "thinking" + CONTENT = "content" + MESSAGE = "message" + STRUCTURED = "structured" + ERROR = "error" + + +class StreamEvent(BaseModel, frozen=True): + type: StreamEventType + data: str diff --git a/src/domain/ports/agent_runner.py b/src/domain/ports/agent_runner.py index 2f756ab..2d09742 100644 --- a/src/domain/ports/agent_runner.py +++ b/src/domain/ports/agent_runner.py @@ -2,35 +2,24 @@ from collections.abc import AsyncIterator from src.domain.entities.message import Message +from src.domain.entities.stream_event import StreamEvent class AgentRunner(ABC): @abstractmethod - async def invoke(self, thread_id: str, message: str) -> Message: - """Envoie un message et retourne la reponse complete.""" - ... + async def invoke(self, thread_id: str, message: str) -> Message: ... @abstractmethod - async def stream(self, thread_id: str, message: str) -> AsyncIterator[str]: - """Envoie un message et streame la reponse par chunks.""" - ... + async def stream(self, thread_id: str, message: str) -> AsyncIterator[StreamEvent]: ... @abstractmethod - async def stream_with_message(self, thread_id: str, message: str) -> AsyncIterator[str | Message]: - """Streame les chunks puis yield le Message final complet.""" - ... + async def stream_with_message(self, thread_id: str, message: str) -> AsyncIterator[StreamEvent]: ... @abstractmethod - async def approve_hitl(self, thread_id: str, tool_call_id: str) -> Message: - """Approuve une action HITL en attente.""" - ... + async def approve_hitl(self, thread_id: str, tool_call_id: str) -> Message: ... @abstractmethod - async def reject_hitl(self, thread_id: str, tool_call_id: str, reason: str | None = None) -> Message: - """Rejette une action HITL en attente.""" - ... + async def reject_hitl(self, thread_id: str, tool_call_id: str, reason: str | None = None) -> Message: ... @abstractmethod - async def edit_hitl(self, thread_id: str, tool_call_id: str, edits: dict) -> Message: - """Edite et approuve une action HITL en attente.""" - ... + async def edit_hitl(self, thread_id: str, tool_call_id: str, edits: dict) -> Message: ... diff --git a/src/infrastructure/database/models/thread.py b/src/infrastructure/database/models/thread.py index aa5334a..7497325 100644 --- a/src/infrastructure/database/models/thread.py +++ b/src/infrastructure/database/models/thread.py @@ -35,5 +35,8 @@ class MessageModel(Base): tool_calls: Mapped[list[dict] | None] = mapped_column(JSONB, nullable=True) status: Mapped[str | None] = mapped_column(String(50), nullable=True) structured_response: Mapped[dict | None] = mapped_column(JSONB, nullable=True) + thinking: Mapped[str | None] = mapped_column( + Text, nullable=True + ) # chain-of-thought / reasoning text from extended-thinking models thread: Mapped["ThreadModel"] = relationship("ThreadModel", back_populates="messages") diff --git a/src/infrastructure/deepagent/adapter.py b/src/infrastructure/deepagent/adapter.py index 28e8d38..2ee6292 100644 --- a/src/infrastructure/deepagent/adapter.py +++ b/src/infrastructure/deepagent/adapter.py @@ -7,33 +7,29 @@ from langgraph.types import Command from src.domain.entities.message import Message, MessageRole, MessageStatus +from src.domain.entities.stream_event import StreamEvent, StreamEventType from src.domain.exceptions import AgentError from src.domain.ports.agent_runner import AgentRunner from src.domain.ports.tracing_provider import TracingProvider -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class DeepAgentRunner(AgentRunner): - """Adapter qui execute un Deep Agent LangGraph.""" - def __init__(self, graph, tracing_provider: TracingProvider | None = None): self._graph = graph self._tracing_provider = tracing_provider @staticmethod def _try_parse_json(content: str) -> dict | None: - """Try to extract a JSON object from content that may contain markdown.""" if not content: return None - # Try direct parse first try: parsed = json.loads(content) if isinstance(parsed, dict): return parsed except (json.JSONDecodeError, TypeError): pass - # Try extracting from ```json ... ``` blocks match = re.search(r"```(?:json)?\s*\n(.*?)\n```", content, re.DOTALL) if match: try: @@ -44,15 +40,25 @@ def _try_parse_json(content: str) -> dict | None: pass return None - def _build_config(self, thread_id: str) -> dict: - """Build the LangGraph config with optional tracing callbacks. + @staticmethod + def _is_nonblank_str(val: object) -> bool: + return isinstance(val, str) and val.strip() != "" - Args: - thread_id: The conversation thread identifier. + @staticmethod + def _classify_chunk(chunk) -> tuple[StreamEventType, str] | None: + if chunk.type != "AIMessageChunk": + return None + additional = getattr(chunk, "additional_kwargs", {}) + reasoning = additional.get("reasoning_content") + if DeepAgentRunner._is_nonblank_str(reasoning): + return (StreamEventType.THINKING, reasoning) + if additional.get("type") == "thinking" and DeepAgentRunner._is_nonblank_str(chunk.content): + return (StreamEventType.THINKING, chunk.content) + if DeepAgentRunner._is_nonblank_str(chunk.content): + return (StreamEventType.CONTENT, chunk.content) + return None - Returns: - Config dict with thread_id and optional callbacks. - """ + def _build_config(self, thread_id: str) -> dict: config: dict = {"configurable": {"thread_id": thread_id}} if self._tracing_provider: callbacks = self._tracing_provider.get_callbacks() @@ -60,18 +66,14 @@ def _build_config(self, thread_id: str) -> dict: config["callbacks"] = callbacks return config - def _build_response(self, result: dict, config: dict) -> Message: - """Build a Message from graph result, detecting interrupts and collecting tool_calls.""" + def _build_response(self, result: dict, config: dict, thinking: str | None) -> Message: messages = result.get("messages", []) if not messages: raise AgentError("Graph completed but no messages were found in the final state.") last_message = messages[-1] - all_tool_calls = getattr(last_message, "tool_calls", None) or [] - state = self._graph.get_state(config) status = MessageStatus.AWAITING_HITL if state.interrupts else MessageStatus.COMPLETED - structured_response = None raw_structured = result.get("structured_response") if raw_structured is not None: @@ -79,16 +81,15 @@ def _build_response(self, result: dict, config: dict) -> Message: structured_response = raw_structured.model_dump() elif isinstance(raw_structured, dict): structured_response = raw_structured - if structured_response is None: structured_response = self._try_parse_json(last_message.content) - return Message( role=MessageRole.AI, content=last_message.content, tool_calls=all_tool_calls or None, status=status, structured_response=structured_response, + thinking=thinking, ) async def invoke(self, thread_id: str, message: str) -> Message: @@ -102,7 +103,7 @@ async def invoke(self, thread_id: str, message: str) -> Message: config=config, ) elapsed = time.monotonic() - start - response = self._build_response(result, config) + response = self._build_response(result, config, None) logger.info("[thread=%s] Invoke complete, status=%s, elapsed=%.2fs", thread_id, response.status, elapsed) return response except Exception as e: @@ -110,13 +111,8 @@ async def invoke(self, thread_id: str, message: str) -> Message: raise AgentError(f"Agent execution error: {e}") from e async def _yield_chunks( - self, - thread_id: str, - message: str, - config: dict, - stats: dict, - ) -> AsyncIterator[str]: - """Stream graph chunks and populate *stats* with timing.""" + self, thread_id: str, message: str, config: dict, stats: dict + ) -> AsyncIterator[StreamEvent]: start = time.monotonic() first_chunk = True chunk_count = 0 @@ -125,26 +121,24 @@ async def _yield_chunks( config=config, stream_mode="messages", ): - if hasattr(chunk, "content") and chunk.content and chunk.type == "AIMessageChunk": + classification = self._classify_chunk(chunk) + if classification: + event_type, data = classification if first_chunk: - logger.info( - "[thread=%s] First chunk received, elapsed=%.2fs", - thread_id, - time.monotonic() - start, - ) + logger.info("[thread=%s] First chunk received, elapsed=%.2fs", thread_id, time.monotonic() - start) first_chunk = False chunk_count += 1 - yield chunk.content + yield StreamEvent(type=event_type, data=data) stats["chunk_count"] = chunk_count stats["elapsed"] = time.monotonic() - start - async def stream(self, thread_id: str, message: str) -> AsyncIterator[str]: + async def stream(self, thread_id: str, message: str) -> AsyncIterator[StreamEvent]: config = self._build_config(thread_id) logger.info("[thread=%s] Streaming agent response", thread_id) try: stats: dict = {} - async for chunk in self._yield_chunks(thread_id, message, config, stats): - yield chunk + async for event in self._yield_chunks(thread_id, message, config, stats): + yield event logger.info( "[thread=%s] Stream complete, %d chunks, elapsed=%.2fs", thread_id, @@ -155,20 +149,21 @@ async def stream(self, thread_id: str, message: str) -> AsyncIterator[str]: logger.exception("[thread=%s] Streaming error", thread_id) raise AgentError(f"Streaming error: {e}") from e - async def stream_with_message(self, thread_id: str, message: str) -> AsyncIterator[str | Message]: + async def stream_with_message(self, thread_id: str, message: str) -> AsyncIterator[StreamEvent]: config = self._build_config(thread_id) logger.info("[thread=%s] Streaming agent response with final message", thread_id) try: stats: dict = {} - async for chunk in self._yield_chunks(thread_id, message, config, stats): - yield chunk + thinking_parts = [] + async for event in self._yield_chunks(thread_id, message, config, stats): + yield event + if event.type == StreamEventType.THINKING: + thinking_parts.append(event.data) state = self._graph.get_state(config) - values = state.values if state and hasattr(state, "values") else {} - result = { - "messages": values.get("messages", []), - "structured_response": values.get("structured_response"), - } - response = self._build_response(result, config) + values = getattr(state, "values", None) or {} + result = {"messages": values.get("messages", []), "structured_response": values.get("structured_response")} + thinking = "".join(thinking_parts) if thinking_parts else None + response = self._build_response(result, config, thinking) logger.info( "[thread=%s] Stream with message complete, %d chunks, elapsed=%.2fs, status=%s", thread_id, @@ -176,39 +171,35 @@ async def stream_with_message(self, thread_id: str, message: str) -> AsyncIterat stats["elapsed"], response.status, ) - yield response + yield StreamEvent(type=StreamEventType.MESSAGE, data=response.model_dump_json()) except Exception as e: logger.exception("[thread=%s] Streaming error", thread_id) raise AgentError(f"Streaming error: {e}") from e - async def approve_hitl(self, thread_id: str, tool_call_id: str) -> Message: # noqa: ARG002 + async def approve_hitl(self, thread_id: str, _tool_call_id: str) -> Message: config = self._build_config(thread_id) logger.info("[thread=%s] HITL approve", thread_id) try: start = time.monotonic() - result = await self._graph.ainvoke( - Command(resume={"decisions": [{"type": "approve"}]}), - config=config, - ) + result = await self._graph.ainvoke(Command(resume={"decisions": [{"type": "approve"}]}), config=config) elapsed = time.monotonic() - start - response = self._build_response(result, config) + response = self._build_response(result, config, None) logger.info("[thread=%s] HITL approve complete, elapsed=%.2fs", thread_id, elapsed) return response except Exception as e: logger.exception("HITL approve error") raise AgentError(f"HITL approve error: {e}") from e - async def reject_hitl(self, thread_id: str, tool_call_id: str, reason: str | None = None) -> Message: # noqa: ARG002 + async def reject_hitl(self, thread_id: str, _tool_call_id: str, reason: str | None = None) -> Message: config = self._build_config(thread_id) logger.info("[thread=%s] HITL reject, reason=%s", thread_id, reason) try: start = time.monotonic() result = await self._graph.ainvoke( - Command(resume={"decisions": [{"type": "reject", "message": reason or ""}]}), - config=config, + Command(resume={"decisions": [{"type": "reject", "message": reason or ""}]}), config=config ) elapsed = time.monotonic() - start - response = self._build_response(result, config) + response = self._build_response(result, config, None) logger.info("[thread=%s] HITL reject complete, elapsed=%.2fs", thread_id, elapsed) return response except Exception as e: @@ -222,18 +213,22 @@ async def edit_hitl(self, thread_id: str, tool_call_id: str, edits: dict) -> Mes start = time.monotonic() state = self._graph.get_state(config) tool_name = tool_call_id - for msg in state.values.get("messages", []): - if hasattr(msg, "tool_calls"): - for tc in msg.tool_calls: - if tc.get("id") == tool_call_id: - tool_name = tc["name"] - break + tool_name = next( + ( + tc["name"] + for msg in state.values.get("messages", []) + if hasattr(msg, "tool_calls") + for tc in msg.tool_calls + if tc.get("id") == tool_call_id + ), + tool_call_id, + ) result = await self._graph.ainvoke( Command(resume={"decisions": [{"type": "edit", "edited_action": {"name": tool_name, "args": edits}}]}), config=config, ) elapsed = time.monotonic() - start - response = self._build_response(result, config) + response = self._build_response(result, config, None) logger.info("[thread=%s] HITL edit complete, elapsed=%.2fs", thread_id, elapsed) return response except Exception as e: diff --git a/src/infrastructure/deepagent/factory.py b/src/infrastructure/deepagent/factory.py index 570d2e2..bbbe017 100644 --- a/src/infrastructure/deepagent/factory.py +++ b/src/infrastructure/deepagent/factory.py @@ -17,7 +17,7 @@ from src.domain.ports.mcp_tool_loader import McpToolLoader from src.domain.ports.prompt_manager import PromptManager -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) STRUCTURED_OUTPUT_INSTRUCTION = ( "\n\nYou MUST use the 'structured_response' tool to return your final answer in the expected structured format." diff --git a/src/infrastructure/mcp/adapter.py b/src/infrastructure/mcp/adapter.py index e330b0b..b979796 100644 --- a/src/infrastructure/mcp/adapter.py +++ b/src/infrastructure/mcp/adapter.py @@ -10,7 +10,7 @@ from src.domain.ports.mcp_tool_loader import McpToolLoader from src.infrastructure.env_utils import resolve_env_vars -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class LangchainMcpToolLoader(McpToolLoader): diff --git a/src/infrastructure/minio_store/adapter.py b/src/infrastructure/minio_store/adapter.py index 7651fa9..3f997e5 100644 --- a/src/infrastructure/minio_store/adapter.py +++ b/src/infrastructure/minio_store/adapter.py @@ -7,7 +7,7 @@ from src.domain.exceptions import AgentNotFoundError from src.domain.ports.agent_config_store import AgentConfigStore -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class MinioAgentConfigStore(AgentConfigStore): diff --git a/src/infrastructure/persistent_registry/adapter.py b/src/infrastructure/persistent_registry/adapter.py index 250e8b0..e6d1982 100644 --- a/src/infrastructure/persistent_registry/adapter.py +++ b/src/infrastructure/persistent_registry/adapter.py @@ -12,7 +12,7 @@ from src.infrastructure.deepagent.adapter import DeepAgentRunner from src.infrastructure.deepagent.factory import create_agent_from_config -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class PersistentAgentRegistry(AgentRegistry): diff --git a/src/infrastructure/postgres_repository/adapter.py b/src/infrastructure/postgres_repository/adapter.py index 98cae95..bb2999b 100644 --- a/src/infrastructure/postgres_repository/adapter.py +++ b/src/infrastructure/postgres_repository/adapter.py @@ -9,7 +9,7 @@ from src.domain.ports.agent_config_repository import AgentConfigRepository from src.infrastructure.database.models.agent_config import AgentConfigModel -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) def _model_to_metadata(model: AgentConfigModel) -> AgentConfigMetadata: diff --git a/src/infrastructure/postgres_thread/adapter.py b/src/infrastructure/postgres_thread/adapter.py index 35a3d81..66bce21 100644 --- a/src/infrastructure/postgres_thread/adapter.py +++ b/src/infrastructure/postgres_thread/adapter.py @@ -13,7 +13,12 @@ from src.domain.ports.thread_repository import ThreadRepository from src.infrastructure.database.models.thread import MessageModel, ThreadModel -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) + + +def _safe_str(val: object) -> str | None: + """Return a string if the value is a real string, else None.""" + return val if isinstance(val, str) else None def _model_to_thread(thread_model: ThreadModel) -> Thread: @@ -37,6 +42,7 @@ def _model_to_thread(thread_model: ThreadModel) -> Thread: tool_calls=msg.tool_calls, status=MessageStatus(msg.status) if msg.status else None, structured_response=msg.structured_response, + thinking=_safe_str(msg.thinking), ) for msg in messages_sorted ] @@ -189,6 +195,7 @@ async def add_message(self, thread_id: str, message: Message) -> Thread: tool_calls=message.tool_calls, status=message.status.value if message.status else None, structured_response=message.structured_response, + thinking=message.thinking, ) session.add(msg_model) thread_model.updated_at = datetime.now(UTC) diff --git a/src/infrastructure/prompt_management/phoenix_prompt_adapter.py b/src/infrastructure/prompt_management/phoenix_prompt_adapter.py index 399e41c..e71ef6b 100644 --- a/src/infrastructure/prompt_management/phoenix_prompt_adapter.py +++ b/src/infrastructure/prompt_management/phoenix_prompt_adapter.py @@ -14,7 +14,7 @@ from src.domain.entities.prompt import Prompt, PromptVersion from src.domain.ports.prompt_manager import PromptManager -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) # Retry config _RETRY_ATTEMPTS = 3 diff --git a/src/infrastructure/tracing/phoenix_adapter.py b/src/infrastructure/tracing/phoenix_adapter.py index 2fd35b9..cf2ccc7 100644 --- a/src/infrastructure/tracing/phoenix_adapter.py +++ b/src/infrastructure/tracing/phoenix_adapter.py @@ -7,7 +7,7 @@ from src.domain.ports.tracing_provider import TracingProvider -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class PhoenixTracingProvider(TracingProvider): diff --git a/src/infrastructure/yaml_config/adapter.py b/src/infrastructure/yaml_config/adapter.py index 5813cd8..e81e6c4 100644 --- a/src/infrastructure/yaml_config/adapter.py +++ b/src/infrastructure/yaml_config/adapter.py @@ -7,7 +7,7 @@ from src.domain.exceptions import ConfigError, ConfigNotFoundError, ConfigValidationError from src.domain.ports.agent_config_loader import AgentConfigLoader -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) class YamlAgentConfigLoader(AgentConfigLoader): diff --git a/src/main.py b/src/main.py index 5127c47..9b95dcd 100644 --- a/src/main.py +++ b/src/main.py @@ -4,36 +4,15 @@ from contextlib import asynccontextmanager from pathlib import Path -log_level_name = "INFO" -log_level = logging.INFO +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + stream=sys.stdout, + force=True, +) from src.config import Settings -_settings = Settings() -log_level_name = _settings.log_level.upper() -log_level = getattr(logging, log_level_name, logging.INFO) - -_handler = logging.StreamHandler(sys.stdout) -_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) - -_root = logging.getLogger() -_root.setLevel(log_level) -_root.handlers.clear() -_root.addHandler(_handler) - -for _name in ( - "langchain", - "langchain_core", - "langchain_community", - "langgraph", - "openai", - "httpx", - "httpcore", - "alembic", - "sqlalchemy", -): - logging.getLogger(_name).setLevel(log_level) - from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse @@ -62,7 +41,7 @@ ThreadNotFoundError, ) -logger = logging.getLogger("composable-agents") +logger = logging.getLogger(__name__) settings = Settings() diff --git a/tests/unit/test_deep_agent_runner_stream_with_message.py b/tests/unit/test_deep_agent_runner_stream_with_message.py index e5b01a9..1e103cc 100644 --- a/tests/unit/test_deep_agent_runner_stream_with_message.py +++ b/tests/unit/test_deep_agent_runner_stream_with_message.py @@ -1,10 +1,6 @@ """Tests for DeepAgentRunner.stream_with_message(). Graph is mocked (external LLM boundary via LangGraph). -These tests exercise the new stream_with_message() method which yields -str chunks during streaming and a final Message object after the stream -completes, allowing callers to receive both streaming chunks and a -structured complete response. """ from unittest.mock import AsyncMock, MagicMock @@ -12,47 +8,38 @@ import pytest from src.domain.entities.message import Message, MessageRole, MessageStatus +from src.domain.entities.stream_event import StreamEventType from src.domain.exceptions import AgentError from src.infrastructure.deepagent.adapter import DeepAgentRunner def _make_streaming_graph( - chunks: list[str], + chunks: list[tuple[StreamEventType, str]], final_messages: list | None = None, interrupts=(), state_values: dict | None = None, structured_response=None, ): - """Create a mock graph with astream and get_state for stream_with_message. - - Args: - chunks: List of string chunks to yield from astream. - final_messages: Messages to appear in get_state().values["messages"]. - If None, a default AI message is constructed from the chunks. - interrupts: Interrupt tuples for get_state(). - state_values: Additional state values for get_state(). - structured_response: Optional structured_response value in state. - """ mock_graph = AsyncMock() - # Build async generator for astream async def _astream(_input, **_kwargs): - for chunk_text in chunks: + for event_type, chunk_text in chunks: chunk = MagicMock() chunk.content = chunk_text chunk.type = "AIMessageChunk" + chunk.additional_kwargs = {} + if event_type == StreamEventType.THINKING: + chunk.additional_kwargs = {"type": "thinking"} yield chunk, MagicMock() mock_graph.astream = _astream - - # Build state for get_state (called after stream to build Message) state = MagicMock() state.interrupts = interrupts - # Build final AI message from joined chunks if no explicit messages given if final_messages is None: + content = "".join(text for etype, text in chunks if etype == StreamEventType.CONTENT) final_ai = MagicMock() - final_ai.content = "".join(chunks) + final_ai.content = content final_ai.tool_calls = None final_messages = [final_ai] @@ -64,130 +51,123 @@ async def _astream(_input, **_kwargs): state.values = values mock_graph.get_state = MagicMock(return_value=state) - - # ainvoke return value (used by get_state context) mock_graph.ainvoke.return_value = { "messages": final_messages, "structured_response": structured_response, } - return mock_graph class TestStreamWithMessage: - """Tests for DeepAgentRunner.stream_with_message().""" - - async def test_stream_with_message_yields_chunks_then_message(self): - """Should yield str chunks during streaming, then a final Message.""" - chunks = ["Hello ", "world!"] + async def test_stream_with_message_yields_content_then_message(self): + chunks = [(StreamEventType.CONTENT, "Hello "), (StreamEventType.CONTENT, "world!")] graph = _make_streaming_graph(chunks) - runner = DeepAgentRunner(graph) collected = [] - async for item in runner.stream_with_message("thread-1", "hi"): - collected.append(item) - - # All items except the last should be str chunks - str_items = collected[:-1] - final_message = collected[-1] + async for event in runner.stream_with_message("thread-1", "hi"): + collected.append(event) + + content_events = collected[:-1] + final_event = collected[-1] + + assert all(e.type == StreamEventType.CONTENT for e in content_events) + assert [e.data for e in content_events] == ["Hello ", "world!"] + + assert final_event.type == StreamEventType.MESSAGE + msg = Message.model_validate_json(final_event.data) + assert msg.role == MessageRole.AI + assert msg.content == "Hello world!" + assert msg.status == MessageStatus.COMPLETED + + async def test_stream_with_message_yields_thinking_then_content(self): + chunks = [ + (StreamEventType.THINKING, "Let me think..."), + (StreamEventType.CONTENT, "Here is the answer."), + ] + graph = _make_streaming_graph(chunks) + runner = DeepAgentRunner(graph) + collected = [] + async for event in runner.stream_with_message("thread-1", "hi"): + collected.append(event) - assert all(isinstance(c, str) for c in str_items) - assert str_items == ["Hello ", "world!"] + events = collected[:-1] + assert events[0].type == StreamEventType.THINKING + assert events[0].data == "Let me think..." + assert events[1].type == StreamEventType.CONTENT + assert events[1].data == "Here is the answer." - assert isinstance(final_message, Message) - assert final_message.role == MessageRole.AI - assert final_message.content == "Hello world!" - assert final_message.status == MessageStatus.COMPLETED + final_event = collected[-1] + msg = Message.model_validate_json(final_event.data) + assert msg.thinking == "Let me think..." + assert msg.content == "Here is the answer." async def test_stream_with_message_final_message_has_tool_calls(self): - """When the last message has tool_calls, the final Message includes them.""" - chunks = ["Processing..."] - + chunks = [(StreamEventType.CONTENT, "Processing...")] ai_msg = MagicMock() ai_msg.content = "Processing..." ai_msg.tool_calls = [{"name": "search", "args": {"q": "test"}, "id": "tc-1"}] - - graph = _make_streaming_graph( - chunks, - final_messages=[ai_msg], - structured_response=None, - ) + graph = _make_streaming_graph(chunks, final_messages=[ai_msg]) runner = DeepAgentRunner(graph) collected = [] - async for item in runner.stream_with_message("thread-1", "search for test"): - collected.append(item) + async for event in runner.stream_with_message("thread-1", "search for test"): + collected.append(event) - final_message = collected[-1] - assert isinstance(final_message, Message) - assert final_message.tool_calls is not None - assert len(final_message.tool_calls) == 1 - assert final_message.tool_calls[0]["name"] == "search" + final_event = collected[-1] + msg = Message.model_validate_json(final_event.data) + assert msg.tool_calls is not None + assert len(msg.tool_calls) == 1 + assert msg.tool_calls[0]["name"] == "search" async def test_stream_with_message_final_message_has_structured_response(self): - """When result has structured_response, it appears in the final Message.""" - chunks = ["Weather report"] - + chunks = [(StreamEventType.CONTENT, "Weather report")] ai_msg = MagicMock() ai_msg.content = "Weather report" ai_msg.tool_calls = None - graph = _make_streaming_graph( - chunks, - final_messages=[ai_msg], + chunks, final_messages=[ai_msg], structured_response={"temperature": 22, "condition": "sunny"}, ) runner = DeepAgentRunner(graph) collected = [] - async for item in runner.stream_with_message("thread-1", "weather?"): - collected.append(item) + async for event in runner.stream_with_message("thread-1", "weather?"): + collected.append(event) - final_message = collected[-1] - assert isinstance(final_message, Message) - assert final_message.structured_response == {"temperature": 22, "condition": "sunny"} + final_event = collected[-1] + msg = Message.model_validate_json(final_event.data) + assert msg.structured_response == {"temperature": 22, "condition": "sunny"} async def test_stream_with_message_detects_hitl_interrupt(self): - """When state has interrupts, final Message has status=awaiting_hitl.""" - chunks = ["Waiting for approval"] - + chunks = [(StreamEventType.CONTENT, "Waiting for approval")] ai_msg = MagicMock() ai_msg.content = "" ai_msg.tool_calls = [{"name": "delete_file", "args": {"path": "/tmp/x"}, "id": "tc-1"}] - interrupt = MagicMock() - graph = _make_streaming_graph( - chunks, - final_messages=[ai_msg], - interrupts=(interrupt,), - ) + graph = _make_streaming_graph(chunks, final_messages=[ai_msg], interrupts=(interrupt,)) runner = DeepAgentRunner(graph) collected = [] - async for item in runner.stream_with_message("thread-1", "delete file"): - collected.append(item) + async for event in runner.stream_with_message("thread-1", "delete file"): + collected.append(event) - final_message = collected[-1] - assert isinstance(final_message, Message) - assert final_message.status == MessageStatus.AWAITING_HITL + final_event = collected[-1] + msg = Message.model_validate_json(final_event.data) + assert msg.status == MessageStatus.AWAITING_HITL async def test_stream_with_message_no_chunks_yields_message(self): - """When stream produces 0 AI chunks but graph completes, still yield a Message.""" - # Empty chunks — the stream yields nothing, but get_state still works graph = _make_streaming_graph([]) - runner = DeepAgentRunner(graph) collected = [] - async for item in runner.stream_with_message("thread-1", "hello"): - collected.append(item) + async for event in runner.stream_with_message("thread-1", "hello"): + collected.append(event) - # Should have exactly one item: the final Message assert len(collected) == 1 - assert isinstance(collected[0], Message) - assert collected[0].role == MessageRole.AI + assert collected[0].type == StreamEventType.MESSAGE + msg = Message.model_validate_json(collected[0].data) + assert msg.role == MessageRole.AI async def test_stream_with_message_raises_on_error(self): - """When astream raises, AgentError is raised.""" mock_graph = AsyncMock() async def _astream_error(_input, _config=None, _stream_mode=None): @@ -195,9 +175,8 @@ async def _astream_error(_input, _config=None, _stream_mode=None): yield mock_graph.astream = _astream_error - runner = DeepAgentRunner(mock_graph) with pytest.raises(AgentError, match="Streaming error"): collected = [] - async for item in runner.stream_with_message("thread-1", "hello"): - collected.append(item) + async for event in runner.stream_with_message("thread-1", "hello"): + collected.append(event) diff --git a/tests/unit/test_routes.py b/tests/unit/test_routes.py index 24c726d..4f63d09 100644 --- a/tests/unit/test_routes.py +++ b/tests/unit/test_routes.py @@ -13,6 +13,7 @@ from src.domain.entities.agent_config_metadata import AgentConfigMetadata from src.domain.entities.message import Message, MessageRole, MessageStatus +from src.domain.entities.stream_event import StreamEvent, StreamEventType from src.domain.exceptions import AgentError from src.infrastructure.persistent_registry.adapter import PersistentAgentRegistry from src.infrastructure.yaml_config.adapter import YamlAgentConfigLoader @@ -58,18 +59,21 @@ def mock_runner(): async def mock_stream(_thread_id, _message): for word in ["I", "am", "a", "mock", "agent."]: - yield word + " " + yield StreamEvent(type=StreamEventType.CONTENT, data=word + " ") runner.stream = mock_stream async def mock_stream_with_message(_thread_id, _message): for word in ["I", "am", "a", "mock", "agent."]: - yield word + " " - yield Message( - role=MessageRole.AI, - content="I am a mock agent.", - status=MessageStatus.COMPLETED, - structured_response={"key": "value"}, + yield StreamEvent(type=StreamEventType.CONTENT, data=word + " ") + yield StreamEvent( + type=StreamEventType.MESSAGE, + data=Message( + role=MessageRole.AI, + content="I am a mock agent.", + status=MessageStatus.COMPLETED, + structured_response={"key": "value"}, + ).model_dump_json(), ) runner.stream_with_message = mock_stream_with_message @@ -412,14 +416,14 @@ async def test_stream_ends_with_done(self, client): body = resp.text data_lines = [ - line.strip()[len("data:"):].strip() + line.strip()[len("data:") :].strip() for line in body.replace("\r\n", "\n").split("\n") if line.strip().startswith("data:") ] assert data_lines[-1] == "[DONE]", f"Expected [DONE] as last data line, got: {data_lines[-1]}" async def test_stream_emits_message_json_before_done(self, client): - """Stream emits Message JSON as second-to-last data line, before [DONE].""" + """Stream emits Message JSON before [DONE], even with a preceding structured event.""" create_resp = await client.post("/api/v1/threads", json={"agent_name": "my-agent"}) thread_id = create_resp.json()["id"] resp = await client.post( @@ -432,13 +436,23 @@ async def test_stream_emits_message_json_before_done(self, client): import json data_lines = [ - line.strip()[len("data:"):].strip() + line.strip()[len("data:") :].strip() for line in body.replace("\r\n", "\n").split("\n") if line.strip().startswith("data:") ] assert data_lines[-1] == "[DONE]" - message_json = json.loads(data_lines[-2]) + + # Find the last message event before [DONE] + message_event = None + for line in reversed(data_lines[:-1]): + event = json.loads(line) + if event.get("type") == "message": + message_event = event + break + + assert message_event is not None, "No message event found in stream" + message_json = json.loads(message_event["data"]) assert message_json["role"] == "ai" assert message_json["structured_response"] == {"key": "value"} @@ -464,12 +478,23 @@ async def test_stream_message_format_matches_sync(self, client): import json data_lines = [ - line.strip()[len("data:"):].strip() + line.strip()[len("data:") :].strip() for line in body.replace("\r\n", "\n").split("\n") if line.strip().startswith("data:") ] - message_json = json.loads(data_lines[-2]) + # Locate the message event (ignore structured events, [DONE], etc.) + message_event = None + for line in data_lines: + if line == "[DONE]": + continue + event = json.loads(line) + if event.get("type") == "message": + message_event = event + break + + assert message_event is not None, "No message event found in stream" + message_json = json.loads(message_event["data"]) for field in ["role", "content", "timestamp", "status"]: assert field in message_json, f"Missing field {field!r} in stream Message: {message_json}" diff --git a/tests/unit/test_runner_tracing.py b/tests/unit/test_runner_tracing.py index 8a1f637..be7e3c2 100644 --- a/tests/unit/test_runner_tracing.py +++ b/tests/unit/test_runner_tracing.py @@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock from src.domain.entities.message import MessageRole +from src.domain.entities.stream_event import StreamEventType from src.infrastructure.deepagent.adapter import DeepAgentRunner @@ -83,12 +84,13 @@ async def mock_astream(*_args, **_kwargs): runner = DeepAgentRunner(mock_graph, tracing_provider=mock_tracing_provider) - chunks = [] - async for chunk in runner.stream("thread-1", "Hi"): - chunks.append(chunk) + events = [] + async for event in runner.stream("thread-1", "Hi"): + events.append(event) - assert len(chunks) == 1 - assert chunks[0] == "chunk" + assert len(events) == 1 + assert events[0].type == StreamEventType.CONTENT + assert events[0].data == "chunk" def test_build_config_with_tracing(self, mock_tracing_provider): mock_callback = MagicMock()