diff --git a/.gitignore b/.gitignore index d01bd1a9..9a96472a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,7 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb - +.vscode/ # RustRover # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore diff --git a/crates/rmcp/tests/test_sampling.rs b/crates/rmcp/tests/test_sampling.rs new file mode 100644 index 00000000..b760796f --- /dev/null +++ b/crates/rmcp/tests/test_sampling.rs @@ -0,0 +1,321 @@ +//cargo test --test test_sampling --features "client server" + +mod common; + +use anyhow::Result; +use common::handlers::{TestClientHandler, TestServer}; +use rmcp::{ + ServiceExt, + model::*, + service::{RequestContext, Service}, +}; +use tokio_util::sync::CancellationToken; + +#[tokio::test] +async fn test_basic_sampling_message_creation() -> Result<()> { + // Test basic sampling message structure + let message = SamplingMessage { + role: Role::User, + content: Content::text("What is the capital of France?"), + }; + + // Verify serialization/deserialization + let json = serde_json::to_string(&message)?; + let deserialized: SamplingMessage = serde_json::from_str(&json)?; + assert_eq!(message, deserialized); + assert_eq!(message.role, Role::User); + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_request_params() -> Result<()> { + // Test sampling request parameters structure + let params = CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("Hello, world!"), + }], + model_preferences: Some(ModelPreferences { + hints: Some(vec![ModelHint { + name: Some("claude".to_string()), + }]), + cost_priority: Some(0.5), + speed_priority: Some(0.8), + intelligence_priority: Some(0.7), + }), + system_prompt: Some("You are a helpful assistant.".to_string()), + temperature: Some(0.7), + max_tokens: 100, + stop_sequences: Some(vec!["STOP".to_string()]), + include_context: Some(ContextInclusion::None), + metadata: Some(serde_json::json!({"test": "value"})), + }; + + // Verify serialization/deserialization + let json = serde_json::to_string(¶ms)?; + let deserialized: CreateMessageRequestParam = serde_json::from_str(&json)?; + assert_eq!(params, deserialized); + + // Verify specific fields + assert_eq!(params.messages.len(), 1); + assert_eq!(params.max_tokens, 100); + assert_eq!(params.temperature, Some(0.7)); + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_result_structure() -> Result<()> { + // Test sampling result structure + let result = CreateMessageResult { + message: SamplingMessage { + role: Role::Assistant, + content: Content::text("The capital of France is Paris."), + }, + model: "test-model".to_string(), + stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), + }; + + // Verify serialization/deserialization + let json = serde_json::to_string(&result)?; + let deserialized: CreateMessageResult = serde_json::from_str(&json)?; + assert_eq!(result, deserialized); + + // Verify specific fields + assert_eq!(result.message.role, Role::Assistant); + assert_eq!(result.model, "test-model"); + assert_eq!( + result.stop_reason, + Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()) + ); + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_context_inclusion_enum() -> Result<()> { + // Test context inclusion enum values + let test_cases = vec![ + (ContextInclusion::None, "none"), + (ContextInclusion::ThisServer, "thisServer"), + (ContextInclusion::AllServers, "allServers"), + ]; + + for (context, expected_json) in test_cases { + let json = serde_json::to_string(&context)?; + assert_eq!(json, format!("\"{}\"", expected_json)); + + let deserialized: ContextInclusion = serde_json::from_str(&json)?; + assert_eq!(context, deserialized); + } + + Ok(()) +} + +#[tokio::test] +async fn test_sampling_integration_with_test_handlers() -> Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client that honors sampling requests + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Wait for initialization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Test sampling with context inclusion + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("What is the capital of France?"), + }], + include_context: Some(ContextInclusion::ThisServer), + model_preferences: Some(ModelPreferences { + hints: Some(vec![ModelHint { + name: Some("test-model".to_string()), + }]), + cost_priority: Some(0.5), + speed_priority: Some(0.8), + intelligence_priority: Some(0.7), + }), + system_prompt: Some("You are a helpful assistant.".to_string()), + temperature: Some(0.7), + max_tokens: 100, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(1), + meta: Default::default(), + extensions: Default::default(), + }, + ) + .await?; + + // Verify the response + if let ClientResult::CreateMessageResult(result) = result { + assert_eq!(result.message.role, Role::Assistant); + assert_eq!(result.model, "test-model"); + assert_eq!( + result.stop_reason, + Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()) + ); + + let response_text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + response_text.contains("test context"), + "Response should include context for ThisServer inclusion" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_sampling_no_context_inclusion() -> Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client that honors sampling requests + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Wait for initialization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Test sampling without context inclusion + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text("Hello"), + }], + include_context: Some(ContextInclusion::None), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 50, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(2), + meta: Default::default(), + extensions: Default::default(), + }, + ) + .await?; + + // Verify the response + if let ClientResult::CreateMessageResult(result) = result { + assert_eq!(result.message.role, Role::Assistant); + assert_eq!(result.model, "test-model"); + + let response_text = result.message.content.as_text().unwrap().text.as_str(); + assert!( + !response_text.contains("test context"), + "Response should not include context for None inclusion" + ); + } else { + panic!("Expected CreateMessageResult"); + } + + client.cancel().await?; + server_handle.await??; + Ok(()) +} + +#[tokio::test] +async fn test_sampling_error_invalid_message_sequence() -> Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client + let handler = TestClientHandler::new(true, true); + let client = handler.clone().serve(client_transport).await?; + + // Wait for initialization + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Test sampling with no user messages (should fail) + let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::Assistant, + content: Content::text("I'm an assistant message without a user message"), + }], + include_context: Some(ContextInclusion::None), + model_preferences: None, + system_prompt: None, + temperature: None, + max_tokens: 50, + stop_sequences: None, + metadata: None, + }, + extensions: Default::default(), + }); + + let result = handler + .handle_request( + request.clone(), + RequestContext { + peer: client.peer().clone(), + ct: CancellationToken::new(), + id: NumberOrString::Number(3), + meta: Default::default(), + extensions: Default::default(), + }, + ) + .await; + + // This should result in an error + assert!(result.is_err()); + + client.cancel().await?; + server_handle.await??; + Ok(()) +} diff --git a/examples/clients/Cargo.toml b/examples/clients/Cargo.toml index 2461684a..dfb576c1 100644 --- a/examples/clients/Cargo.toml +++ b/examples/clients/Cargo.toml @@ -51,4 +51,8 @@ path = "src/collection.rs" [[example]] name = "clients_oauth_client" -path = "src/auth/oauth_client.rs" \ No newline at end of file +path = "src/auth/oauth_client.rs" + +[[example]] +name = "clients_sampling_stdio" +path = "src/sampling_stdio.rs" \ No newline at end of file diff --git a/examples/clients/src/collection.rs b/examples/clients/src/collection.rs index deef8188..67969ae4 100644 --- a/examples/clients/src/collection.rs +++ b/examples/clients/src/collection.rs @@ -26,7 +26,7 @@ async fn main() -> Result<()> { let mut clients_map = HashMap::new(); for idx in 0..10 { - let service = () + let client = () .into_dyn() .serve(TokioChildProcess::new(Command::new("uvx").configure( |cmd| { @@ -34,18 +34,18 @@ async fn main() -> Result<()> { }, ))?) .await?; - clients_map.insert(idx, service); + clients_map.insert(idx, client); } - for (_, service) in clients_map.iter() { + for (_, client) in clients_map.iter() { // Initialize - let _server_info = service.peer_info(); + let _server_info = client.peer_info(); // List tools - let _tools = service.list_tools(Default::default()).await?; + let _tools = client.list_tools(Default::default()).await?; // Call tool 'git_status' with arguments = {"repo_path": "."} - let _tool_result = service + let _tool_result = client .call_tool(CallToolRequestParam { name: "git_status".into(), arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), diff --git a/examples/clients/src/everything_stdio.rs b/examples/clients/src/everything_stdio.rs index a2a7475a..107adc07 100644 --- a/examples/clients/src/everything_stdio.rs +++ b/examples/clients/src/everything_stdio.rs @@ -19,8 +19,7 @@ async fn main() -> Result<()> { .with(tracing_subscriber::fmt::layer()) .init(); - // Start server - let service = () + let client = () .serve(TokioChildProcess::new(Command::new("npx").configure( |cmd| { cmd.arg("-y").arg("@modelcontextprotocol/server-everything"); @@ -29,15 +28,15 @@ async fn main() -> Result<()> { .await?; // Initialize - let server_info = service.peer_info(); + let server_info = client.peer_info(); tracing::info!("Connected to server: {server_info:#?}"); // List tools - let tools = service.list_all_tools().await?; + let tools = client.list_all_tools().await?; tracing::info!("Available tools: {tools:#?}"); // Call tool echo - let tool_result = service + let tool_result = client .call_tool(CallToolRequestParam { name: "echo".into(), arguments: Some(object!({ "message": "hi from rmcp" })), @@ -46,7 +45,7 @@ async fn main() -> Result<()> { tracing::info!("Tool result for echo: {tool_result:#?}"); // Call tool longRunningOperation - let tool_result = service + let tool_result = client .call_tool(CallToolRequestParam { name: "longRunningOperation".into(), arguments: Some(object!({ "duration": 3, "steps": 1 })), @@ -55,11 +54,11 @@ async fn main() -> Result<()> { tracing::info!("Tool result for longRunningOperation: {tool_result:#?}"); // List resources - let resources = service.list_all_resources().await?; + let resources = client.list_all_resources().await?; tracing::info!("Available resources: {resources:#?}"); // Read resource - let resource = service + let resource = client .read_resource(ReadResourceRequestParam { uri: "test://static/resource/3".into(), }) @@ -67,11 +66,11 @@ async fn main() -> Result<()> { tracing::info!("Resource: {resource:#?}"); // List prompts - let prompts = service.list_all_prompts().await?; + let prompts = client.list_all_prompts().await?; tracing::info!("Available prompts: {prompts:#?}"); // Get simple prompt - let prompt = service + let prompt = client .get_prompt(GetPromptRequestParam { name: "simple_prompt".into(), arguments: None, @@ -80,7 +79,7 @@ async fn main() -> Result<()> { tracing::info!("Prompt - simple: {prompt:#?}"); // Get complex prompt (returns text & image) - let prompt = service + let prompt = client .get_prompt(GetPromptRequestParam { name: "complex_prompt".into(), arguments: Some(object!({ "temperature": "0.5", "style": "formal" })), @@ -89,10 +88,10 @@ async fn main() -> Result<()> { tracing::info!("Prompt - complex: {prompt:#?}"); // List resource templates - let resource_templates = service.list_all_resource_templates().await?; + let resource_templates = client.list_all_resource_templates().await?; tracing::info!("Available resource templates: {resource_templates:#?}"); - service.cancel().await?; + client.cancel().await?; Ok(()) } diff --git a/examples/clients/src/git_stdio.rs b/examples/clients/src/git_stdio.rs index 0e04763d..bda978d9 100644 --- a/examples/clients/src/git_stdio.rs +++ b/examples/clients/src/git_stdio.rs @@ -17,7 +17,7 @@ async fn main() -> Result<()> { ) .with(tracing_subscriber::fmt::layer()) .init(); - let service = () + let client = () .serve(TokioChildProcess::new(Command::new("uvx").configure( |cmd| { cmd.arg("mcp-server-git"); @@ -28,21 +28,21 @@ async fn main() -> Result<()> { // or serve_client((), TokioChildProcess::new(cmd)?).await?; // Initialize - let server_info = service.peer_info(); + let server_info = client.peer_info(); tracing::info!("Connected to server: {server_info:#?}"); // List tools - let tools = service.list_tools(Default::default()).await?; + let tools = client.list_tools(Default::default()).await?; tracing::info!("Available tools: {tools:#?}"); // Call tool 'git_status' with arguments = {"repo_path": "."} - let tool_result = service + let tool_result = client .call_tool(CallToolRequestParam { name: "git_status".into(), arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), }) .await?; tracing::info!("Tool result: {tool_result:#?}"); - service.cancel().await?; + client.cancel().await?; Ok(()) } diff --git a/examples/clients/src/sampling_stdio.rs b/examples/clients/src/sampling_stdio.rs new file mode 100644 index 00000000..ef02604c --- /dev/null +++ b/examples/clients/src/sampling_stdio.rs @@ -0,0 +1,120 @@ +use anyhow::Result; +use rmcp::{ + ClientHandler, ServiceExt, + model::*, + object, + service::{RequestContext, RoleClient}, + transport::{ConfigureCommandExt, TokioChildProcess}, +}; +use tokio::process::Command; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +/// Simple Sampling Demo Client +/// +/// This client demonstrates how to handle sampling requests from servers. +/// It includes a mock LLM that generates simple responses. +/// Run with: cargo run --example clients_sampling_stdio +#[derive(Clone, Debug, Default)] +pub struct SamplingDemoClient; + +impl SamplingDemoClient { + /// Mock LLM function that generates responses based on the input + /// In actual implementation, this would be replaced with a call to an LLM service + fn mock_llm_response( + &self, + _messages: &[SamplingMessage], + _system_prompt: Option<&str>, + ) -> String { + "It just a mock response".to_string() + } +} + +impl ClientHandler for SamplingDemoClient { + async fn create_message( + &self, + params: CreateMessageRequestParam, + _context: RequestContext, + ) -> Result { + tracing::info!("Received sampling request with {:?}", params); + + // Generate mock response using our simple LLM + let response_text = + self.mock_llm_response(¶ms.messages, params.system_prompt.as_deref()); + + Ok(CreateMessageResult { + message: SamplingMessage { + role: Role::Assistant, + content: Content::text(response_text), + }, + model: "mock_llm".to_string(), + stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), + }) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + tracing::info!("Starting Sampling Demo Client"); + + let client = SamplingDemoClient; + + // Start the sampling server as a child process + let client = client + .serve(TokioChildProcess::new(Command::new("cargo").configure( + |cmd| { + cmd.arg("run") + .arg("--example") + .arg("servers_sampling_stdio") + .current_dir("../servers") // This should be correct relative path + ; + }, + ))?) + .await + .inspect_err(|e| { + tracing::error!("client error: {:?}", e); + })?; + + // Wait for initialization + tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + + // Get server info + let server_info = client.peer_info(); + tracing::info!("Connected to server: {server_info:#?}"); + + // List available tools + match client.list_all_tools().await { + Ok(tools) => { + tracing::info!("Available tools: {tools:#?}"); + + // Test the ask_llm tool + tracing::info!("Testing ask_llm tool..."); + match client + .call_tool(CallToolRequestParam { + name: "ask_llm".into(), + arguments: Some(object!({ + "question": "Hello world" + })), + }) + .await + { + Ok(result) => tracing::info!("Ask LLM result: {result:#?}"), + Err(e) => tracing::error!("Ask LLM error: {e}"), + } + } + Err(e) => tracing::error!("Failed to list tools: {e}"), + } + + tracing::info!("Sampling demo completed successfully!"); + + tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + client.cancel().await?; + Ok(()) +} diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 776b5f9e..a6acd646 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -78,3 +78,7 @@ path = "src/simple_auth_sse.rs" [[example]] name = "counter_hyper_streamable_http" path = "src/counter_hyper_streamable_http.rs" + +[[example]] +name = "servers_sampling_stdio" +path = "src/sampling_stdio.rs" diff --git a/examples/servers/src/sampling_stdio.rs b/examples/servers/src/sampling_stdio.rs new file mode 100644 index 00000000..95e8e48f --- /dev/null +++ b/examples/servers/src/sampling_stdio.rs @@ -0,0 +1,158 @@ +use std::sync::Arc; + +use anyhow::Result; +use rmcp::{ + ServerHandler, ServiceExt, + model::*, + service::{RequestContext, RoleServer}, + transport::stdio, +}; +use tracing_subscriber::{self, EnvFilter}; + +/// Simple Sampling Demo Server +/// +/// This server demonstrates how to request LLM sampling from clients. +/// Run with: cargo run --example servers_sampling_stdio +#[derive(Clone, Debug, Default)] +pub struct SamplingDemoServer; + +impl ServerHandler for SamplingDemoServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + instructions: Some(concat!( + "This is a demo server that requests sampling from clients. It provides tools that use LLM capabilities.\n\n", + "IMPORTANT: This server requires a client that supports the 'sampling/createMessage' method. ", + "Without sampling support, the tools will return errors." + ).into()), + capabilities: ServerCapabilities::builder() + .enable_tools() + .build(), + ..Default::default() + } + } + + async fn call_tool( + &self, + request: CallToolRequestParam, + context: RequestContext, + ) -> Result { + match request.name.as_ref() { + "ask_llm" => { + // Get the question from arguments + let question = request + .arguments + .as_ref() + .and_then(|args| args.get("question")) + .and_then(|q| q.as_str()) + .unwrap_or("What is the capital of France?"); + + // Request sampling from the client + let sampling_request = CreateMessageRequest { + method: Default::default(), + params: CreateMessageRequestParam { + messages: vec![SamplingMessage { + role: Role::User, + content: Content::text(question), + }], + model_preferences: Some(ModelPreferences { + hints: Some(vec![ModelHint { + name: Some("claude".to_string()), + }]), + cost_priority: Some(0.3), + speed_priority: Some(0.8), + intelligence_priority: Some(0.7), + }), + system_prompt: Some("You are a helpful assistant.".to_string()), + temperature: Some(0.7), + max_tokens: 150, + stop_sequences: None, + include_context: Some(ContextInclusion::None), + metadata: None, + }, + extensions: Default::default(), + }; + let request = ServerRequest::CreateMessageRequest(sampling_request.clone()); + tracing::info!("Sending request: {:?}", request); + let response = context.peer.send_request(request).await.map_err(|e| { + ErrorData::new( + ErrorCode::INTERNAL_ERROR, + format!("Sampling request failed: {}", e), + None, + ) + })?; + if let ClientResult::CreateMessageResult(result) = response { + Ok(CallToolResult::success(vec![Content::text(format!( + "Question: {}\nAnswer: {}", + question, + result + .message + .content + .as_text() + .map(|t| &t.text) + .unwrap_or(&"No text response".to_string()) + ))])) + } else { + Err(ErrorData::new( + ErrorCode::INTERNAL_ERROR, + "Unexpected response type", + None, + )) + } + } + + _ => Err(ErrorData::new( + ErrorCode::INTERNAL_ERROR, + format!("Unknown tool: {}", request.name), + None, + )), + } + } + + async fn list_tools( + &self, + _request: Option, + _context: RequestContext, + ) -> Result { + Ok(ListToolsResult { + tools: vec![Tool { + name: "ask_llm".into(), + description: Some("Ask a question to the LLM through sampling".into()), + input_schema: Arc::new( + serde_json::from_value(serde_json::json!({ + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "The question to ask the LLM" + } + }, + "required": ["question"] + })) + .unwrap(), + ), + annotations: None, + }], + next_cursor: None, + }) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) + .with_writer(std::io::stderr) + .with_ansi(false) + .init(); + + tracing::info!("Starting Sampling Demo Server"); + + // Create and serve the sampling demo server + let service = SamplingDemoServer.serve(stdio()).await.inspect_err(|e| { + tracing::error!("Serving error: {:?}", e); + })?; + + service.waiting().await?; + Ok(()) +}