Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions src/claude_agent_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
import uuid
from collections.abc import AsyncIterable, AsyncIterator
from dataclasses import asdict, replace
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
self._transport: Transport | None = None
self._query: Any | None = None
self._materialized: MaterializedResume | None = None
self._session_id: str = "default"

def _convert_hooks_to_internal_format(
self, hooks: dict[HookEvent, list[HookMatcher]]
Expand Down Expand Up @@ -262,7 +264,7 @@ async def _on_mirror_error(key: Any, error: str) -> None:
"type": "user",
"message": {"role": "user", "content": prompt},
"parent_tool_use_id": None,
"session_id": "default",
"session_id": self._session_id,
}
await self._transport.write(json.dumps(message) + "\n")
elif prompt is not None and isinstance(prompt, AsyncIterable):
Expand All @@ -281,35 +283,76 @@ async def receive_messages(self) -> AsyncIterator[Message]:
yield message

async def query(
self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str = "default"
self, prompt: str | AsyncIterable[dict[str, Any]], session_id: str | None = None
) -> None:
"""
Send a new request in streaming mode.

Args:
prompt: Either a string message or an async iterable of message dictionaries
session_id: Session identifier for the conversation
session_id: Session identifier for the conversation. If not provided,
uses the current session ID (set at construction or via clear()/resume()).
"""
if not self._query or not self._transport:
raise CLIConnectionError("Not connected. Call connect() first.")

effective_session_id = (
session_id if session_id is not None else self._session_id
)

# Handle string prompts
if isinstance(prompt, str):
message = {
"type": "user",
"message": {"role": "user", "content": prompt},
"parent_tool_use_id": None,
"session_id": session_id,
"session_id": effective_session_id,
}
await self._transport.write(json.dumps(message) + "\n")
else:
# Handle AsyncIterable prompts - stream them
async for msg in prompt:
# Ensure session_id is set on each message
if "session_id" not in msg:
msg["session_id"] = session_id
msg["session_id"] = effective_session_id
await self._transport.write(json.dumps(msg) + "\n")

def clear(self, session_id: str | None = None) -> None:
"""Reset conversation state without restarting the subprocess.

Resets the current session ID so the next query() starts a fresh
conversation on the same warm subprocess, avoiding the 60-second MCP
re-handshake cost of a full disconnect/reconnect.

Args:
session_id: Optional new session ID. If not provided, a UUID is generated.

Raises:
CLIConnectionError: If not connected.
"""
if not self._query or not self._transport:
raise CLIConnectionError("Not connected. Call connect() first.")

self._session_id = session_id if session_id is not None else str(uuid.uuid4())

def resume(self, session_id: str) -> None:
"""Switch to an existing session by ID without restarting the subprocess.

Args:
session_id: The session ID to resume.

Raises:
CLIConnectionError: If not connected.
ValueError: If session_id is empty.
"""
if not self._query or not self._transport:
raise CLIConnectionError("Not connected. Call connect() first.")

if not session_id:
raise ValueError("session_id must be a non-empty string")

self._session_id = session_id

async def interrupt(self) -> None:
"""Send interrupt signal (only works with streaming mode)."""
if not self._query:
Expand Down
43 changes: 43 additions & 0 deletions tests/test_clear_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Tests for ClaudeSDKClient.clear() and .resume() warm subprocess reuse."""

import pytest

from claude_agent_sdk import ClaudeSDKClient, CLIConnectionError
from claude_agent_sdk._internal.transport import Transport


class _NullTransport(Transport):
async def connect(self) -> None: ...
async def close(self) -> None: ...
async def write(self, data: str) -> None: ...
async def end_input(self) -> None: ...
def read_messages(self):
async def _r():
return
yield

return _r()

def is_ready(self) -> bool:
return False


def test_clear_raises_when_not_connected():
client = ClaudeSDKClient()
with pytest.raises(CLIConnectionError):
client.clear()


def test_resume_raises_when_not_connected():
client = ClaudeSDKClient()
with pytest.raises(CLIConnectionError):
client.resume("some-session")


def test_resume_raises_on_empty_session_id():
client = ClaudeSDKClient()
# Bypass connection check by patching internals
client._query = object()
client._transport = object()
with pytest.raises(ValueError, match="non-empty"):
client.resume("")