From dc2a285ef2d2f5e6b87290a96b6ffea16bb2d66d Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Sun, 24 May 2026 17:47:26 +0800 Subject: [PATCH] fix: reinitialize expired streamable sessions --- packages/client/src/client/client.ts | 91 ++++++++++-------- packages/client/src/client/streamableHttp.ts | 19 +++- .../client/test/client/streamableHttp.test.ts | 93 ++++++++++++++++++- 3 files changed, 160 insertions(+), 43 deletions(-) diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 5fa2e14d94..38a115715e 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -194,6 +194,10 @@ export type ClientOptions = ProtocolOptions & { listChanged?: ListChangedHandlers; }; +type SessionExpiringTransport = Transport & { + onsessionexpired?: () => void | Promise; +}; + /** * An MCP client on top of a pluggable transport. * @@ -483,6 +487,13 @@ export class Client extends Protocol { */ override async connect(transport: Transport, options?: RequestOptions): Promise { await super.connect(transport); + (transport as SessionExpiringTransport).onsessionexpired = async () => { + this._serverCapabilities = undefined; + this._serverVersion = undefined; + this._negotiatedProtocolVersion = undefined; + await this._initialize(transport, options); + }; + // When transport sessionId is already set this means we are trying to reconnect. // Restore the protocol version negotiated during the original initialize handshake // so HTTP transports include the required mcp-protocol-version header, but skip re-init. @@ -493,50 +504,54 @@ export class Client extends Protocol { return; } try { - const result = await this._requestWithSchema( - { - method: 'initialize', - params: { - protocolVersion: this._supportedProtocolVersions[0] ?? LATEST_PROTOCOL_VERSION, - capabilities: this._capabilities, - clientInfo: this._clientInfo - } - }, - InitializeResultSchema, - options - ); + await this._initialize(transport, options); + } catch (error) { + // Disconnect if initialization fails. + void this.close(); + throw error; + } + } - if (result === undefined) { - throw new Error(`Server sent invalid initialize result: ${result}`); - } + private async _initialize(transport: Transport, options?: RequestOptions): Promise { + const result = await this._requestWithSchema( + { + method: 'initialize', + params: { + protocolVersion: this._supportedProtocolVersions[0] ?? LATEST_PROTOCOL_VERSION, + capabilities: this._capabilities, + clientInfo: this._clientInfo + } + }, + InitializeResultSchema, + options + ); - if (!this._supportedProtocolVersions.includes(result.protocolVersion)) { - throw new Error(`Server's protocol version is not supported: ${result.protocolVersion}`); - } + if (result === undefined) { + throw new Error(`Server sent invalid initialize result: ${result}`); + } - this._serverCapabilities = result.capabilities; - this._serverVersion = result.serverInfo; - this._negotiatedProtocolVersion = result.protocolVersion; - // HTTP transports must set the protocol version in each header after initialization. - if (transport.setProtocolVersion) { - transport.setProtocolVersion(result.protocolVersion); - } + if (!this._supportedProtocolVersions.includes(result.protocolVersion)) { + throw new Error(`Server's protocol version is not supported: ${result.protocolVersion}`); + } - this._instructions = result.instructions; + this._serverCapabilities = result.capabilities; + this._serverVersion = result.serverInfo; + this._negotiatedProtocolVersion = result.protocolVersion; + // HTTP transports must set the protocol version in each header after initialization. + if (transport.setProtocolVersion) { + transport.setProtocolVersion(result.protocolVersion); + } - await this.notification({ - method: 'notifications/initialized' - }); + this._instructions = result.instructions; - // Set up list changed handlers now that we know server capabilities - if (this._pendingListChangedConfig) { - this._setupListChangedHandlers(this._pendingListChangedConfig); - this._pendingListChangedConfig = undefined; - } - } catch (error) { - // Disconnect if initialization fails. - void this.close(); - throw error; + await this.notification({ + method: 'notifications/initialized' + }); + + // Set up list changed handlers now that we know server capabilities + if (this._pendingListChangedConfig) { + this._setupListChangedHandlers(this._pendingListChangedConfig); + this._pendingListChangedConfig = undefined; } } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index 3b8ddafe5a..fd3b7edd46 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -189,6 +189,7 @@ export class StreamableHTTPClientTransport implements Transport { onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; + onsessionexpired?: () => void | Promise; constructor(url: URL, opts?: StreamableHTTPClientTransportOptions) { this._url = url; @@ -521,13 +522,14 @@ export class StreamableHTTPClientTransport implements Transport { message: JSONRPCMessage | JSONRPCMessage[], options?: { resumptionToken?: string; onresumptiontoken?: (token: string) => void } ): Promise { - return this._send(message, options, false); + return this._send(message, options, false, false); } private async _send( message: JSONRPCMessage | JSONRPCMessage[], options: { resumptionToken?: string; onresumptiontoken?: (token: string) => void } | undefined, - isAuthRetry: boolean + isAuthRetry: boolean, + isSessionRetry: boolean ): Promise { try { const { resumptionToken, onresumptiontoken } = options || {}; @@ -579,7 +581,7 @@ export class StreamableHTTPClientTransport implements Transport { }); await response.text?.().catch(() => {}); // Purposely _not_ awaited, so we don't call onerror twice - return this._send(message, options, true); + return this._send(message, options, true, isSessionRetry); } await response.text?.().catch(() => {}); if (isAuthRetry) { @@ -593,6 +595,15 @@ export class StreamableHTTPClientTransport implements Transport { const text = await response.text?.().catch(() => null); + if (response.status === 404 && this._sessionId && !isSessionRetry) { + this._sessionId = undefined; + await this.onsessionexpired?.(); + + if (this._sessionId) { + return this._send(message, options, isAuthRetry, true); + } + } + if (response.status === 403 && this._oauthProvider) { const { resourceMetadataUrl, scope, error } = extractWWWAuthenticateParams(response); @@ -629,7 +640,7 @@ export class StreamableHTTPClientTransport implements Transport { throw new UnauthorizedError(); } - return this._send(message, options, isAuthRetry); + return this._send(message, options, isAuthRetry, isSessionRetry); } } diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index 0edf8b75ac..65d9325c1c 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -1,9 +1,10 @@ import type { JSONRPCMessage, JSONRPCRequest } from '@modelcontextprotocol/core'; -import { OAuthError, OAuthErrorCode, SdkErrorCode, SdkHttpError } from '@modelcontextprotocol/core'; +import { LATEST_PROTOCOL_VERSION, OAuthError, OAuthErrorCode, SdkErrorCode, SdkHttpError } from '@modelcontextprotocol/core'; import type { Mock, Mocked } from 'vitest'; import type { OAuthClientProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; +import { Client } from '../../src/client/client.js'; import type { ReconnectionScheduler, StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js'; import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; @@ -249,6 +250,96 @@ describe('StreamableHTTPClientTransport', () => { expect(errorSpy).toHaveBeenCalled(); }); + it('reinitializes and retries once when a persisted session expires', async () => { + const client = new Client({ name: 'test-client', version: '1.0.0' }); + const httpTransport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + let initializeCount = 0; + let firstPing = true; + + (globalThis.fetch as Mock).mockImplementation(async (_url, init) => { + if (init.method === 'GET') { + return { + ok: false, + status: 405, + statusText: 'Method Not Allowed', + headers: new Headers(), + text: async () => '' + }; + } + + const body = JSON.parse(init.body as string) as JSONRPCRequest; + + if (body.method === 'initialize') { + const sessionId = initializeCount++ === 0 ? 'old-session-id' : 'new-session-id'; + + return { + ok: true, + status: 200, + headers: new Headers({ + 'content-type': 'application/json', + 'mcp-session-id': sessionId + }), + json: async () => ({ + jsonrpc: '2.0', + id: body.id, + result: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + serverInfo: { name: 'test-server', version: '1.0.0' } + } + }) + }; + } + + if (body.method === 'notifications/initialized') { + return { + ok: true, + status: 202, + headers: new Headers(), + text: async () => '' + }; + } + + if (body.method === 'ping' && firstPing) { + firstPing = false; + + return { + ok: false, + status: 404, + statusText: 'Not Found', + headers: new Headers(), + text: async () => 'Session not found' + }; + } + + return { + ok: true, + status: 200, + headers: new Headers({ 'content-type': 'application/json' }), + json: async () => ({ + jsonrpc: '2.0', + id: body.id, + result: {} + }) + }; + }); + + try { + await client.connect(httpTransport); + await expect(client.ping()).resolves.toEqual({}); + + const calls = (globalThis.fetch as Mock).mock.calls; + const postCalls = calls.filter(([, init]) => init.method === 'POST'); + expect(postCalls).toHaveLength(6); + expect(postCalls[2]![1].headers.get('mcp-session-id')).toBe('old-session-id'); + expect(postCalls[3]![1].headers.get('mcp-session-id')).toBeNull(); + expect(postCalls[5]![1].headers.get('mcp-session-id')).toBe('new-session-id'); + expect(httpTransport.sessionId).toBe('new-session-id'); + } finally { + await client.close().catch(() => {}); + } + }); + it('should handle non-streaming JSON response', async () => { const message: JSONRPCMessage = { jsonrpc: '2.0',