diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index f6a34f02da..b2266f1a57 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -102,6 +102,8 @@ export class Server extends Protocol { private _instructions?: string; private _jsonSchemaValidator: jsonSchemaValidator; private _experimental?: { tasks: ExperimentalServerTasks }; + private _receivedInitialize = false; + private _initialized = false; /** * Callback for when initialization has fully completed (i.e., the client has sent an `notifications/initialized` notification). @@ -132,7 +134,13 @@ export class Server extends Protocol { } this.setRequestHandler('initialize', request => this._oninitialize(request)); - this.setNotificationHandler('notifications/initialized', () => this.oninitialized?.()); + this.setNotificationHandler('notifications/initialized', () => { + if (!this._receivedInitialize) { + throw new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Server not initialized'); + } + this._initialized = true; + this.oninitialized?.(); + }); if (this._capabilities.logging) { this._registerLoggingHandler(); @@ -226,8 +234,19 @@ export class Server extends Protocol { method: string, handler: (request: JSONRPCRequest, ctx: ServerContext) => Promise ): (request: JSONRPCRequest, ctx: ServerContext) => Promise { + const lifecycleHandler: (request: JSONRPCRequest, ctx: ServerContext) => Promise = async (request, ctx) => { + if (!ctx.http && ctx.sessionId === undefined && !this._receivedInitialize && method !== 'initialize' && method !== 'ping') { + throw new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Server not initialized'); + } + const result = await handler(request, ctx); + if (method === 'initialize') { + this._receivedInitialize = true; + } + return result; + }; + if (method !== 'tools/call') { - return handler; + return lifecycleHandler; } return async (request, ctx) => { const validatedRequest = parseSchema(CallToolRequestSchema, request); @@ -239,7 +258,7 @@ export class Server extends Protocol { const { params } = validatedRequest.data; - const result = await handler(request, ctx); + const result = await lifecycleHandler(request, ctx); // When task creation is requested, validate and return CreateTaskResult if (params.task) { diff --git a/packages/server/test/server/server.test.ts b/packages/server/test/server/server.test.ts index fdb8214c56..c1dcf1ffde 100644 --- a/packages/server/test/server/server.test.ts +++ b/packages/server/test/server/server.test.ts @@ -1,5 +1,5 @@ import type { JSONRPCMessage } from '@modelcontextprotocol/core'; -import { InMemoryTransport, LATEST_PROTOCOL_VERSION } from '@modelcontextprotocol/core'; +import { InMemoryTransport, LATEST_PROTOCOL_VERSION, ProtocolErrorCode } from '@modelcontextprotocol/core'; import { Server } from '../../src/server/server.js'; describe('Server', () => { @@ -38,5 +38,72 @@ describe('Server', () => { await server.close(); }); + + it('rejects requests before initialize', async () => { + const server = new Server({ name: 'test', version: '1.0.0' }, { capabilities: { tools: {} } }); + + server.setRequestHandler('tools/list', async () => ({ tools: [] })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + + const responses: JSONRPCMessage[] = []; + clientTransport.onmessage = message => responses.push(message); + await clientTransport.start(); + + await clientTransport.send({ + jsonrpc: '2.0', + method: 'notifications/initialized' + } as JSONRPCMessage); + + await clientTransport.send({ + jsonrpc: '2.0', + id: 1, + method: 'tools/list', + params: {} + } as JSONRPCMessage); + + await vi.waitFor(() => expect(responses.some(message => 'id' in message && message.id === 1)).toBe(true)); + + const rejected = responses.find(message => 'id' in message && message.id === 1); + expect(rejected).toMatchObject({ + error: { + code: ProtocolErrorCode.InvalidRequest, + message: 'Server not initialized' + } + }); + + await clientTransport.send({ + jsonrpc: '2.0', + id: 2, + method: 'initialize', + params: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } as JSONRPCMessage); + await vi.waitFor(() => expect(responses.some(message => 'id' in message && message.id === 2)).toBe(true)); + + await clientTransport.send({ + jsonrpc: '2.0', + method: 'notifications/initialized' + } as JSONRPCMessage); + + await clientTransport.send({ + jsonrpc: '2.0', + id: 3, + method: 'tools/list', + params: {} + } as JSONRPCMessage); + + await vi.waitFor(() => expect(responses.some(message => 'id' in message && message.id === 3)).toBe(true)); + + expect(responses.find(message => 'id' in message && message.id === 3)).toMatchObject({ + result: { tools: [] } + }); + + await server.close(); + }); }); });