diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index b8677cb27..9c6133989 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -219,6 +219,11 @@ name = "test_streamable_http_json_response" required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] path = "tests/test_streamable_http_json_response.rs" +[[test]] +name = "test_streamable_http_4xx_error_body" +required-features = ["transport-streamable-http-client", "transport-streamable-http-client-reqwest"] +path = "tests/test_streamable_http_4xx_error_body.rs" + [[test]] name = "test_custom_request" diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index 8fca86fbc..fc37414e7 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -6,7 +6,7 @@ use reqwest::header::ACCEPT; use sse_stream::{Sse, SseStream}; use crate::{ - model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, + model::{ClientJsonRpcMessage, JsonRpcMessage, ServerJsonRpcMessage}, transport::{ common::http_header::{ EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION, @@ -59,6 +59,15 @@ fn apply_custom_headers( Ok(builder) } +/// Attempts to parse `body` as a JSON-RPC error message. +/// Returns `None` if the body is not parseable or is not a `JsonRpcMessage::Error`. +fn parse_json_rpc_error(body: &str) -> Option { + match serde_json::from_str::(body) { + Ok(message @ JsonRpcMessage::Error(_)) => Some(message), + _ => None, + } +} + impl StreamableHttpClient for reqwest::Client { type Error = reqwest::Error; @@ -190,21 +199,40 @@ impl StreamableHttpClient for reqwest::Client { if status == reqwest::StatusCode::NOT_FOUND && session_was_attached { return Err(StreamableHttpError::SessionExpired); } + let content_type = response + .headers() + .get(reqwest::header::CONTENT_TYPE) + .map(|ct| String::from_utf8_lossy(ct.as_bytes()).to_string()); + let session_id = response + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + // Non-success responses may carry valid JSON-RPC error payloads that + // should be surfaced as McpError rather than lost in TransportSend. if !status.is_success() { let body = response .text() .await .unwrap_or_else(|_| "".to_owned()); + if content_type + .as_deref() + .is_some_and(|ct| ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes())) + { + match parse_json_rpc_error(&body) { + Some(message) => { + return Ok(StreamableHttpPostResponse::Json(message, session_id)); + } + None => tracing::warn!( + "HTTP {status}: could not parse JSON body as a JSON-RPC error" + ), + } + } return Err(StreamableHttpError::UnexpectedServerResponse(Cow::Owned( format!("HTTP {status}: {body}"), ))); } - let content_type = response.headers().get(reqwest::header::CONTENT_TYPE); - let session_id = response.headers().get(HEADER_SESSION_ID); - let session_id = session_id - .and_then(|v| v.to_str().ok()) - .map(|s| s.to_string()); - match content_type { + match content_type.as_deref() { Some(ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => { let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); Ok(StreamableHttpPostResponse::Sse(event_stream, session_id)) @@ -226,9 +254,7 @@ impl StreamableHttpClient for reqwest::Client { _ => { // unexpected content type tracing::error!("unexpected content type: {:?}", content_type); - Err(StreamableHttpError::UnexpectedContentType( - content_type.map(|ct| String::from_utf8_lossy(ct.as_bytes()).to_string()), - )) + Err(StreamableHttpError::UnexpectedContentType(content_type)) } } } @@ -308,8 +334,8 @@ fn extract_scope_from_header(header: &str) -> Option { #[cfg(test)] mod tests { - use super::extract_scope_from_header; - use crate::transport::streamable_http_client::InsufficientScopeError; + use super::{extract_scope_from_header, parse_json_rpc_error}; + use crate::{model::JsonRpcMessage, transport::streamable_http_client::InsufficientScopeError}; #[test] fn extract_scope_quoted() { @@ -356,4 +382,36 @@ mod tests { assert!(!without_scope.can_upgrade()); assert_eq!(without_scope.get_required_scope(), None); } + + #[test] + fn parse_json_rpc_error_returns_error_variant() { + let body = + r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid Request"}}"#; + assert!(matches!( + parse_json_rpc_error(body), + Some(JsonRpcMessage::Error(_)) + )); + } + + #[test] + fn parse_json_rpc_error_rejects_non_error_request() { + // A valid JSON-RPC request (method + id) must not be accepted as an error. + let body = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; + assert!(parse_json_rpc_error(body).is_none()); + } + + #[test] + fn parse_json_rpc_error_rejects_notification() { + // A notification (method, no id) must not be accepted as an error. + let body = + r#"{"jsonrpc":"2.0","method":"notifications/cancelled","params":{"requestId":1}}"#; + assert!(parse_json_rpc_error(body).is_none()); + } + + #[test] + fn parse_json_rpc_error_rejects_malformed_json() { + assert!(parse_json_rpc_error("not json at all").is_none()); + assert!(parse_json_rpc_error("").is_none()); + assert!(parse_json_rpc_error(r#"{"broken":"#).is_none()); + } } diff --git a/crates/rmcp/tests/test_streamable_http_4xx_error_body.rs b/crates/rmcp/tests/test_streamable_http_4xx_error_body.rs new file mode 100644 index 000000000..ea49a4172 --- /dev/null +++ b/crates/rmcp/tests/test_streamable_http_4xx_error_body.rs @@ -0,0 +1,121 @@ +#![cfg(all( + feature = "transport-streamable-http-client", + feature = "transport-streamable-http-client-reqwest", + not(feature = "local") +))] + +use std::{collections::HashMap, sync::Arc}; + +use rmcp::{ + model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, + transport::streamable_http_client::{ + StreamableHttpClient, StreamableHttpError, StreamableHttpPostResponse, + }, +}; + +/// Spin up a minimal axum server that always responds with the given status, +/// content-type, and body — no MCP logic involved. +async fn spawn_mock_server(status: u16, content_type: &'static str, body: &'static str) -> String { + use axum::{Router, body::Body, http::Response, routing::post}; + + let router = Router::new().route( + "/mcp", + post(move || async move { + Response::builder() + .status(status) + .header("content-type", content_type) + .body(Body::from(body)) + .unwrap() + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, router).await.unwrap(); + }); + + format!("http://{addr}/mcp") +} + +fn ping_message() -> ClientJsonRpcMessage { + ClientJsonRpcMessage::request( + ClientRequest::PingRequest(PingRequest::default()), + RequestId::Number(1), + ) +} + +/// HTTP 4xx with Content-Type: application/json and a valid JSON-RPC error body +/// must be surfaced as `StreamableHttpPostResponse::Json`, not swallowed as a +/// transport error. +#[tokio::test] +async fn http_4xx_json_rpc_error_body_is_surfaced_as_json_response() { + let body = r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid Request"}}"#; + let url = spawn_mock_server(400, "application/json", body).await; + + let client = reqwest::Client::new(); + let result = client + .post_message( + Arc::from(url.as_str()), + ping_message(), + None, + None, + HashMap::new(), + ) + .await; + + match result { + Ok(StreamableHttpPostResponse::Json(msg, _)) => { + let json = serde_json::to_value(&msg).unwrap(); + assert_eq!(json["error"]["code"], -32600); + assert_eq!(json["error"]["message"], "Invalid Request"); + } + other => panic!("expected Json response, got: {other:?}"), + } +} + +/// HTTP 4xx with non-JSON content-type must still return `UnexpectedServerResponse` +/// (no regression on the original error path). +#[tokio::test] +async fn http_4xx_non_json_body_returns_unexpected_server_response() { + let url = spawn_mock_server(400, "text/plain", "Bad Request").await; + + let client = reqwest::Client::new(); + let result = client + .post_message( + Arc::from(url.as_str()), + ping_message(), + None, + None, + HashMap::new(), + ) + .await; + + match result { + Err(StreamableHttpError::UnexpectedServerResponse(_)) => {} + other => panic!("expected UnexpectedServerResponse, got: {other:?}"), + } +} + +/// HTTP 4xx with Content-Type: application/json but a body that is NOT a valid +/// JSON-RPC message must fall back to `UnexpectedServerResponse`. +#[tokio::test] +async fn http_4xx_malformed_json_body_falls_back_to_unexpected_server_response() { + let url = spawn_mock_server(400, "application/json", r#"{"error":"not jsonrpc"}"#).await; + + let client = reqwest::Client::new(); + let result = client + .post_message( + Arc::from(url.as_str()), + ping_message(), + None, + None, + HashMap::new(), + ) + .await; + + match result { + Err(StreamableHttpError::UnexpectedServerResponse(_)) => {} + other => panic!("expected UnexpectedServerResponse, got: {other:?}"), + } +}