Skip to content

anthropic: add include_response_headers parameter to ChatAnthropic #31579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 64 additions & 5 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,9 @@ def get_weather(location: str) -> str:
"name": "example-mcp"}]``
"""

include_response_headers: bool = False
"""Whether to include response headers in the output message response_metadata."""

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
Expand Down Expand Up @@ -1315,12 +1318,26 @@ def _create(self, payload: dict) -> Any:
else:
return self._client.messages.create(**payload)

def _create_with_raw_response(self, payload: dict) -> Any:
if "betas" in payload:
return self._client.beta.messages.with_raw_response.create(**payload)
else:
return self._client.messages.with_raw_response.create(**payload)

async def _acreate(self, payload: dict) -> Any:
if "betas" in payload:
return await self._async_client.beta.messages.create(**payload)
else:
return await self._async_client.messages.create(**payload)

async def _acreate_with_raw_response(self, payload: dict) -> Any:
if "betas" in payload:
return await self._async_client.beta.messages.with_raw_response.create(
**payload
)
else:
return await self._async_client.messages.with_raw_response.create(**payload)

def _stream(
self,
messages: list[BaseMessage],
Expand All @@ -1341,6 +1358,10 @@ def _stream(
and not _documents_in_params(payload)
and not _thinking_in_params(payload)
)
headers = {}
if self.include_response_headers and hasattr(stream, "response"):
headers = dict(stream.response.headers)

block_start_event = None
for event in stream:
msg, block_start_event = _make_message_chunk_from_anthropic_event(
Expand All @@ -1350,6 +1371,10 @@ def _stream(
block_start_event=block_start_event,
)
if msg is not None:
if headers and msg.response_metadata is not None:
msg.response_metadata["headers"] = headers
elif headers:
msg.response_metadata = {"headers": headers}
chunk = ChatGenerationChunk(message=msg)
if run_manager and isinstance(msg.content, str):
run_manager.on_llm_new_token(msg.content, chunk=chunk)
Expand Down Expand Up @@ -1377,6 +1402,10 @@ async def _astream(
and not _documents_in_params(payload)
and not _thinking_in_params(payload)
)
headers = {}
if self.include_response_headers and hasattr(stream, "response"):
headers = dict(stream.response.headers)

block_start_event = None
async for event in stream:
msg, block_start_event = _make_message_chunk_from_anthropic_event(
Expand All @@ -1386,14 +1415,20 @@ async def _astream(
block_start_event=block_start_event,
)
if msg is not None:
if headers and msg.response_metadata is not None:
msg.response_metadata["headers"] = headers
elif headers:
msg.response_metadata = {"headers": headers}
chunk = ChatGenerationChunk(message=msg)
if run_manager and isinstance(msg.content, str):
await run_manager.on_llm_new_token(msg.content, chunk=chunk)
yield chunk
except anthropic.BadRequestError as e:
_handle_anthropic_bad_request(e)

def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
def _format_output(
self, data: Any, headers: Optional[dict] = None, **kwargs: Any
) -> ChatResult:
data_dict = data.model_dump()
content = data_dict["content"]

Expand All @@ -1418,6 +1453,13 @@ def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
}
if "model" in llm_output and "model_name" not in llm_output:
llm_output["model_name"] = llm_output["model"]

# Only include response_metadata when headers are present
response_metadata = {}
if headers:
response_metadata = llm_output.copy()
response_metadata["headers"] = headers

if (
len(content) == 1
and content[0]["type"] == "text"
Expand All @@ -1432,6 +1474,11 @@ def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
)
else:
msg = AIMessage(content=content)

# Set response metadata if headers are present
if response_metadata:
msg.response_metadata = response_metadata

msg.usage_metadata = _create_usage_metadata(data.usage)
return ChatResult(
generations=[ChatGeneration(message=msg)],
Expand All @@ -1452,10 +1499,16 @@ def _generate(
return generate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
try:
data = self._create(payload)
if self.include_response_headers:
raw_response = self._create_with_raw_response(payload)
data = raw_response.parse()
headers = dict(raw_response.headers)
else:
data = self._create(payload)
headers = {}
except anthropic.BadRequestError as e:
_handle_anthropic_bad_request(e)
return self._format_output(data, **kwargs)
return self._format_output(data, headers=headers, **kwargs)

async def _agenerate(
self,
Expand All @@ -1471,10 +1524,16 @@ async def _agenerate(
return await agenerate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
try:
data = await self._acreate(payload)
if self.include_response_headers:
raw_response = await self._acreate_with_raw_response(payload)
data = raw_response.parse()
headers = dict(raw_response.headers)
else:
data = await self._acreate(payload)
headers = {}
except anthropic.BadRequestError as e:
_handle_anthropic_bad_request(e)
return self._format_output(data, **kwargs)
return self._format_output(data, headers=headers, **kwargs)

def _get_llm_for_structured_output_when_thinking_is_enabled(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1082,3 +1082,67 @@ def test_files_api_pdf(block_format: str) -> None:
],
}
_ = llm.invoke([input_message])


def test_anthropic_response_headers() -> None:
"""Test ChatAnthropic response headers."""
chat_anthropic = ChatAnthropic(model=MODEL_NAME, include_response_headers=True)
query = "I'm Pickle Rick"
result = chat_anthropic.invoke(query)
headers = result.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
# Check for common HTTP headers
assert any(
key.lower() in ["content-type", "request-id", "x-request-id"]
for key in headers.keys()
)

# Stream
full: Optional[BaseMessageChunk] = None
for chunk in chat_anthropic.stream(query):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessage)
headers = full.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert any(
key.lower() in ["content-type", "request-id", "x-request-id"]
for key in headers.keys()
)


async def test_anthropic_response_headers_async() -> None:
"""Test ChatAnthropic response headers for async methods."""
chat_anthropic = ChatAnthropic(model=MODEL_NAME, include_response_headers=True)
query = "I'm Pickle Rick"
result = await chat_anthropic.ainvoke(query)
headers = result.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert any(
key.lower() in ["content-type", "request-id", "x-request-id"]
for key in headers.keys()
)

# Stream
full: Optional[BaseMessageChunk] = None
async for chunk in chat_anthropic.astream(query):
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessage)
headers = full.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert any(
key.lower() in ["content-type", "request-id", "x-request-id"]
for key in headers.keys()
)


def test_anthropic_no_response_headers_by_default() -> None:
"""Test that headers are not included by default."""
chat_anthropic = ChatAnthropic(model=MODEL_NAME)
query = "I'm Pickle Rick"
result = chat_anthropic.invoke(query)
# assert no response headers if include_response_headers is not set
assert "headers" not in result.response_metadata
122 changes: 122 additions & 0 deletions libs/partners/anthropic/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,3 +1056,125 @@ def mock_create(*args: Any, **kwargs: Any) -> Message:
# Test headers are correctly propagated to request
payload = llm._get_request_payload([input_message])
assert payload["mcp_servers"][0]["authorization_token"] == "PLACEHOLDER"


def test_chat_anthropic_include_response_headers_initialization() -> None:
"""Test ChatAnthropic include_response_headers initialization."""
# Default should be False
llm = ChatAnthropic(model="claude-3-sonnet-20240229")
assert llm.include_response_headers is False

# Explicit setting should work
llm_with_headers = ChatAnthropic(
model="claude-3-sonnet-20240229", include_response_headers=True
)
assert llm_with_headers.include_response_headers is True


def test_chat_anthropic_invoke_without_response_headers() -> None:
"""Test that headers are not included when include_response_headers=False."""
llm = ChatAnthropic(model="claude-3-sonnet-20240229")

mock_response = Message(
id="msg_123",
content=[TextBlock(type="text", text="Hello")],
model="claude-3-sonnet-20240229",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
type="message",
usage=Usage(input_tokens=10, output_tokens=5),
)

with patch.object(llm, "_client") as mock_client:
mock_client.messages.create.return_value = mock_response

result = llm.invoke("Hello")

# headers should not be in response_metadata if include_response_headers not set
assert "headers" not in result.response_metadata

# Verify client was called without raw_response
assert mock_client.messages.create.called
assert not mock_client.messages.with_raw_response.create.called


def test_chat_anthropic_invoke_with_response_headers() -> None:
"""Test that headers are included when include_response_headers=True."""
llm = ChatAnthropic(model="claude-3-sonnet-20240229", include_response_headers=True)

mock_response = Message(
id="msg_123",
content=[TextBlock(type="text", text="Hello")],
model="claude-3-sonnet-20240229",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
type="message",
usage=Usage(input_tokens=10, output_tokens=5),
)

# Mock raw response with headers
mock_raw_response = MagicMock()
mock_raw_response.parse.return_value = mock_response
mock_raw_response.headers = {
"content-type": "application/json",
"request-id": "req_123",
}

with patch.object(llm, "_client") as mock_client:
mock_client.messages.with_raw_response.create.return_value = mock_raw_response

result = llm.invoke("Hello")

# headers should be in response_metadata if include_response_headers is True
assert "headers" in result.response_metadata
headers = result.response_metadata["headers"]
assert headers["content-type"] == "application/json"
assert headers["request-id"] == "req_123"

# Verify client was called with raw_response
assert mock_client.messages.with_raw_response.create.called


async def test_chat_anthropic_ainvoke_with_response_headers() -> None:
"""Test headers included in async invoke when include_response_headers=True."""
llm = ChatAnthropic(model="claude-3-sonnet-20240229", include_response_headers=True)

mock_response = Message(
id="msg_123",
content=[TextBlock(type="text", text="Hello")],
model="claude-3-sonnet-20240229",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
type="message",
usage=Usage(input_tokens=10, output_tokens=5),
)

# Mock raw response with headers
mock_raw_response = MagicMock()
mock_raw_response.parse.return_value = mock_response
mock_raw_response.headers = {
"content-type": "application/json",
"request-id": "req_456",
}

with patch.object(llm, "_async_client") as mock_client:
# Create an async mock for the return value
from unittest.mock import AsyncMock

mock_client.messages.with_raw_response.create = AsyncMock(
return_value=mock_raw_response
)

result = await llm.ainvoke("Hello")

# headers should be in response_metadata if include_response_headers is True
assert "headers" in result.response_metadata
headers = result.response_metadata["headers"]
assert headers["content-type"] == "application/json"
assert headers["request-id"] == "req_456"

# Verify async client was called with raw_response
assert mock_client.messages.with_raw_response.create.called
11 changes: 6 additions & 5 deletions libs/partners/anthropic/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading