Skip to content

fix: add compatibility handling for non-standard notifications #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 132 additions & 5 deletions crates/rmcp/src/transport/async_rw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,68 @@ fn without_carriage_return(s: &[u8]) -> &[u8] {
}
}

/// Check if a notification method is a standard MCP notification
/// should update this when MCP spec is updated about new notifications
fn is_standard_notification(method: &str) -> bool {
matches!(
method,
"notifications/cancelled"
| "notifications/initialized"
| "notifications/message"
| "notifications/progress"
| "notifications/prompts/list_changed"
| "notifications/resources/list_changed"
| "notifications/resources/updated"
| "notifications/roots/list_changed"
| "notifications/tools/list_changed"
)
}

/// Try to parse a message with compatibility handling for non-standard notifications
fn try_parse_with_compatibility<T: serde::de::DeserializeOwned>(
line: &[u8],
context: &str,
) -> Result<Option<T>, JsonRpcMessageCodecError> {
if let Ok(line_str) = std::str::from_utf8(line) {
match serde_json::from_slice(line) {
Ok(item) => Ok(Some(item)),
Err(e) => {
// Check if this is a non-standard notification that should be ignored
if line_str.contains("\"method\":\"notifications/") {
// Extract the method name to check if it's standard
if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(line_str) {
if let Some(method) = json_value.get("method").and_then(|m| m.as_str()) {
if method.starts_with("notifications/")
&& !is_standard_notification(method)
{
tracing::debug!(
"Ignoring non-standard notification {} {}: {}",
method,
context,
line_str
);
return Ok(None); // Skip this message
}
}
}
}

tracing::debug!(
"Failed to parse message {}: {} | Error: {}",
context,
line_str,
e
);
Err(JsonRpcMessageCodecError::Serde(e))
}
}
} else {
serde_json::from_slice(line)
.map(Some)
.map_err(JsonRpcMessageCodecError::Serde)
}
}

#[derive(Debug, Error)]
pub enum JsonRpcMessageCodecError {
#[error("max line length exceeded")]
Expand Down Expand Up @@ -234,8 +296,12 @@ impl<T: DeserializeOwned> Decoder for JsonRpcMessageCodec<T> {
let line = buf.split_to(newline_index + 1);
let line = &line[..line.len() - 1];
let line = without_carriage_return(line);
let item =
serde_json::from_slice(line).map_err(JsonRpcMessageCodecError::Serde)?;

// Use compatibility handling function
let item = match try_parse_with_compatibility(line, "decode")? {
Some(item) => item,
None => return Ok(None), // Skip non-standard message
};
return Ok(Some(item));
}
(false, None) if buf.len() > self.max_length => {
Expand Down Expand Up @@ -266,8 +332,12 @@ impl<T: DeserializeOwned> Decoder for JsonRpcMessageCodec<T> {
} else {
let line = buf.split_to(buf.len());
let line = without_carriage_return(&line);
let item =
serde_json::from_slice(line).map_err(JsonRpcMessageCodecError::Serde)?;

// Use compatibility handling function
let item = match try_parse_with_compatibility(line, "decode_eof")? {
Some(item) => item,
None => return Ok(None), // Skip non-standard message
};
Some(item)
}
}
Expand Down Expand Up @@ -319,7 +389,7 @@ mod test {
{"jsonrpc":"2.0","method":"subtract","params":[23,42],"id":8}
{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":9}
{"jsonrpc":"2.0","method":"subtract","params":[23,42],"id":10}

"#;

let mut cursor = BufReader::new(data.as_bytes());
Expand Down Expand Up @@ -379,4 +449,61 @@ mod test {
// Make sure there are no extra lines
assert!(lines.next().is_none());
}

#[test]
fn test_standard_notification_check() {
// Test that all standard notifications are recognized
assert!(is_standard_notification("notifications/cancelled"));
assert!(is_standard_notification("notifications/initialized"));
assert!(is_standard_notification("notifications/progress"));
assert!(is_standard_notification(
"notifications/resources/list_changed"
));
assert!(is_standard_notification("notifications/resources/updated"));
assert!(is_standard_notification(
"notifications/prompts/list_changed"
));
assert!(is_standard_notification("notifications/tools/list_changed"));
assert!(is_standard_notification("notifications/message"));
assert!(is_standard_notification("notifications/roots/list_changed"));

// Test that non-standard notifications are not recognized
assert!(!is_standard_notification("notifications/stderr"));
assert!(!is_standard_notification("notifications/custom"));
assert!(!is_standard_notification("notifications/debug"));
assert!(!is_standard_notification("some/other/method"));
}

#[test]
fn test_compatibility_function() {
// Test the compatibility function directly
let stderr_message =
r#"{"method":"notifications/stderr","params":{"content":"stderr message"}}"#;
let custom_message = r#"{"method":"notifications/custom","params":{"data":"custom"}}"#;
let standard_message =
r#"{"method":"notifications/message","params":{"level":"info","data":"standard"}}"#;
let progress_message = r#"{"method":"notifications/progress","params":{"progressToken":"token","progress":50}}"#;

// Test with valid JSON - all should parse successfully
let result1 =
try_parse_with_compatibility::<serde_json::Value>(stderr_message.as_bytes(), "test");
let result2 =
try_parse_with_compatibility::<serde_json::Value>(custom_message.as_bytes(), "test");
let result3 =
try_parse_with_compatibility::<serde_json::Value>(standard_message.as_bytes(), "test");
let result4 =
try_parse_with_compatibility::<serde_json::Value>(progress_message.as_bytes(), "test");

// All should parse successfully since they're valid JSON
assert!(result1.is_ok());
assert!(result2.is_ok());
assert!(result3.is_ok());
assert!(result4.is_ok());

// Standard notifications should return Some(value)
assert!(result3.unwrap().is_some());
assert!(result4.unwrap().is_some());

println!("Standard notifications are preserved, non-standard are handled gracefully");
}
}
Loading