diff --git a/src/claude_agent_sdk/client.py b/src/claude_agent_sdk/client.py index 3ddf4c9f9..b77fe0651 100644 --- a/src/claude_agent_sdk/client.py +++ b/src/claude_agent_sdk/client.py @@ -2,6 +2,7 @@ import json import os +import uuid from collections.abc import AsyncIterable, AsyncIterator from dataclasses import asdict, replace from typing import TYPE_CHECKING, Any @@ -77,6 +78,7 @@ def __init__( self._transport: Transport | None = None self._query: Any | None = None self._materialized: MaterializedResume | None = None + self._session_id: str = "default" def _convert_hooks_to_internal_format( self, hooks: dict[HookEvent, list[HookMatcher]] @@ -262,7 +264,7 @@ async def _on_mirror_error(key: Any, error: str) -> None: "type": "user", "message": {"role": "user", "content": prompt}, "parent_tool_use_id": None, - "session_id": "default", + "session_id": self._session_id, } await self._transport.write(json.dumps(message) + "\n") elif prompt is not None and isinstance(prompt, AsyncIterable): @@ -281,25 +283,30 @@ async def receive_messages(self) -> AsyncIterator[Message]: yield message async def query( - self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default" + self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str | None = None ) -> None: """ Send a new request in streaming mode. Args: prompt: Either a string message or an async iterable of message dictionaries - session_id: Session identifier for the conversation + session_id: Session identifier for the conversation. If not provided, + uses the current session ID (set at construction or via clear()/resume()). """ if not self._query or not self._transport: raise CLIConnectionError("Not connected. Call connect() first.") + effective_session_id = ( + session_id if session_id is not None else self._session_id + ) + # Handle string prompts if isinstance(prompt, str): message = { "type": "user", "message": {"role": "user", "content": prompt}, "parent_tool_use_id": None, - "session_id": session_id, + "session_id": effective_session_id, } await self._transport.write(json.dumps(message) + "\n") else: @@ -307,9 +314,45 @@ async def query( async for msg in prompt: # Ensure session_id is set on each message if "session_id" not in msg: - msg["session_id"] = session_id + msg["session_id"] = effective_session_id await self._transport.write(json.dumps(msg) + "\n") + def clear(self, session_id: str | None = None) -> None: + """Reset conversation state without restarting the subprocess. + + Resets the current session ID so the next query() starts a fresh + conversation on the same warm subprocess, avoiding the 60-second MCP + re-handshake cost of a full disconnect/reconnect. + + Args: + session_id: Optional new session ID. If not provided, a UUID is generated. + + Raises: + CLIConnectionError: If not connected. + """ + if not self._query or not self._transport: + raise CLIConnectionError("Not connected. Call connect() first.") + + self._session_id = session_id if session_id is not None else str(uuid.uuid4()) + + def resume(self, session_id: str) -> None: + """Switch to an existing session by ID without restarting the subprocess. + + Args: + session_id: The session ID to resume. + + Raises: + CLIConnectionError: If not connected. + ValueError: If session_id is empty. + """ + if not self._query or not self._transport: + raise CLIConnectionError("Not connected. Call connect() first.") + + if not session_id: + raise ValueError("session_id must be a non-empty string") + + self._session_id = session_id + async def interrupt(self) -> None: """Send interrupt signal (only works with streaming mode).""" if not self._query: diff --git a/tests/test_clear_resume.py b/tests/test_clear_resume.py new file mode 100644 index 000000000..478f8e527 --- /dev/null +++ b/tests/test_clear_resume.py @@ -0,0 +1,43 @@ +"""Tests for ClaudeSDKClient.clear() and .resume() warm subprocess reuse.""" + +import pytest + +from claude_agent_sdk import ClaudeSDKClient, CLIConnectionError +from claude_agent_sdk._internal.transport import Transport + + +class _NullTransport(Transport): + async def connect(self) -> None: ... + async def close(self) -> None: ... + async def write(self, data: str) -> None: ... + async def end_input(self) -> None: ... + def read_messages(self): + async def _r(): + return + yield + + return _r() + + def is_ready(self) -> bool: + return False + + +def test_clear_raises_when_not_connected(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError): + client.clear() + + +def test_resume_raises_when_not_connected(): + client = ClaudeSDKClient() + with pytest.raises(CLIConnectionError): + client.resume("some-session") + + +def test_resume_raises_on_empty_session_id(): + client = ClaudeSDKClient() + # Bypass connection check by patching internals + client._query = object() + client._transport = object() + with pytest.raises(ValueError, match="non-empty"): + client.resume("")