Skip to content

Commit 03ff8f8

Browse files
committed
example: add simpling example and test
Signed-off-by: jokemanfire <[email protected]>
1 parent b9d7d61 commit 03ff8f8

File tree

9 files changed

+645
-26
lines changed

9 files changed

+645
-26
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Cargo.lock
1212

1313
# MSVC Windows builds of rustc generate these, which store debugging information
1414
*.pdb
15-
15+
.vscode/
1616
# RustRover
1717
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
1818
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore

crates/rmcp/tests/test_sampling.rs

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
//cargo test --test test_sampling --features "client server"
2+
3+
mod common;
4+
5+
use anyhow::Result;
6+
use common::handlers::{TestClientHandler, TestServer};
7+
use rmcp::{
8+
ServiceExt,
9+
model::*,
10+
service::{RequestContext, Service},
11+
};
12+
use tokio_util::sync::CancellationToken;
13+
14+
#[tokio::test]
15+
async fn test_basic_sampling_message_creation() -> Result<()> {
16+
// Test basic sampling message structure
17+
let message = SamplingMessage {
18+
role: Role::User,
19+
content: Content::text("What is the capital of France?"),
20+
};
21+
22+
// Verify serialization/deserialization
23+
let json = serde_json::to_string(&message)?;
24+
let deserialized: SamplingMessage = serde_json::from_str(&json)?;
25+
assert_eq!(message, deserialized);
26+
assert_eq!(message.role, Role::User);
27+
28+
Ok(())
29+
}
30+
31+
#[tokio::test]
32+
async fn test_sampling_request_params() -> Result<()> {
33+
// Test sampling request parameters structure
34+
let params = CreateMessageRequestParam {
35+
messages: vec![SamplingMessage {
36+
role: Role::User,
37+
content: Content::text("Hello, world!"),
38+
}],
39+
model_preferences: Some(ModelPreferences {
40+
hints: Some(vec![ModelHint {
41+
name: Some("claude".to_string()),
42+
}]),
43+
cost_priority: Some(0.5),
44+
speed_priority: Some(0.8),
45+
intelligence_priority: Some(0.7),
46+
}),
47+
system_prompt: Some("You are a helpful assistant.".to_string()),
48+
temperature: Some(0.7),
49+
max_tokens: 100,
50+
stop_sequences: Some(vec!["STOP".to_string()]),
51+
include_context: Some(ContextInclusion::None),
52+
metadata: Some(serde_json::json!({"test": "value"})),
53+
};
54+
55+
// Verify serialization/deserialization
56+
let json = serde_json::to_string(&params)?;
57+
let deserialized: CreateMessageRequestParam = serde_json::from_str(&json)?;
58+
assert_eq!(params, deserialized);
59+
60+
// Verify specific fields
61+
assert_eq!(params.messages.len(), 1);
62+
assert_eq!(params.max_tokens, 100);
63+
assert_eq!(params.temperature, Some(0.7));
64+
65+
Ok(())
66+
}
67+
68+
#[tokio::test]
69+
async fn test_sampling_result_structure() -> Result<()> {
70+
// Test sampling result structure
71+
let result = CreateMessageResult {
72+
message: SamplingMessage {
73+
role: Role::Assistant,
74+
content: Content::text("The capital of France is Paris."),
75+
},
76+
model: "test-model".to_string(),
77+
stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()),
78+
};
79+
80+
// Verify serialization/deserialization
81+
let json = serde_json::to_string(&result)?;
82+
let deserialized: CreateMessageResult = serde_json::from_str(&json)?;
83+
assert_eq!(result, deserialized);
84+
85+
// Verify specific fields
86+
assert_eq!(result.message.role, Role::Assistant);
87+
assert_eq!(result.model, "test-model");
88+
assert_eq!(
89+
result.stop_reason,
90+
Some(CreateMessageResult::STOP_REASON_END_TURN.to_string())
91+
);
92+
93+
Ok(())
94+
}
95+
96+
#[tokio::test]
97+
async fn test_sampling_context_inclusion_enum() -> Result<()> {
98+
// Test context inclusion enum values
99+
let test_cases = vec![
100+
(ContextInclusion::None, "none"),
101+
(ContextInclusion::ThisServer, "thisServer"),
102+
(ContextInclusion::AllServers, "allServers"),
103+
];
104+
105+
for (context, expected_json) in test_cases {
106+
let json = serde_json::to_string(&context)?;
107+
assert_eq!(json, format!("\"{}\"", expected_json));
108+
109+
let deserialized: ContextInclusion = serde_json::from_str(&json)?;
110+
assert_eq!(context, deserialized);
111+
}
112+
113+
Ok(())
114+
}
115+
116+
#[tokio::test]
117+
async fn test_sampling_integration_with_test_handlers() -> Result<()> {
118+
let (server_transport, client_transport) = tokio::io::duplex(4096);
119+
120+
// Start server
121+
let server_handle = tokio::spawn(async move {
122+
let server = TestServer::new().serve(server_transport).await?;
123+
server.waiting().await?;
124+
anyhow::Ok(())
125+
});
126+
127+
// Start client that honors sampling requests
128+
let handler = TestClientHandler::new(true, true);
129+
let client = handler.clone().serve(client_transport).await?;
130+
131+
// Wait for initialization
132+
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
133+
134+
// Test sampling with context inclusion
135+
let request = ServerRequest::CreateMessageRequest(CreateMessageRequest {
136+
method: Default::default(),
137+
params: CreateMessageRequestParam {
138+
messages: vec![SamplingMessage {
139+
role: Role::User,
140+
content: Content::text("What is the capital of France?"),
141+
}],
142+
include_context: Some(ContextInclusion::ThisServer),
143+
model_preferences: Some(ModelPreferences {
144+
hints: Some(vec![ModelHint {
145+
name: Some("test-model".to_string()),
146+
}]),
147+
cost_priority: Some(0.5),
148+
speed_priority: Some(0.8),
149+
intelligence_priority: Some(0.7),
150+
}),
151+
system_prompt: Some("You are a helpful assistant.".to_string()),
152+
temperature: Some(0.7),
153+
max_tokens: 100,
154+
stop_sequences: None,
155+
metadata: None,
156+
},
157+
extensions: Default::default(),
158+
});
159+
160+
let result = handler
161+
.handle_request(
162+
request.clone(),
163+
RequestContext {
164+
peer: client.peer().clone(),
165+
ct: CancellationToken::new(),
166+
id: NumberOrString::Number(1),
167+
meta: Default::default(),
168+
extensions: Default::default(),
169+
},
170+
)
171+
.await?;
172+
173+
// Verify the response
174+
if let ClientResult::CreateMessageResult(result) = result {
175+
assert_eq!(result.message.role, Role::Assistant);
176+
assert_eq!(result.model, "test-model");
177+
assert_eq!(
178+
result.stop_reason,
179+
Some(CreateMessageResult::STOP_REASON_END_TURN.to_string())
180+
);
181+
182+
let response_text = result.message.content.as_text().unwrap().text.as_str();
183+
assert!(
184+
response_text.contains("test context"),
185+
"Response should include context for ThisServer inclusion"
186+
);
187+
} else {
188+
panic!("Expected CreateMessageResult");
189+
}
190+
191+
client.cancel().await?;
192+
server_handle.await??;
193+
Ok(())
194+
}
195+
196+
#[tokio::test]
197+
async fn test_sampling_no_context_inclusion() -> Result<()> {
198+
let (server_transport, client_transport) = tokio::io::duplex(4096);
199+
200+
// Start server
201+
let server_handle = tokio::spawn(async move {
202+
let server = TestServer::new().serve(server_transport).await?;
203+
server.waiting().await?;
204+
anyhow::Ok(())
205+
});
206+
207+
// Start client that honors sampling requests
208+
let handler = TestClientHandler::new(true, true);
209+
let client = handler.clone().serve(client_transport).await?;
210+
211+
// Wait for initialization
212+
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
213+
214+
// Test sampling without context inclusion
215+
let request = ServerRequest::CreateMessageRequest(CreateMessageRequest {
216+
method: Default::default(),
217+
params: CreateMessageRequestParam {
218+
messages: vec![SamplingMessage {
219+
role: Role::User,
220+
content: Content::text("Hello"),
221+
}],
222+
include_context: Some(ContextInclusion::None),
223+
model_preferences: None,
224+
system_prompt: None,
225+
temperature: None,
226+
max_tokens: 50,
227+
stop_sequences: None,
228+
metadata: None,
229+
},
230+
extensions: Default::default(),
231+
});
232+
233+
let result = handler
234+
.handle_request(
235+
request.clone(),
236+
RequestContext {
237+
peer: client.peer().clone(),
238+
ct: CancellationToken::new(),
239+
id: NumberOrString::Number(2),
240+
meta: Default::default(),
241+
extensions: Default::default(),
242+
},
243+
)
244+
.await?;
245+
246+
// Verify the response
247+
if let ClientResult::CreateMessageResult(result) = result {
248+
assert_eq!(result.message.role, Role::Assistant);
249+
assert_eq!(result.model, "test-model");
250+
251+
let response_text = result.message.content.as_text().unwrap().text.as_str();
252+
assert!(
253+
!response_text.contains("test context"),
254+
"Response should not include context for None inclusion"
255+
);
256+
} else {
257+
panic!("Expected CreateMessageResult");
258+
}
259+
260+
client.cancel().await?;
261+
server_handle.await??;
262+
Ok(())
263+
}
264+
265+
#[tokio::test]
266+
async fn test_sampling_error_invalid_message_sequence() -> Result<()> {
267+
let (server_transport, client_transport) = tokio::io::duplex(4096);
268+
269+
// Start server
270+
let server_handle = tokio::spawn(async move {
271+
let server = TestServer::new().serve(server_transport).await?;
272+
server.waiting().await?;
273+
anyhow::Ok(())
274+
});
275+
276+
// Start client
277+
let handler = TestClientHandler::new(true, true);
278+
let client = handler.clone().serve(client_transport).await?;
279+
280+
// Wait for initialization
281+
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
282+
283+
// Test sampling with no user messages (should fail)
284+
let request = ServerRequest::CreateMessageRequest(CreateMessageRequest {
285+
method: Default::default(),
286+
params: CreateMessageRequestParam {
287+
messages: vec![SamplingMessage {
288+
role: Role::Assistant,
289+
content: Content::text("I'm an assistant message without a user message"),
290+
}],
291+
include_context: Some(ContextInclusion::None),
292+
model_preferences: None,
293+
system_prompt: None,
294+
temperature: None,
295+
max_tokens: 50,
296+
stop_sequences: None,
297+
metadata: None,
298+
},
299+
extensions: Default::default(),
300+
});
301+
302+
let result = handler
303+
.handle_request(
304+
request.clone(),
305+
RequestContext {
306+
peer: client.peer().clone(),
307+
ct: CancellationToken::new(),
308+
id: NumberOrString::Number(3),
309+
meta: Default::default(),
310+
extensions: Default::default(),
311+
},
312+
)
313+
.await;
314+
315+
// This should result in an error
316+
assert!(result.is_err());
317+
318+
client.cancel().await?;
319+
server_handle.await??;
320+
Ok(())
321+
}

examples/clients/Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,8 @@ path = "src/collection.rs"
5151

5252
[[example]]
5353
name = "clients_oauth_client"
54-
path = "src/auth/oauth_client.rs"
54+
path = "src/auth/oauth_client.rs"
55+
56+
[[example]]
57+
name = "clients_sampling_stdio"
58+
path = "src/sampling_stdio.rs"

examples/clients/src/collection.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,26 @@ async fn main() -> Result<()> {
2626

2727
let mut clients_map = HashMap::new();
2828
for idx in 0..10 {
29-
let service = ()
29+
let client = ()
3030
.into_dyn()
3131
.serve(TokioChildProcess::new(Command::new("uvx").configure(
3232
|cmd| {
3333
cmd.arg("mcp-client-git");
3434
},
3535
))?)
3636
.await?;
37-
clients_map.insert(idx, service);
37+
clients_map.insert(idx, client);
3838
}
3939

40-
for (_, service) in clients_map.iter() {
40+
for (_, client) in clients_map.iter() {
4141
// Initialize
42-
let _server_info = service.peer_info();
42+
let _server_info = client.peer_info();
4343

4444
// List tools
45-
let _tools = service.list_tools(Default::default()).await?;
45+
let _tools = client.list_tools(Default::default()).await?;
4646

4747
// Call tool 'git_status' with arguments = {"repo_path": "."}
48-
let _tool_result = service
48+
let _tool_result = client
4949
.call_tool(CallToolRequestParam {
5050
name: "git_status".into(),
5151
arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(),

0 commit comments

Comments
 (0)