Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 53 additions & 38 deletions packages/client/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ export type ClientOptions = ProtocolOptions & {
listChanged?: ListChangedHandlers;
};

type SessionExpiringTransport = Transport & {
onsessionexpired?: () => void | Promise<void>;
};

/**
* An MCP client on top of a pluggable transport.
*
Expand Down Expand Up @@ -483,6 +487,13 @@ export class Client extends Protocol<ClientContext> {
*/
override async connect(transport: Transport, options?: RequestOptions): Promise<void> {
await super.connect(transport);
(transport as SessionExpiringTransport).onsessionexpired = async () => {
this._serverCapabilities = undefined;
this._serverVersion = undefined;
this._negotiatedProtocolVersion = undefined;
await this._initialize(transport, options);
};

// When transport sessionId is already set this means we are trying to reconnect.
// Restore the protocol version negotiated during the original initialize handshake
// so HTTP transports include the required mcp-protocol-version header, but skip re-init.
Expand All @@ -493,50 +504,54 @@ export class Client extends Protocol<ClientContext> {
return;
}
try {
const result = await this._requestWithSchema(
{
method: 'initialize',
params: {
protocolVersion: this._supportedProtocolVersions[0] ?? LATEST_PROTOCOL_VERSION,
capabilities: this._capabilities,
clientInfo: this._clientInfo
}
},
InitializeResultSchema,
options
);
await this._initialize(transport, options);
} catch (error) {
// Disconnect if initialization fails.
void this.close();
throw error;
}
}

if (result === undefined) {
throw new Error(`Server sent invalid initialize result: ${result}`);
}
private async _initialize(transport: Transport, options?: RequestOptions): Promise<void> {
const result = await this._requestWithSchema(
{
method: 'initialize',
params: {
protocolVersion: this._supportedProtocolVersions[0] ?? LATEST_PROTOCOL_VERSION,
capabilities: this._capabilities,
clientInfo: this._clientInfo
}
},
InitializeResultSchema,
options
);

if (!this._supportedProtocolVersions.includes(result.protocolVersion)) {
throw new Error(`Server's protocol version is not supported: ${result.protocolVersion}`);
}
if (result === undefined) {
throw new Error(`Server sent invalid initialize result: ${result}`);
}

this._serverCapabilities = result.capabilities;
this._serverVersion = result.serverInfo;
this._negotiatedProtocolVersion = result.protocolVersion;
// HTTP transports must set the protocol version in each header after initialization.
if (transport.setProtocolVersion) {
transport.setProtocolVersion(result.protocolVersion);
}
if (!this._supportedProtocolVersions.includes(result.protocolVersion)) {
throw new Error(`Server's protocol version is not supported: ${result.protocolVersion}`);
}

this._instructions = result.instructions;
this._serverCapabilities = result.capabilities;
this._serverVersion = result.serverInfo;
this._negotiatedProtocolVersion = result.protocolVersion;
// HTTP transports must set the protocol version in each header after initialization.
if (transport.setProtocolVersion) {
transport.setProtocolVersion(result.protocolVersion);
}

await this.notification({
method: 'notifications/initialized'
});
this._instructions = result.instructions;

// Set up list changed handlers now that we know server capabilities
if (this._pendingListChangedConfig) {
this._setupListChangedHandlers(this._pendingListChangedConfig);
this._pendingListChangedConfig = undefined;
}
} catch (error) {
// Disconnect if initialization fails.
void this.close();
throw error;
await this.notification({
method: 'notifications/initialized'
});

// Set up list changed handlers now that we know server capabilities
if (this._pendingListChangedConfig) {
this._setupListChangedHandlers(this._pendingListChangedConfig);
this._pendingListChangedConfig = undefined;
}
}

Expand Down
19 changes: 15 additions & 4 deletions packages/client/src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ export class StreamableHTTPClientTransport implements Transport {
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
onsessionexpired?: () => void | Promise<void>;

constructor(url: URL, opts?: StreamableHTTPClientTransportOptions) {
this._url = url;
Expand Down Expand Up @@ -521,13 +522,14 @@ export class StreamableHTTPClientTransport implements Transport {
message: JSONRPCMessage | JSONRPCMessage[],
options?: { resumptionToken?: string; onresumptiontoken?: (token: string) => void }
): Promise<void> {
return this._send(message, options, false);
return this._send(message, options, false, false);
}

private async _send(
message: JSONRPCMessage | JSONRPCMessage[],
options: { resumptionToken?: string; onresumptiontoken?: (token: string) => void } | undefined,
isAuthRetry: boolean
isAuthRetry: boolean,
isSessionRetry: boolean
): Promise<void> {
try {
const { resumptionToken, onresumptiontoken } = options || {};
Expand Down Expand Up @@ -579,7 +581,7 @@ export class StreamableHTTPClientTransport implements Transport {
});
await response.text?.().catch(() => {});
// Purposely _not_ awaited, so we don't call onerror twice
return this._send(message, options, true);
return this._send(message, options, true, isSessionRetry);
}
await response.text?.().catch(() => {});
if (isAuthRetry) {
Expand All @@ -593,6 +595,15 @@ export class StreamableHTTPClientTransport implements Transport {

const text = await response.text?.().catch(() => null);

if (response.status === 404 && this._sessionId && !isSessionRetry) {
this._sessionId = undefined;
await this.onsessionexpired?.();

if (this._sessionId) {
return this._send(message, options, isAuthRetry, true);
}
}

if (response.status === 403 && this._oauthProvider) {
const { resourceMetadataUrl, scope, error } = extractWWWAuthenticateParams(response);

Expand Down Expand Up @@ -629,7 +640,7 @@ export class StreamableHTTPClientTransport implements Transport {
throw new UnauthorizedError();
}

return this._send(message, options, isAuthRetry);
return this._send(message, options, isAuthRetry, isSessionRetry);
}
}

Expand Down
93 changes: 92 additions & 1 deletion packages/client/test/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import type { JSONRPCMessage, JSONRPCRequest } from '@modelcontextprotocol/core';
import { OAuthError, OAuthErrorCode, SdkErrorCode, SdkHttpError } from '@modelcontextprotocol/core';
import { LATEST_PROTOCOL_VERSION, OAuthError, OAuthErrorCode, SdkErrorCode, SdkHttpError } from '@modelcontextprotocol/core';
import type { Mock, Mocked } from 'vitest';

import type { OAuthClientProvider } from '../../src/client/auth.js';
import { UnauthorizedError } from '../../src/client/auth.js';
import { Client } from '../../src/client/client.js';
import type { ReconnectionScheduler, StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js';
import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js';

Expand Down Expand Up @@ -249,6 +250,96 @@ describe('StreamableHTTPClientTransport', () => {
expect(errorSpy).toHaveBeenCalled();
});

it('reinitializes and retries once when a persisted session expires', async () => {
const client = new Client({ name: 'test-client', version: '1.0.0' });
const httpTransport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'));
let initializeCount = 0;
let firstPing = true;

(globalThis.fetch as Mock).mockImplementation(async (_url, init) => {
if (init.method === 'GET') {
return {
ok: false,
status: 405,
statusText: 'Method Not Allowed',
headers: new Headers(),
text: async () => ''
};
}

const body = JSON.parse(init.body as string) as JSONRPCRequest;

if (body.method === 'initialize') {
const sessionId = initializeCount++ === 0 ? 'old-session-id' : 'new-session-id';

return {
ok: true,
status: 200,
headers: new Headers({
'content-type': 'application/json',
'mcp-session-id': sessionId
}),
json: async () => ({
jsonrpc: '2.0',
id: body.id,
result: {
protocolVersion: LATEST_PROTOCOL_VERSION,
capabilities: {},
serverInfo: { name: 'test-server', version: '1.0.0' }
}
})
};
}

if (body.method === 'notifications/initialized') {
return {
ok: true,
status: 202,
headers: new Headers(),
text: async () => ''
};
}

if (body.method === 'ping' && firstPing) {
firstPing = false;

return {
ok: false,
status: 404,
statusText: 'Not Found',
headers: new Headers(),
text: async () => 'Session not found'
};
}

return {
ok: true,
status: 200,
headers: new Headers({ 'content-type': 'application/json' }),
json: async () => ({
jsonrpc: '2.0',
id: body.id,
result: {}
})
};
});

try {
await client.connect(httpTransport);
await expect(client.ping()).resolves.toEqual({});

const calls = (globalThis.fetch as Mock).mock.calls;
const postCalls = calls.filter(([, init]) => init.method === 'POST');
expect(postCalls).toHaveLength(6);
expect(postCalls[2]![1].headers.get('mcp-session-id')).toBe('old-session-id');
expect(postCalls[3]![1].headers.get('mcp-session-id')).toBeNull();
expect(postCalls[5]![1].headers.get('mcp-session-id')).toBe('new-session-id');
expect(httpTransport.sessionId).toBe('new-session-id');
} finally {
await client.close().catch(() => {});
}
});

it('should handle non-streaming JSON response', async () => {
const message: JSONRPCMessage = {
jsonrpc: '2.0',
Expand Down
Loading