diff --git a/README.md b/README.md index 850319f..54349d0 100644 --- a/README.md +++ b/README.md @@ -449,8 +449,20 @@ data: of data: code data: align data: ... +event: message +data: {"role":"ai","content":"Lines of code align...","timestamp":"2025-04-24T10:30:05.000000Z","tool_calls":null,"status":"completed","structured_response":null} + +event: done +data: ``` +The stream emits: +1. **`data: `** — text chunks as they are generated (keepalive for proxies like Cloudflare) +2. **`event: message`** — the complete `Message` JSON with all fields (role, content, timestamp, tool_calls, status, structured_response), identical in format to the synchronous `POST /chat/{thread_id}` response +3. **`event: done`** — signals the stream is complete + +This design prevents Cloudflare timeout issues (~100s on idle connections) because chunks and SSE pings (every 15s) keep the connection active. Clients that need the full structured response can read the `event: message` data. + ### 7. List All Threads ```bash @@ -700,11 +712,20 @@ ws.onmessage = (event) => { if (event.data === "[END]") { console.log("Response complete"); } else { - process.stdout.write(event.data); + try { + const data = JSON.parse(event.data); + if (data.type === "message") { + console.log("Final message:", data); + } + } catch { + process.stdout.write(event.data); + } } }; ``` +The WebSocket stream emits text chunks, then a JSON object with `type: "message"` containing the full `Message` fields, followed by `[END]`. + --- ## Prompt Management Setup diff --git a/src/application/routes/chat.py b/src/application/routes/chat.py index e5b8fd3..1ebaf06 100644 --- a/src/application/routes/chat.py +++ b/src/application/routes/chat.py @@ -45,10 +45,13 @@ async def stream_message( async def event_generator(): chunk_count = 0 try: - async for chunk in use_case.execute(thread_id, body.message): - chunk_count += 1 - yield {"data": chunk} - yield {"event": "done", "data": ""} + async for event in use_case.execute(thread_id, body.message): + if isinstance(event, str): + chunk_count += 1 + yield {"data": event} + elif isinstance(event, Message): + yield {"data": event.model_dump_json()} + yield {"data": "[DONE]"} logger.info("[thread=%s] Stream complete, %d chunks", thread_id, chunk_count) except Exception: logger.exception("[thread=%s] Stream error after %d chunks", thread_id, chunk_count) diff --git a/src/application/routes/websocket.py b/src/application/routes/websocket.py index fdf8306..d2edebb 100644 --- a/src/application/routes/websocket.py +++ b/src/application/routes/websocket.py @@ -5,6 +5,7 @@ from src.application.use_cases.stream_message import StreamMessageUseCase from src.dependencies import get_stream_message_use_case +from src.domain.entities.message import Message logger = logging.getLogger("composable-agents") @@ -32,9 +33,12 @@ async def websocket_chat( logger.info("[thread=%s] WS message received: %s", thread_id, message[:80]) chunk_count = 0 try: - async for chunk in use_case.execute(thread_id, message): - chunk_count += 1 - await websocket.send_text(chunk) + async for event in use_case.execute(thread_id, message): + if isinstance(event, str): + chunk_count += 1 + await websocket.send_text(event) + elif isinstance(event, Message): + await websocket.send_text(json.dumps({"type": "message", **event.model_dump(mode="json")})) await websocket.send_text("[END]") logger.info("[thread=%s] WS stream complete, %d chunks", thread_id, chunk_count) except Exception: diff --git a/src/application/use_cases/stream_message.py b/src/application/use_cases/stream_message.py index d87591c..c1c0ad1 100644 --- a/src/application/use_cases/stream_message.py +++ b/src/application/use_cases/stream_message.py @@ -10,44 +10,47 @@ class StreamMessageUseCase: - """Envoie un message a l'agent et streame la reponse.""" + """Envoie un message a l'agent et streame la reponse avec le Message final.""" def __init__(self, registry: AgentRegistry, threads: ThreadRepository): self._registry = registry self._threads = threads - async def execute(self, thread_id: str, message: str) -> AsyncGenerator[str, None]: + async def execute(self, thread_id: str, message: str) -> AsyncGenerator[str | Message, None]: thread = await self._threads.get(thread_id) human_msg = Message(role=MessageRole.HUMAN, content=message) await self._threads.add_message(thread_id, human_msg) runner = await self._registry.get_runner(thread.agent_name) start = time.monotonic() logger.info("[thread=%s][agent=%s] Stream started", thread_id, thread.agent_name) - full_response = [] chunk_count = 0 + final_message = None try: - async for chunk in runner.stream(thread_id, message): - chunk_count += 1 - full_response.append(chunk) - yield chunk + async for event in runner.stream_with_message(thread_id, message): + if isinstance(event, str): + chunk_count += 1 + yield event + elif isinstance(event, Message): + final_message = event except Exception: logger.exception( "[thread=%s][agent=%s] Stream error after %d chunks", thread_id, thread.agent_name, chunk_count ) raise elapsed = time.monotonic() - start - ai_msg = Message(role=MessageRole.AI, content="".join(full_response)) - try: - await self._threads.add_message(thread_id, ai_msg) - except Exception: - logger.exception( - "[thread=%s][agent=%s] Failed to persist AI message after stream", thread_id, thread.agent_name - ) + if final_message is not None: + try: + await self._threads.add_message(thread_id, final_message) + except Exception: + logger.exception( + "[thread=%s][agent=%s] Failed to persist AI message after stream", thread_id, thread.agent_name + ) + yield final_message logger.info( - "[thread=%s][agent=%s] Stream complete, %d chunks, %d chars, elapsed=%.2fs", + "[thread=%s][agent=%s] Stream complete, %d chunks, elapsed=%.2fs, message=%s", thread_id, thread.agent_name, chunk_count, - len(ai_msg.content), elapsed, + "yielded" if final_message else "none", ) diff --git a/src/domain/ports/agent_runner.py b/src/domain/ports/agent_runner.py index 953dd77..2f756ab 100644 --- a/src/domain/ports/agent_runner.py +++ b/src/domain/ports/agent_runner.py @@ -15,6 +15,11 @@ async def stream(self, thread_id: str, message: str) -> AsyncIterator[str]: """Envoie un message et streame la reponse par chunks.""" ... + @abstractmethod + async def stream_with_message(self, thread_id: str, message: str) -> AsyncIterator[str | Message]: + """Streame les chunks puis yield le Message final complet.""" + ... + @abstractmethod async def approve_hitl(self, thread_id: str, tool_call_id: str) -> Message: """Approuve une action HITL en attente.""" diff --git a/src/infrastructure/deepagent/adapter.py b/src/infrastructure/deepagent/adapter.py index 496d629..28e8d38 100644 --- a/src/infrastructure/deepagent/adapter.py +++ b/src/infrastructure/deepagent/adapter.py @@ -62,7 +62,9 @@ def _build_config(self, thread_id: str) -> dict: def _build_response(self, result: dict, config: dict) -> Message: """Build a Message from graph result, detecting interrupts and collecting tool_calls.""" - messages = result["messages"] + messages = result.get("messages", []) + if not messages: + raise AgentError("Graph completed but no messages were found in the final state.") last_message = messages[-1] all_tool_calls = getattr(last_message, "tool_calls", None) or [] @@ -107,32 +109,74 @@ async def invoke(self, thread_id: str, message: str) -> Message: logger.exception("[thread=%s] Agent execution error", thread_id) raise AgentError(f"Agent execution error: {e}") from e + async def _yield_chunks( + self, + thread_id: str, + message: str, + config: dict, + stats: dict, + ) -> AsyncIterator[str]: + """Stream graph chunks and populate *stats* with timing.""" + start = time.monotonic() + first_chunk = True + chunk_count = 0 + async for chunk, _metadata in self._graph.astream( + {"messages": [{"role": "human", "content": message}]}, + config=config, + stream_mode="messages", + ): + if hasattr(chunk, "content") and chunk.content and chunk.type == "AIMessageChunk": + if first_chunk: + logger.info( + "[thread=%s] First chunk received, elapsed=%.2fs", + thread_id, + time.monotonic() - start, + ) + first_chunk = False + chunk_count += 1 + yield chunk.content + stats["chunk_count"] = chunk_count + stats["elapsed"] = time.monotonic() - start + async def stream(self, thread_id: str, message: str) -> AsyncIterator[str]: config = self._build_config(thread_id) logger.info("[thread=%s] Streaming agent response", thread_id) try: - start = time.monotonic() - first_chunk = True - chunk_count = 0 - async for chunk, _metadata in self._graph.astream( - {"messages": [{"role": "human", "content": message}]}, - config=config, - stream_mode="messages", - ): - if hasattr(chunk, "content") and chunk.content and chunk.type == "AIMessageChunk": - if first_chunk: - logger.info( - "[thread=%s] First chunk received, elapsed=%.2fs", thread_id, time.monotonic() - start - ) - first_chunk = False - chunk_count += 1 - yield chunk.content + stats: dict = {} + async for chunk in self._yield_chunks(thread_id, message, config, stats): + yield chunk logger.info( "[thread=%s] Stream complete, %d chunks, elapsed=%.2fs", thread_id, - chunk_count, - time.monotonic() - start, + stats["chunk_count"], + stats["elapsed"], + ) + except Exception as e: + logger.exception("[thread=%s] Streaming error", thread_id) + raise AgentError(f"Streaming error: {e}") from e + + async def stream_with_message(self, thread_id: str, message: str) -> AsyncIterator[str | Message]: + config = self._build_config(thread_id) + logger.info("[thread=%s] Streaming agent response with final message", thread_id) + try: + stats: dict = {} + async for chunk in self._yield_chunks(thread_id, message, config, stats): + yield chunk + state = self._graph.get_state(config) + values = state.values if state and hasattr(state, "values") else {} + result = { + "messages": values.get("messages", []), + "structured_response": values.get("structured_response"), + } + response = self._build_response(result, config) + logger.info( + "[thread=%s] Stream with message complete, %d chunks, elapsed=%.2fs, status=%s", + thread_id, + stats["chunk_count"], + stats["elapsed"], + response.status, ) + yield response except Exception as e: logger.exception("[thread=%s] Streaming error", thread_id) raise AgentError(f"Streaming error: {e}") from e diff --git a/tests/unit/test_deep_agent_runner_stream_with_message.py b/tests/unit/test_deep_agent_runner_stream_with_message.py new file mode 100644 index 0000000..e5b01a9 --- /dev/null +++ b/tests/unit/test_deep_agent_runner_stream_with_message.py @@ -0,0 +1,203 @@ +"""Tests for DeepAgentRunner.stream_with_message(). + +Graph is mocked (external LLM boundary via LangGraph). +These tests exercise the new stream_with_message() method which yields +str chunks during streaming and a final Message object after the stream +completes, allowing callers to receive both streaming chunks and a +structured complete response. +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from src.domain.entities.message import Message, MessageRole, MessageStatus +from src.domain.exceptions import AgentError +from src.infrastructure.deepagent.adapter import DeepAgentRunner + + +def _make_streaming_graph( + chunks: list[str], + final_messages: list | None = None, + interrupts=(), + state_values: dict | None = None, + structured_response=None, +): + """Create a mock graph with astream and get_state for stream_with_message. + + Args: + chunks: List of string chunks to yield from astream. + final_messages: Messages to appear in get_state().values["messages"]. + If None, a default AI message is constructed from the chunks. + interrupts: Interrupt tuples for get_state(). + state_values: Additional state values for get_state(). + structured_response: Optional structured_response value in state. + """ + mock_graph = AsyncMock() + + # Build async generator for astream + async def _astream(_input, **_kwargs): + for chunk_text in chunks: + chunk = MagicMock() + chunk.content = chunk_text + chunk.type = "AIMessageChunk" + yield chunk, MagicMock() + + mock_graph.astream = _astream + + # Build state for get_state (called after stream to build Message) + state = MagicMock() + state.interrupts = interrupts + + # Build final AI message from joined chunks if no explicit messages given + if final_messages is None: + final_ai = MagicMock() + final_ai.content = "".join(chunks) + final_ai.tool_calls = None + final_messages = [final_ai] + + values = state_values or {} + if "messages" not in values: + values["messages"] = final_messages + if structured_response is not None: + values["structured_response"] = structured_response + state.values = values + + mock_graph.get_state = MagicMock(return_value=state) + + # ainvoke return value (used by get_state context) + mock_graph.ainvoke.return_value = { + "messages": final_messages, + "structured_response": structured_response, + } + + return mock_graph + + +class TestStreamWithMessage: + """Tests for DeepAgentRunner.stream_with_message().""" + + async def test_stream_with_message_yields_chunks_then_message(self): + """Should yield str chunks during streaming, then a final Message.""" + chunks = ["Hello ", "world!"] + graph = _make_streaming_graph(chunks) + + runner = DeepAgentRunner(graph) + collected = [] + async for item in runner.stream_with_message("thread-1", "hi"): + collected.append(item) + + # All items except the last should be str chunks + str_items = collected[:-1] + final_message = collected[-1] + + assert all(isinstance(c, str) for c in str_items) + assert str_items == ["Hello ", "world!"] + + assert isinstance(final_message, Message) + assert final_message.role == MessageRole.AI + assert final_message.content == "Hello world!" + assert final_message.status == MessageStatus.COMPLETED + + async def test_stream_with_message_final_message_has_tool_calls(self): + """When the last message has tool_calls, the final Message includes them.""" + chunks = ["Processing..."] + + ai_msg = MagicMock() + ai_msg.content = "Processing..." + ai_msg.tool_calls = [{"name": "search", "args": {"q": "test"}, "id": "tc-1"}] + + graph = _make_streaming_graph( + chunks, + final_messages=[ai_msg], + structured_response=None, + ) + + runner = DeepAgentRunner(graph) + collected = [] + async for item in runner.stream_with_message("thread-1", "search for test"): + collected.append(item) + + final_message = collected[-1] + assert isinstance(final_message, Message) + assert final_message.tool_calls is not None + assert len(final_message.tool_calls) == 1 + assert final_message.tool_calls[0]["name"] == "search" + + async def test_stream_with_message_final_message_has_structured_response(self): + """When result has structured_response, it appears in the final Message.""" + chunks = ["Weather report"] + + ai_msg = MagicMock() + ai_msg.content = "Weather report" + ai_msg.tool_calls = None + + graph = _make_streaming_graph( + chunks, + final_messages=[ai_msg], + structured_response={"temperature": 22, "condition": "sunny"}, + ) + + runner = DeepAgentRunner(graph) + collected = [] + async for item in runner.stream_with_message("thread-1", "weather?"): + collected.append(item) + + final_message = collected[-1] + assert isinstance(final_message, Message) + assert final_message.structured_response == {"temperature": 22, "condition": "sunny"} + + async def test_stream_with_message_detects_hitl_interrupt(self): + """When state has interrupts, final Message has status=awaiting_hitl.""" + chunks = ["Waiting for approval"] + + ai_msg = MagicMock() + ai_msg.content = "" + ai_msg.tool_calls = [{"name": "delete_file", "args": {"path": "/tmp/x"}, "id": "tc-1"}] + + interrupt = MagicMock() + graph = _make_streaming_graph( + chunks, + final_messages=[ai_msg], + interrupts=(interrupt,), + ) + + runner = DeepAgentRunner(graph) + collected = [] + async for item in runner.stream_with_message("thread-1", "delete file"): + collected.append(item) + + final_message = collected[-1] + assert isinstance(final_message, Message) + assert final_message.status == MessageStatus.AWAITING_HITL + + async def test_stream_with_message_no_chunks_yields_message(self): + """When stream produces 0 AI chunks but graph completes, still yield a Message.""" + # Empty chunks — the stream yields nothing, but get_state still works + graph = _make_streaming_graph([]) + + runner = DeepAgentRunner(graph) + collected = [] + async for item in runner.stream_with_message("thread-1", "hello"): + collected.append(item) + + # Should have exactly one item: the final Message + assert len(collected) == 1 + assert isinstance(collected[0], Message) + assert collected[0].role == MessageRole.AI + + async def test_stream_with_message_raises_on_error(self): + """When astream raises, AgentError is raised.""" + mock_graph = AsyncMock() + + async def _astream_error(_input, _config=None, _stream_mode=None): + raise RuntimeError("LLM streaming error") + yield + + mock_graph.astream = _astream_error + + runner = DeepAgentRunner(mock_graph) + with pytest.raises(AgentError, match="Streaming error"): + collected = [] + async for item in runner.stream_with_message("thread-1", "hello"): + collected.append(item) diff --git a/tests/unit/test_routes.py b/tests/unit/test_routes.py index 98c4145..24c726d 100644 --- a/tests/unit/test_routes.py +++ b/tests/unit/test_routes.py @@ -19,15 +19,18 @@ from src.main import app from tests.fixtures.in_memory_thread_repository import InMemoryThreadRepository -VALID_YAML = ( - "name: {name}\n" - "model: test-model\n" - 'system_prompt: "Test prompt."\n' - "tools: []\n" - "debug: false\n" -) +VALID_YAML = 'name: {name}\nmodel: test-model\nsystem_prompt: "Test prompt."\ntools: []\ndebug: false\n' -AGENTS = ["my-agent", "agent-1", "agent-2", "example-agent", "code-reviewer", "minimal-agent", "research-assistant", "mcp-agent"] +AGENTS = [ + "my-agent", + "agent-1", + "agent-2", + "example-agent", + "code-reviewer", + "minimal-agent", + "research-assistant", + "mcp-agent", +] @pytest.fixture @@ -58,6 +61,18 @@ async def mock_stream(_thread_id, _message): yield word + " " runner.stream = mock_stream + + async def mock_stream_with_message(_thread_id, _message): + for word in ["I", "am", "a", "mock", "agent."]: + yield word + " " + yield Message( + role=MessageRole.AI, + content="I am a mock agent.", + status=MessageStatus.COMPLETED, + structured_response={"key": "value"}, + ) + + runner.stream_with_message = mock_stream_with_message return runner @@ -79,6 +94,7 @@ def mock_config_store(yaml_store): async def _get(name): if name not in yaml_store: from src.domain.exceptions import AgentNotFoundError + raise AgentNotFoundError(f"Agent config not found: {name}") return yaml_store[name] @@ -376,3 +392,86 @@ async def test_agent_error_returns_502(self, client, mock_runner): mock_runner.invoke.return_value = Message( role=MessageRole.AI, content="I am a mock agent.", status=MessageStatus.COMPLETED ) + + +# -- Stream Message Event ------------------------------------------------------ + + +class TestStreamMessageEvent: + """Tests for SSE stream: JSON message then [DONE] terminator.""" + + async def test_stream_ends_with_done(self, client): + """Stream always ends with data: [DONE].""" + create_resp = await client.post("/api/v1/threads", json={"agent_name": "my-agent"}) + thread_id = create_resp.json()["id"] + resp = await client.post( + f"/api/v1/chat/{thread_id}/stream", + json={"message": "Hello agent"}, + ) + assert resp.status_code == 200 + body = resp.text + + data_lines = [ + line.strip()[len("data:"):].strip() + for line in body.replace("\r\n", "\n").split("\n") + if line.strip().startswith("data:") + ] + assert data_lines[-1] == "[DONE]", f"Expected [DONE] as last data line, got: {data_lines[-1]}" + + async def test_stream_emits_message_json_before_done(self, client): + """Stream emits Message JSON as second-to-last data line, before [DONE].""" + create_resp = await client.post("/api/v1/threads", json={"agent_name": "my-agent"}) + thread_id = create_resp.json()["id"] + resp = await client.post( + f"/api/v1/chat/{thread_id}/stream", + json={"message": "Hello agent"}, + ) + assert resp.status_code == 200 + body = resp.text + + import json + + data_lines = [ + line.strip()[len("data:"):].strip() + for line in body.replace("\r\n", "\n").split("\n") + if line.strip().startswith("data:") + ] + + assert data_lines[-1] == "[DONE]" + message_json = json.loads(data_lines[-2]) + assert message_json["role"] == "ai" + assert message_json["structured_response"] == {"key": "value"} + + async def test_stream_message_format_matches_sync(self, client): + """The Message JSON from stream has the same fields as the sync endpoint.""" + create_resp = await client.post("/api/v1/threads", json={"agent_name": "my-agent"}) + thread_id = create_resp.json()["id"] + + sync_resp = await client.post(f"/api/v1/chat/{thread_id}", json={"message": "Compare me"}) + assert sync_resp.status_code == 200 + sync_data = sync_resp.json() + + create_resp2 = await client.post("/api/v1/threads", json={"agent_name": "my-agent"}) + thread_id2 = create_resp2.json()["id"] + + stream_resp = await client.post( + f"/api/v1/chat/{thread_id2}/stream", + json={"message": "Hello agent"}, + ) + assert stream_resp.status_code == 200 + body = stream_resp.text + + import json + + data_lines = [ + line.strip()[len("data:"):].strip() + for line in body.replace("\r\n", "\n").split("\n") + if line.strip().startswith("data:") + ] + + message_json = json.loads(data_lines[-2]) + + for field in ["role", "content", "timestamp", "status"]: + assert field in message_json, f"Missing field {field!r} in stream Message: {message_json}" + + assert message_json["role"] == sync_data["role"] diff --git a/tests/unit/test_runner_tracing.py b/tests/unit/test_runner_tracing.py index 54b43d4..8a1f637 100644 --- a/tests/unit/test_runner_tracing.py +++ b/tests/unit/test_runner_tracing.py @@ -77,6 +77,9 @@ async def mock_astream(*_args, **_kwargs): yield (mock_msg, {"langgraph_node": "agent"}) mock_graph.astream = mock_astream + mock_graph.get_state = MagicMock( + return_value=MagicMock(values={"messages": [MagicMock(content="chunk", tool_calls=None)]}, interrupts=()) + ) runner = DeepAgentRunner(mock_graph, tracing_provider=mock_tracing_provider)