-
Notifications
You must be signed in to change notification settings - Fork 545
feat: expose streamable HTTP session ID #893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,4 +1,9 @@ | ||||||
| use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration}; | ||||||
| use std::{ | ||||||
| borrow::Cow, | ||||||
| collections::HashMap, | ||||||
| sync::{Arc, RwLock}, | ||||||
| time::Duration, | ||||||
| }; | ||||||
|
|
||||||
| use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream}; | ||||||
| use http::{HeaderName, HeaderValue}; | ||||||
|
|
@@ -16,13 +21,38 @@ use crate::{ | |||||
| ServerResult, | ||||||
| }, | ||||||
| transport::{ | ||||||
| TransportSessionIdHandle, TransportSessionIdProvider, | ||||||
| common::client_side_sse::SseAutoReconnectStream, | ||||||
| worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport}, | ||||||
| }, | ||||||
| }; | ||||||
|
|
||||||
| type BoxedSseStream = BoxStream<'static, Result<Sse, SseError>>; | ||||||
|
|
||||||
| /// Cloneable read-only handle for the negotiated streamable HTTP session ID. | ||||||
| #[derive(Debug, Clone, Default)] | ||||||
| pub struct StreamableHttpClientSession { | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like an implementation detail.
Suggested change
|
||||||
| session_id: Arc<RwLock<Option<Arc<str>>>>, | ||||||
| } | ||||||
|
|
||||||
| impl StreamableHttpClientSession { | ||||||
| pub fn session_id(&self) -> Option<Arc<str>> { | ||||||
| self.session_id.read().ok().and_then(|guard| guard.clone()) | ||||||
| } | ||||||
|
|
||||||
| fn set_session_id(&self, session_id: Option<Arc<str>>) { | ||||||
| if let Ok(mut guard) = self.session_id.write() { | ||||||
| *guard = session_id; | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| impl TransportSessionIdProvider for StreamableHttpClientSession { | ||||||
| fn session_id(&self) -> Option<Arc<str>> { | ||||||
| self.session_id() | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| #[derive(Debug)] | ||||||
| #[non_exhaustive] | ||||||
| pub struct AuthRequiredError { | ||||||
|
|
@@ -277,6 +307,7 @@ struct SessionCleanupInfo<C> { | |||||
| pub struct StreamableHttpClientWorker<C: StreamableHttpClient> { | ||||||
| pub client: C, | ||||||
| pub config: StreamableHttpClientTransportConfig, | ||||||
| session: StreamableHttpClientSession, | ||||||
| } | ||||||
|
|
||||||
| impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> { | ||||||
|
|
@@ -287,13 +318,18 @@ impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> { | |||||
| uri: url.into(), | ||||||
| ..Default::default() | ||||||
| }, | ||||||
| session: StreamableHttpClientSession::default(), | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> { | ||||||
| pub fn new(client: C, config: StreamableHttpClientTransportConfig) -> Self { | ||||||
| Self { client, config } | ||||||
| Self { | ||||||
| client, | ||||||
| config, | ||||||
| session: StreamableHttpClientSession::default(), | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
|
|
@@ -447,6 +483,11 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> { | |||||
| fn err_join(e: tokio::task::JoinError) -> Self::Error { | ||||||
| StreamableHttpError::TokioJoinError(e) | ||||||
| } | ||||||
| fn session_id_handle(&self) -> Option<TransportSessionIdHandle> { | ||||||
| Some(TransportSessionIdHandle::new(Arc::new( | ||||||
| self.session.clone(), | ||||||
| ))) | ||||||
| } | ||||||
| fn config(&self) -> super::worker::WorkerConfig { | ||||||
| super::worker::WorkerConfig { | ||||||
| name: Some("StreamableHttpClientWorker".into()), | ||||||
|
|
@@ -505,6 +546,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> { | |||||
| } | ||||||
| None | ||||||
| }; | ||||||
| self.session.set_session_id(session_id.clone()); | ||||||
| // Extract the negotiated protocol version from the init response | ||||||
| // and build a custom headers map that includes MCP-Protocol-Version | ||||||
| // for all subsequent HTTP requests (per MCP 2025-06-18 spec). | ||||||
|
|
@@ -684,6 +726,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> { | |||||
| streams.abort_all(); | ||||||
|
|
||||||
| session_id = new_session_id; | ||||||
| self.session.set_session_id(session_id.clone()); | ||||||
| protocol_headers = new_protocol_headers; | ||||||
| session_cleanup_info = | ||||||
| session_id.as_ref().map(|sid| SessionCleanupInfo { | ||||||
|
|
@@ -872,6 +915,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> { | |||||
| } | ||||||
| } | ||||||
| } | ||||||
| self.session.set_session_id(None); | ||||||
|
|
||||||
| loop_result | ||||||
| } | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ use std::{borrow::Cow, time::Duration}; | |
| use tokio_util::sync::CancellationToken; | ||
| use tracing::{Instrument, Level}; | ||
|
|
||
| use super::{IntoTransport, Transport}; | ||
| use super::{IntoTransport, Transport, TransportSessionIdHandle}; | ||
| use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}; | ||
|
|
||
| #[derive(Debug, thiserror::Error)] | ||
|
|
@@ -46,6 +46,9 @@ pub trait Worker: Sized + Send + 'static { | |
| type Role: ServiceRole; | ||
| fn err_closed() -> Self::Error; | ||
| fn err_join(e: tokio::task::JoinError) -> Self::Error; | ||
| fn session_id_handle(&self) -> Option<TransportSessionIdHandle> { | ||
| None | ||
| } | ||
| fn run( | ||
| self, | ||
| context: WorkerContext<Self>, | ||
|
|
@@ -67,6 +70,7 @@ pub struct WorkerTransport<W: Worker> { | |
| join_handle: Option<tokio::task::JoinHandle<Result<(), WorkerQuitReason<W::Error>>>>, | ||
| _drop_guard: tokio_util::sync::DropGuard, | ||
| ct: CancellationToken, | ||
| session_id_handle: Option<TransportSessionIdHandle>, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The SemVer check flags this as a breaking change because it removes the auto traits. The simplest fix is probably to require |
||
| } | ||
|
|
||
| #[non_exhaustive] | ||
|
|
@@ -96,12 +100,21 @@ impl<W: Worker> WorkerTransport<W> { | |
| pub fn cancel_token(&self) -> CancellationToken { | ||
| self.ct.clone() | ||
| } | ||
| pub fn session_id_handle(&self) -> Option<TransportSessionIdHandle> { | ||
| self.session_id_handle.clone() | ||
| } | ||
| pub fn session_id(&self) -> Option<std::sync::Arc<str>> { | ||
| self.session_id_handle | ||
| .as_ref() | ||
| .and_then(|handle| handle.session_id()) | ||
| } | ||
| pub fn spawn(worker: W) -> Self { | ||
| Self::spawn_with_ct(worker, CancellationToken::new()) | ||
| } | ||
| pub fn spawn_with_ct(worker: W, transport_task_ct: CancellationToken) -> Self { | ||
| let config = worker.config(); | ||
| let worker_name = config.name; | ||
| let session_id_handle = worker.session_id_handle(); | ||
| let (to_transport_tx, from_handler_rx) = | ||
| tokio::sync::mpsc::channel::<WorkerSendRequest<W>>(config.channel_buffer_capacity); | ||
| let (to_handler_tx, from_transport_rx) = | ||
|
|
@@ -145,6 +158,7 @@ impl<W: Worker> WorkerTransport<W> { | |
| join_handle: Some(join_handle), | ||
| ct: transport_task_ct.clone(), | ||
| _drop_guard: transport_task_ct.drop_guard(), | ||
| session_id_handle, | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -214,4 +228,8 @@ impl<W: Worker> Transport<W::Role> for WorkerTransport<W> { | |
| Ok(()) | ||
| } | ||
| } | ||
|
|
||
| fn session_id_handle(&self) -> Option<TransportSessionIdHandle> { | ||
| WorkerTransport::session_id_handle(self) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the existing
handlenames in this project refer to task or runtime handles. How about calling itsession_id_observer()orsession_id_provider()instead so that the API can communicate that this value is read-only more clearly?