diff --git a/.changeset/dispatcher-extraction.md b/.changeset/dispatcher-extraction.md new file mode 100644 index 0000000000..e121016e1c --- /dev/null +++ b/.changeset/dispatcher-extraction.md @@ -0,0 +1,4 @@ +--- +'@modelcontextprotocol/core': major +--- +Extract Dispatcher from Protocol. Protocol composes `protected readonly dispatcher`; setRequestHandler/_onrequest delegate. The protected `_wrapHandler` override hook is replaced by `dispatcher.use(middleware)`. diff --git a/.changeset/wraphandler-hook.md b/.changeset/wraphandler-hook.md deleted file mode 100644 index 935f576588..0000000000 --- a/.changeset/wraphandler-hook.md +++ /dev/null @@ -1,7 +0,0 @@ ---- -'@modelcontextprotocol/core': patch -'@modelcontextprotocol/client': patch -'@modelcontextprotocol/server': patch ---- - -refactor: subclasses override `_wrapHandler` hook instead of redeclaring `setRequestHandler`. diff --git a/CLAUDE.md b/CLAUDE.md index d5a188676a..ac88c3fe29 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -161,10 +161,11 @@ When a request arrives from the remote side: 1. **Transport** receives message, calls `transport.onmessage()` 2. **`Protocol.connect()`** routes to `_onrequest()`, `_onresponse()`, or `_onnotification()` 3. **`Protocol._onrequest()`**: - - Looks up handler in `_requestHandlers` map (keyed by method name) + - Checks `dispatcher.canHandle(method)`; sends a `MethodNotFound` error and returns early if no handler (or fallback) is registered - Creates `BaseContext` with `signal`, `sessionId`, `sendNotification`, `sendRequest`, etc. - Calls `buildContext()` to let subclasses enrich the context (e.g., Server adds HTTP request info) - - Invokes handler, sends JSON-RPC response back via transport + - Calls `dispatcher.dispatch()` which looks up the handler (keyed by method name), runs the middleware chain, invokes the handler, and wraps the result as a JSON-RPC response + - Sends the response back via transport 4. **Handler** was registered via `setRequestHandler('method', handler)` ### Handler Registration diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 2044cfc453..84b5d320b6 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -7,7 +7,6 @@ import type { CompleteRequest, GetPromptRequest, Implementation, - JSONRPCRequest, JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator, @@ -19,12 +18,12 @@ import type { ListToolsRequest, LoggingLevel, MessageExtraInfo, + Middleware, NotificationMethod, ProtocolOptions, ReadResourceRequest, RequestMethod, RequestOptions, - Result, ServerCapabilities, SubscribeRequest, Tool, @@ -229,6 +228,8 @@ export class Client extends Protocol { this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator(); this._enforceStrictCapabilities = options?.enforceStrictCapabilities ?? false; + this.dispatcher.use(this._validationMiddleware); + // Store list changed config for setup after connection (when we know server capabilities) if (options?.listChanged) { this._pendingListChangedConfig = options.listChanged; @@ -283,93 +284,86 @@ export class Client extends Protocol { /** * Enforces client-side validation for `elicitation/create` and `sampling/createMessage` - * regardless of how the handler was registered. + * regardless of how the handler was registered. Installed as a {@linkcode Dispatcher} + * middleware so it applies to both the legacy `_onrequest` path and the 2026-06 + * dispatch path. */ - protected override _wrapHandler( - method: string, - handler: (request: JSONRPCRequest, ctx: ClientContext) => Promise - ): (request: JSONRPCRequest, ctx: ClientContext) => Promise { - if (method === 'elicitation/create') { - return async (request, ctx) => { - const validatedRequest = parseSchema(ElicitRequestSchema, request); - if (!validatedRequest.success) { - // Type guard: if success is false, error is guaranteed to exist - const errorMessage = - validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid elicitation request: ${errorMessage}`); - } + private readonly _validationMiddleware: Middleware = async (request, _ctx, next) => { + if (request.method === 'elicitation/create') { + const validatedRequest = parseSchema(ElicitRequestSchema, request); + if (!validatedRequest.success) { + const errorMessage = + validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid elicitation request: ${errorMessage}`); + } - const { params } = validatedRequest.data; - params.mode = params.mode ?? 'form'; - const { supportsFormMode, supportsUrlMode } = getSupportedElicitationModes(this._capabilities.elicitation); + const { params } = validatedRequest.data; + params.mode = params.mode ?? 'form'; + const { supportsFormMode, supportsUrlMode } = getSupportedElicitationModes(this._capabilities.elicitation); - if (params.mode === 'form' && !supportsFormMode) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support form-mode elicitation requests'); - } + if (params.mode === 'form' && !supportsFormMode) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support form-mode elicitation requests'); + } - if (params.mode === 'url' && !supportsUrlMode) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests'); - } + if (params.mode === 'url' && !supportsUrlMode) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests'); + } - const result = await handler(request, ctx); + const result = await next(); - const validationResult = parseSchema(ElicitResultSchema, result); - if (!validationResult.success) { - // Type guard: if success is false, error is guaranteed to exist - const errorMessage = - validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid elicitation result: ${errorMessage}`); - } + const validationResult = parseSchema(ElicitResultSchema, result); + if (!validationResult.success) { + const errorMessage = + validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid elicitation result: ${errorMessage}`); + } - const validatedResult = validationResult.data; - const requestedSchema = params.mode === 'form' ? (params.requestedSchema as JsonSchemaType) : undefined; - - if ( - params.mode === 'form' && - validatedResult.action === 'accept' && - validatedResult.content && - requestedSchema && - this._capabilities.elicitation?.form?.applyDefaults - ) { - try { - applyElicitationDefaults(requestedSchema, validatedResult.content); - } catch { - // gracefully ignore errors in default application - } + const validatedResult = validationResult.data; + const requestedSchema = params.mode === 'form' ? (params.requestedSchema as JsonSchemaType) : undefined; + + if ( + params.mode === 'form' && + validatedResult.action === 'accept' && + validatedResult.content && + requestedSchema && + this._capabilities.elicitation?.form?.applyDefaults + ) { + try { + applyElicitationDefaults(requestedSchema, validatedResult.content); + } catch { + // gracefully ignore errors in default application } + } - return validatedResult; - }; + return validatedResult; } - if (method === 'sampling/createMessage') { - return async (request, ctx) => { - const validatedRequest = parseSchema(CreateMessageRequestSchema, request); - if (!validatedRequest.success) { - const errorMessage = - validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid sampling request: ${errorMessage}`); - } + if (request.method === 'sampling/createMessage') { + const validatedRequest = parseSchema(CreateMessageRequestSchema, request); + if (!validatedRequest.success) { + const errorMessage = + validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid sampling request: ${errorMessage}`); + } - const { params } = validatedRequest.data; + const { params } = validatedRequest.data; - const result = await handler(request, ctx); + const result = await next(); - const hasTools = params.tools || params.toolChoice; - const resultSchema = hasTools ? CreateMessageResultWithToolsSchema : CreateMessageResultSchema; - const validationResult = parseSchema(resultSchema, result); - if (!validationResult.success) { - const errorMessage = - validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid sampling result: ${errorMessage}`); - } + const hasTools = params.tools || params.toolChoice; + const resultSchema = hasTools ? CreateMessageResultWithToolsSchema : CreateMessageResultSchema; + const validationResult = parseSchema(resultSchema, result); + if (!validationResult.success) { + const errorMessage = + validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid sampling result: ${errorMessage}`); + } - return validationResult.data; - }; + return validationResult.data; } - return handler; - } + return next(); + }; protected assertCapability(capability: keyof ServerCapabilities, method: string): void { if (!this._serverCapabilities?.[capability]) { diff --git a/packages/core/src/exports/public/index.ts b/packages/core/src/exports/public/index.ts index 28c36538e0..ef30915ba6 100644 --- a/packages/core/src/exports/public/index.ts +++ b/packages/core/src/exports/public/index.ts @@ -38,6 +38,9 @@ export { checkResourceAllowed, resourceUrlFromServerUrl } from '../../shared/aut // Metadata utilities export { getDisplayName } from '../../shared/metadataUtils.js'; +// Dispatcher types (handler registry; consumed by Protocol) +export type { RequestHandlerSchemas } from '../../shared/dispatcher.js'; + // Protocol types (NOT the Protocol class itself or mergeCapabilities) export type { BaseContext, @@ -45,7 +48,6 @@ export type { NotificationOptions, ProgressCallback, ProtocolOptions, - RequestHandlerSchemas, RequestOptions, ServerContext } from '../../shared/protocol.js'; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 0c34b64915..ba787aefa1 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -2,6 +2,7 @@ export * from './auth/errors.js'; export * from './errors/sdkErrors.js'; export * from './shared/auth.js'; export * from './shared/authUtils.js'; +export * from './shared/dispatcher.js'; export * from './shared/metadataUtils.js'; export * from './shared/protocol.js'; export * from './shared/stdio.js'; diff --git a/packages/core/src/shared/dispatcher.ts b/packages/core/src/shared/dispatcher.ts new file mode 100644 index 0000000000..72a8a8f76b --- /dev/null +++ b/packages/core/src/shared/dispatcher.ts @@ -0,0 +1,202 @@ +import { ProtocolErrorCode } from '../types/enums.js'; +import { ProtocolError } from '../types/errors.js'; +import type { + JSONRPCErrorResponse, + JSONRPCRequest, + JSONRPCResponse, + RequestMethod, + RequestTypeMap, + Result, + ResultTypeMap +} from '../types/index.js'; +import { getRequestSchema, JSONRPC_VERSION } from '../types/index.js'; +import type { StandardSchemaV1 } from '../util/standardSchema.js'; +import { validateStandardSchema } from '../util/standardSchema.js'; + +/** + * A request handler stored in {@linkcode Dispatcher}. Receives the raw JSON-RPC + * request and a caller-supplied context, returns a `Result` (the success + * payload). Throw {@linkcode ProtocolError} to surface a structured error. + */ +export type Handler = (request: JSONRPCRequest, ctx: C) => Promise; + +/** + * Onion-style middleware around handler invocation. Receives `next` to call + * the remaining chain (and ultimately the handler). May short-circuit by + * returning a `Result` without calling `next`, or transform the result/error. + * + * Installed via {@linkcode Dispatcher.use}; runs for every request that + * routes through {@linkcode Dispatcher.dispatch}. + */ +export type Middleware = (request: JSONRPCRequest, ctx: C, next: () => Promise) => Promise; + +/** + * Schema bundle accepted by `setRequestHandler`'s 3-arg form. + * + * `params` is required and validates the inbound `request.params`. `result` is optional; + * when supplied it types the handler's return value (no runtime validation is performed + * on the result). + */ +export interface RequestHandlerSchemas< + P extends StandardSchemaV1 = StandardSchemaV1, + R extends StandardSchemaV1 | undefined = StandardSchemaV1 | undefined +> { + params: P; + result?: R; +} + +/** Infers the handler's return type from a `RequestHandlerSchemas.result` schema (or `Result` when absent). */ +export type InferHandlerResult = R extends StandardSchemaV1 + ? StandardSchemaV1.InferOutput + : Result; + +/** + * Method-keyed request handler registry plus invocation. Both the legacy + * connect/_onrequest path and the 2026-06 stateless dispatch path route + * through {@linkcode dispatch}. + * + * `dispatch()` looks up the handler, runs the middleware chain, wraps the + * result/error into a JSON-RPC response. It writes no instance state and is + * safe to call concurrently. + */ +export class Dispatcher { + private readonly _handlers = new Map>(); + private readonly _middleware: Middleware[] = []; + + /** Called when no specific handler matches. Not wrapped by middleware. */ + fallbackHandler?: Handler; + + /** + * Appends a middleware. Middlewares run in registration order, with the + * registered handler as the innermost call. + */ + use(middleware: Middleware): void { + this._middleware.push(middleware); + } + + /** + * Registers a handler to invoke when this dispatcher receives a request with the given method. + * + * Note that this will replace any previous request handler for the same method. + * + * For spec methods, pass `(method, handler)`; the request is parsed with the spec + * schema and the handler receives the typed `Request`. For custom (non-spec) + * methods, pass `(method, schemas, handler)`; `params` are validated against + * `schemas.params` and the handler receives the parsed params object directly. + * Supplying `schemas.result` types the handler's return value. + */ + setRequestHandler( + method: M, + handler: (request: RequestTypeMap[M], ctx: ContextT) => ResultTypeMap[M] | Promise + ): void; + setRequestHandler

( + method: string, + schemas: { params: P; result?: R }, + handler: (params: StandardSchemaV1.InferOutput

, ctx: ContextT) => InferHandlerResult | Promise> + ): void; + setRequestHandler( + method: string, + schemasOrHandler: RequestHandlerSchemas | ((request: unknown, ctx: ContextT) => Result | Promise), + maybeHandler?: (params: unknown, ctx: ContextT) => Result | Promise + ): void { + let stored: Handler; + + if (typeof schemasOrHandler === 'function') { + const schema = getRequestSchema(method); + if (!schema) { + throw new TypeError( + `'${method}' is not a spec request method; pass schemas as the second argument to setRequestHandler().` + ); + } + stored = (request, ctx) => Promise.resolve(schemasOrHandler(schema.parse(request), ctx)); + } else if (maybeHandler) { + stored = async (request, ctx) => { + const userParams = { ...request.params }; + delete userParams._meta; + const parsed = await validateStandardSchema(schemasOrHandler.params, userParams); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error}`); + } + return maybeHandler(parsed.data, ctx); + }; + } else { + throw new TypeError('setRequestHandler: handler is required'); + } + + this._handlers.set(method, stored); + } + + /** + * Removes the request handler for the given method. + */ + removeRequestHandler(method: RequestMethod | string): void { + this._handlers.delete(method); + } + + /** + * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. + */ + assertCanSetRequestHandler(method: RequestMethod | string): void { + if (this._handlers.has(method)) { + throw new Error(`A request handler for ${method} already exists, which would be overridden`); + } + } + + /** + * Returns true if {@linkcode dispatch} would route this method to a handler + * (registered or fallback) rather than returning MethodNotFound. + */ + canHandle(method: string): boolean { + return this._handlers.has(method) || this.fallbackHandler !== undefined; + } + + /** + * Dispatches one JSON-RPC request through the middleware chain to its + * handler and wraps the outcome as a JSON-RPC response. + * + * Thrown errors are surfaced with their `code` (if a safe integer), + * `message`, and `data` properties. This matches the behavior of + * `Protocol._onrequest` prior to extraction. + */ + async dispatch(request: JSONRPCRequest, ctx: ContextT): Promise { + const id = request.id; + const handler = this._handlers.get(request.method); + let chain: () => Promise; + if (handler !== undefined) { + chain = () => handler(request, ctx); + for (let i = this._middleware.length - 1; i >= 0; i--) { + // Loop bounds guarantee a defined element (noUncheckedIndexedAccess). + const mw = this._middleware[i] as Middleware; + const next = chain; + chain = () => mw(request, ctx, next); + } + } else if (this.fallbackHandler === undefined) { + return errorResponse(id, ProtocolErrorCode.MethodNotFound, 'Method not found'); + } else { + // Preserve pre-extraction behavior: fallback bypasses middleware. + const fb = this.fallbackHandler; + chain = () => fb(request, ctx); + } + try { + return okResponse(id, await chain()); + } catch (error) { + const e = error as { code?: unknown; message?: string; data?: unknown }; + return errorResponse( + id, + Number.isSafeInteger(e.code) ? (e.code as number) : ProtocolErrorCode.InternalError, + e.message ?? 'Internal error', + e.data + ); + } + } +} + +/** Builds a JSON-RPC success response. */ +export function okResponse(id: JSONRPCRequest['id'], result: Result): JSONRPCResponse { + return { jsonrpc: JSONRPC_VERSION, id, result }; +} + +/** Builds a JSON-RPC error response. */ +export function errorResponse(id: JSONRPCRequest['id'], code: number, message: string, data?: unknown): JSONRPCErrorResponse { + return { jsonrpc: JSONRPC_VERSION, id, error: { code, message, ...(data === undefined ? {} : { data }) } }; +} diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index ed78cc68d0..77a2222a14 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -32,7 +32,6 @@ import type { } from '../types/index.js'; import { getNotificationSchema, - getRequestSchema, getResultSchema, isJSONRPCErrorResponse, isJSONRPCNotification, @@ -44,6 +43,8 @@ import { } from '../types/index.js'; import type { StandardSchemaV1 } from '../util/standardSchema.js'; import { isStandardSchema, validateStandardSchema } from '../util/standardSchema.js'; +import type { Handler, InferHandlerResult, RequestHandlerSchemas } from './dispatcher.js'; +import { Dispatcher, errorResponse } from './dispatcher.js'; import type { Transport, TransportSendOptions } from './transport.js'; /** @@ -57,7 +58,7 @@ export type ProgressCallback = (progress: Progress) => void; export type ProtocolOptions = { /** * Protocol versions supported. First version is preferred (sent by client, - * used as fallback by server). Passed to transport during {@linkcode Protocol.connect | connect()}. + * used as fallback by server). Passed to transport during `connect()`. * * @default {@linkcode SUPPORTED_PROTOCOL_VERSIONS} */ @@ -275,7 +276,8 @@ type TimeoutInfo = { export abstract class Protocol { private _transport?: Transport; private _requestMessageId = 0; - private _requestHandlers: Map Promise> = new Map(); + /** The handler registry. Both `_onrequest` and `_dispatchStateless` route through it. */ + protected readonly dispatcher = new Dispatcher(); private _requestHandlerAbortControllers: Map = new Map(); private _notificationHandlers: Map Promise> = new Map(); private _responseHandlers: Map void> = new Map(); @@ -302,7 +304,12 @@ export abstract class Protocol { /** * A handler to invoke for any request types that do not have their own handler installed. */ - fallbackRequestHandler?: (request: JSONRPCRequest, ctx: ContextT) => Promise; + get fallbackRequestHandler(): Handler | undefined { + return this.dispatcher.fallbackHandler; + } + set fallbackRequestHandler(handler: Handler | undefined) { + this.dispatcher.fallbackHandler = handler; + } /** * A handler to invoke for any notification types that do not have their own handler installed. @@ -320,7 +327,10 @@ export abstract class Protocol { this._onprogress(notification); }); - this.setRequestHandler( + // Register directly on the dispatcher (not via setRequestHandler) so + // the abstract assertRequestHandlerCapability is not called before + // subclass fields are initialised. + this.dispatcher.setRequestHandler( 'ping', // Automatic pong by default. _request => ({}) as Result @@ -462,7 +472,7 @@ export abstract class Protocol { this.onerror?.(error); } - private _onnotification(notification: JSONRPCNotification): void { + protected _onnotification(notification: JSONRPCNotification): void { const handler = this._notificationHandlers.get(notification.method) ?? this.fallbackNotificationHandler; // Ignore notifications not being subscribed to. @@ -477,29 +487,21 @@ export abstract class Protocol { } private _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): void { - const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; - // Capture the current transport at request time to ensure responses go to the correct client const capturedTransport = this._transport; + if (!this.dispatcher.canHandle(request.method)) { + capturedTransport + ?.send(errorResponse(request.id, ProtocolErrorCode.MethodNotFound, 'Method not found')) + .catch(error => this._onerror(new Error(`Failed to send response: ${error}`))); + return; + } + const sendNotification = (notification: Notification, options?: NotificationOptions) => this.notification(notification, { ...options, relatedRequestId: request.id }); const sendRequest = (r: Request, resultSchema: U, options?: RequestOptions) => this._requestWithSchema(r, resultSchema, { ...options, relatedRequestId: request.id }); - if (handler === undefined) { - const errorResponse: JSONRPCErrorResponse = { - jsonrpc: '2.0', - id: request.id, - error: { - code: ProtocolErrorCode.MethodNotFound, - message: 'Method not found' - } - }; - capturedTransport?.send(errorResponse).catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); - return; - } - const abortController = new AbortController(); this._requestHandlerAbortControllers.set(request.id, abortController); @@ -532,41 +534,14 @@ export abstract class Protocol { }; const ctx = this.buildContext(baseCtx, extra); - // Starting with Promise.resolve() puts any synchronous errors into the monad as well. - Promise.resolve() - .then(() => handler(request, ctx)) - .then( - async result => { - if (abortController.signal.aborted) { - // Request was cancelled - return; - } - - const response: JSONRPCResponse = { - result, - jsonrpc: '2.0', - id: request.id - }; - await capturedTransport?.send(response); - }, - async error => { - if (abortController.signal.aborted) { - // Request was cancelled - return; - } - - const errorResponse: JSONRPCErrorResponse = { - jsonrpc: '2.0', - id: request.id, - error: { - code: Number.isSafeInteger(error['code']) ? error['code'] : ProtocolErrorCode.InternalError, - message: error.message ?? 'Internal error', - ...(error['data'] !== undefined && { data: error['data'] }) - } - }; - await capturedTransport?.send(errorResponse); + this.dispatcher + .dispatch(request, ctx) + .then(async response => { + if (abortController.signal.aborted) { + return; } - ) + await capturedTransport?.send(response); + }) .catch(error => this._onerror(new Error(`Failed to send response: ${error}`))) .finally(() => { if (this._requestHandlerAbortControllers.get(request.id) === abortController) { @@ -902,63 +877,23 @@ export abstract class Protocol { maybeHandler?: (params: unknown, ctx: ContextT) => Result | Promise ): void { this.assertRequestHandlerCapability(method); - - let stored: (request: JSONRPCRequest, ctx: ContextT) => Promise; - - if (typeof schemasOrHandler === 'function') { - const schema = getRequestSchema(method); - if (!schema) { - throw new TypeError( - `'${method}' is not a spec request method; pass schemas as the second argument to setRequestHandler().` - ); - } - stored = (request, ctx) => Promise.resolve(schemasOrHandler(schema.parse(request), ctx)); - } else if (maybeHandler) { - stored = async (request, ctx) => { - const userParams = { ...request.params }; - delete userParams._meta; - const parsed = await validateStandardSchema(schemasOrHandler.params, userParams); - if (!parsed.success) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error}`); - } - return maybeHandler(parsed.data, ctx); - }; - } else { - throw new TypeError('setRequestHandler: handler is required'); - } - - this._requestHandlers.set(method, this._wrapHandler(method, stored)); - } - - /** - * Hook for subclasses to wrap a registered request handler with role-specific - * validation or behavior (e.g. `Server` validates `tools/call` results, `Client` - * validates `elicitation/create` mode and result). Runs for both the 2-arg and - * 3-arg registration paths. The default implementation is identity. - * - * Subclasses overriding this hook avoid redeclaring `setRequestHandler`'s overload set. - */ - protected _wrapHandler( - _method: string, - handler: (request: JSONRPCRequest, ctx: ContextT) => Promise - ): (request: JSONRPCRequest, ctx: ContextT) => Promise { - return handler; + // Unsound only at the impl signature; the public overloads are sound, and + // Dispatcher.setRequestHandler has the identical impl signature. + this.dispatcher.setRequestHandler(method, schemasOrHandler as never, maybeHandler as never); } /** * Removes the request handler for the given method. */ removeRequestHandler(method: RequestMethod | string): void { - this._requestHandlers.delete(method); + this.dispatcher.removeRequestHandler(method); } /** * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. */ assertCanSetRequestHandler(method: RequestMethod | string): void { - if (this._requestHandlers.has(method)) { - throw new Error(`A request handler for ${method} already exists, which would be overridden`); - } + this.dispatcher.assertCanSetRequestHandler(method); } /** @@ -1019,23 +954,6 @@ export abstract class Protocol { } } -/** - * Schema bundle accepted by {@linkcode Protocol.setRequestHandler | setRequestHandler}'s 3-arg form. - * - * `params` is required and validates the inbound `request.params`. `result` is optional; - * when supplied it types the handler's return value (no runtime validation is performed - * on the result). - */ -export interface RequestHandlerSchemas< - P extends StandardSchemaV1 = StandardSchemaV1, - R extends StandardSchemaV1 | undefined = StandardSchemaV1 | undefined -> { - params: P; - result?: R; -} - -type InferHandlerResult = R extends StandardSchemaV1 ? StandardSchemaV1.InferOutput : Result; - function isPlainObject(value: unknown): value is Record { return value !== null && typeof value === 'object' && !Array.isArray(value); } diff --git a/packages/core/test/shared/customMethods.test.ts b/packages/core/test/shared/customMethods.test.ts index ffee5b9a7d..ca98667197 100644 --- a/packages/core/test/shared/customMethods.test.ts +++ b/packages/core/test/shared/customMethods.test.ts @@ -79,20 +79,20 @@ describe('Protocol custom-method support', () => { expect(() => p.setRequestHandler('acme/unknown' as never, () => ({}) as never)).toThrow(TypeError); }); - it('routes both 2-arg and 3-arg registration through _wrapHandler', () => { + it('runs Dispatcher middleware for both 2-arg and 3-arg registration', async () => { + const [a, b] = await pair(); const seen: string[] = []; - class SpyProtocol extends TestProtocol { - protected override _wrapHandler( - method: string, - handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise - ): (request: JSONRPCRequest, ctx: BaseContext) => Promise { - seen.push(method); - return handler; + // dispatcher is protected; reach it via prototype access for test purposes. + (b as unknown as { dispatcher: { use: (mw: unknown) => void } }).dispatcher.use( + async (request: JSONRPCRequest, _ctx: BaseContext, next: () => Promise) => { + seen.push(request.method); + return next(); } - } - const p = new SpyProtocol(); - p.setRequestHandler('tools/list', () => ({ tools: [] })); - p.setRequestHandler('acme/custom', { params: z.object({}) }, () => ({})); + ); + b.setRequestHandler('tools/list', () => ({ tools: [] })); + b.setRequestHandler('acme/custom', { params: z.object({}) }, () => ({})); + await a.request({ method: 'tools/list' }); + await a.request({ method: 'acme/custom' }, z.unknown()); expect(seen).toContain('tools/list'); expect(seen).toContain('acme/custom'); }); diff --git a/packages/core/test/shared/dispatcher.test.ts b/packages/core/test/shared/dispatcher.test.ts new file mode 100644 index 0000000000..73bfacdf15 --- /dev/null +++ b/packages/core/test/shared/dispatcher.test.ts @@ -0,0 +1,163 @@ +import { describe, expect, it } from 'vitest'; + +import { Dispatcher, errorResponse, type Middleware, okResponse } from '../../src/shared/dispatcher.js'; +import { ProtocolErrorCode } from '../../src/types/enums.js'; +import { ProtocolError } from '../../src/types/errors.js'; +import type { JSONRPCRequest, Result } from '../../src/types/index.js'; + +type Ctx = { tag: string }; + +function req(method: string, id = 1): JSONRPCRequest { + return { jsonrpc: '2.0', id, method }; +} + +// Test helper: register a raw handler directly (bypasses schema-wrap). +// Dispatcher's public registration is setRequestHandler (schema-wrapped); these +// unit tests target dispatch/middleware mechanics, not the schema layer. +function setRaw(d: Dispatcher, method: string, handler: (r: JSONRPCRequest, ctx: C) => Promise): void { + (d as unknown as { _handlers: Map })._handlers.set(method, handler); +} + +describe('Dispatcher', () => { + it('dispatches to a registered handler and wraps the result', async () => { + const d = new Dispatcher(); + setRaw(d, 'foo', async (r, ctx) => ({ value: `${ctx.tag}:${r.method}` })); + const res = await d.dispatch(req('foo'), { tag: 't' }); + expect(res).toEqual(okResponse(1, { value: 't:foo' })); + }); + + it('returns MethodNotFound when no handler matches', async () => { + const d = new Dispatcher(); + const res = await d.dispatch(req('nope'), { tag: 't' }); + expect(res).toEqual(errorResponse(1, ProtocolErrorCode.MethodNotFound, 'Method not found')); + }); + + it('falls back to fallbackHandler when set', async () => { + const d = new Dispatcher(); + d.fallbackHandler = async r => ({ fallback: r.method }); + const res = await d.dispatch(req('nope'), { tag: 't' }); + expect(res).toEqual(okResponse(1, { fallback: 'nope' })); + }); + + it('assertCanSetRequestHandler reflects registration only (not fallback)', () => { + const d = new Dispatcher(); + d.fallbackHandler = async () => ({}); + expect(() => d.assertCanSetRequestHandler('foo')).not.toThrow(); + setRaw(d, 'foo', async () => ({})); + expect(() => d.assertCanSetRequestHandler('foo')).toThrow(); + d.removeRequestHandler('foo'); + expect(() => d.assertCanSetRequestHandler('foo')).not.toThrow(); + }); + + it('fallbackHandler bypasses middleware (preserves Protocol._onrequest behavior)', async () => { + const d = new Dispatcher(); + let mwRan = false; + d.use(async (_r, _c, next) => { + mwRan = true; + return next(); + }); + d.fallbackHandler = async r => ({ fallback: r.method }); + const res = await d.dispatch(req('nope'), { tag: 't' }); + expect(res).toEqual(okResponse(1, { fallback: 'nope' })); + expect(mwRan).toBe(false); + }); + + it('surfaces ProtocolError code/message/data', async () => { + const d = new Dispatcher(); + setRaw(d, 'foo', async () => { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'bad', { detail: 1 }); + }); + const res = await d.dispatch(req('foo'), { tag: 't' }); + expect(res).toEqual(errorResponse(1, ProtocolErrorCode.InvalidParams, 'bad', { detail: 1 })); + }); + + it('preserves thrown error message and numeric code (matches Protocol._onrequest behavior)', async () => { + const d = new Dispatcher(); + setRaw(d, 'foo', async () => { + throw new Error('handler error message'); + }); + const res = await d.dispatch(req('foo'), { tag: 't' }); + expect(res).toEqual(errorResponse(1, ProtocolErrorCode.InternalError, 'handler error message')); + + setRaw(d, 'coded', async () => { + throw Object.assign(new Error('coded message'), { code: -31999, data: { x: 1 } }); + }); + const res2 = await d.dispatch(req('coded'), { tag: 't' }); + expect(res2).toEqual(errorResponse(1, -31999, 'coded message', { x: 1 })); + }); + + it('runs middleware in registration order around the handler', async () => { + const d = new Dispatcher(); + const order: string[] = []; + const mk = + (name: string): Middleware => + async (_r, _c, next) => { + order.push(`${name}:pre`); + const result = await next(); + order.push(`${name}:post`); + return result; + }; + d.use(mk('a')); + d.use(mk('b')); + setRaw(d, 'foo', async () => { + order.push('handler'); + return {}; + }); + await d.dispatch(req('foo'), { tag: 't' }); + expect(order).toEqual(['a:pre', 'b:pre', 'handler', 'b:post', 'a:post']); + }); + + it('lets middleware short-circuit without calling next', async () => { + const d = new Dispatcher(); + let handlerRan = false; + d.use(async () => ({ short: true })); + setRaw(d, 'foo', async () => { + handlerRan = true; + return {}; + }); + const res = await d.dispatch(req('foo'), { tag: 't' }); + expect(res).toEqual(okResponse(1, { short: true })); + expect(handlerRan).toBe(false); + }); + + it('lets middleware transform a thrown error into a result', async () => { + const d = new Dispatcher(); + d.use(async (_r, _c, next) => { + try { + return await next(); + } catch { + return { recovered: true }; + } + }); + setRaw(d, 'foo', async () => { + throw new Error('boom'); + }); + const res = await d.dispatch(req('foo'), { tag: 't' }); + expect(res).toEqual(okResponse(1, { recovered: true })); + }); + + it('does not run middleware when no handler matches', async () => { + const d = new Dispatcher(); + let ran = false; + d.use(async (_r, _c, next) => { + ran = true; + return next(); + }); + await d.dispatch(req('nope'), { tag: 't' }); + expect(ran).toBe(false); + }); + + it('supports concurrent dispatch on a shared instance', async () => { + const d = new Dispatcher(); + setRaw(d, 'foo', async (_r, ctx) => { + await new Promise(r => setTimeout(r, 5)); + return { tag: ctx.tag }; + }); + const results = await Promise.all([ + d.dispatch(req('foo', 1), { tag: 'a' }), + d.dispatch(req('foo', 2), { tag: 'b' }), + d.dispatch(req('foo', 3), { tag: 'c' }) + ]); + expect(results.map(r => 'result' in r && (r.result as Result & { tag: string }).tag)).toEqual(['a', 'b', 'c']); + }); +}); diff --git a/packages/core/test/shared/wrapHandler.test.ts b/packages/core/test/shared/wrapHandler.test.ts deleted file mode 100644 index 452b58194f..0000000000 --- a/packages/core/test/shared/wrapHandler.test.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { describe, expect, it } from 'vitest'; - -import { Protocol } from '../../src/shared/protocol.js'; -import type { BaseContext, JSONRPCRequest, Result } from '../../src/exports/public/index.js'; - -class TestProtocol extends Protocol { - protected buildContext(ctx: BaseContext): BaseContext { - return ctx; - } - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} -} - -describe('Protocol._wrapHandler', () => { - it('routes setRequestHandler registration through _wrapHandler', () => { - const seen: string[] = []; - class SpyProtocol extends TestProtocol { - protected override _wrapHandler( - method: string, - handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise - ): (request: JSONRPCRequest, ctx: BaseContext) => Promise { - seen.push(method); - return handler; - } - } - const p = new SpyProtocol(); - seen.length = 0; - p.setRequestHandler('tools/list', () => ({ tools: [] })); - p.setRequestHandler('resources/list', () => ({ resources: [] })); - expect(seen).toEqual(['tools/list', 'resources/list']); - }); -}); diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index 89e1de1817..55d8a44b17 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -12,20 +12,19 @@ import type { Implementation, InitializeRequest, InitializeResult, - JSONRPCRequest, JsonSchemaType, jsonSchemaValidator, ListRootsRequest, LoggingLevel, LoggingMessageNotification, MessageExtraInfo, + Middleware, NotificationMethod, NotificationOptions, ProtocolOptions, RequestMethod, RequestOptions, ResourceUpdatedNotification, - Result, ServerCapabilities, ServerContext, ToolResultContent, @@ -104,6 +103,8 @@ export class Server extends Protocol { this._instructions = options?.instructions; this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator(); + this.dispatcher.use(Server._callToolResultMiddleware); + this.setRequestHandler('initialize', request => this._oninitialize(request)); this.setNotificationHandler('notifications/initialized', () => this.oninitialized?.()); @@ -177,35 +178,26 @@ export class Server extends Protocol { /** * Enforces server-side validation for `tools/call` results regardless of how the - * handler was registered. + * handler was registered. Installed as a {@linkcode Dispatcher} middleware so + * it applies to both the legacy `_onrequest` path and the 2026-06 dispatch path. */ - protected override _wrapHandler( - method: string, - handler: (request: JSONRPCRequest, ctx: ServerContext) => Promise - ): (request: JSONRPCRequest, ctx: ServerContext) => Promise { - if (method !== 'tools/call') { - return handler; + private static readonly _callToolResultMiddleware: Middleware = async (request, _ctx, next) => { + if (request.method !== 'tools/call') { + return next(); } - return async (request, ctx) => { - const validatedRequest = parseSchema(CallToolRequestSchema, request); - if (!validatedRequest.success) { - const errorMessage = - validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`); - } - - const result = await handler(request, ctx); - - const validationResult = parseSchema(CallToolResultSchema, result); - if (!validationResult.success) { - const errorMessage = - validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); - } - - return validationResult.data; - }; - } + const validatedRequest = parseSchema(CallToolRequestSchema, request); + if (!validatedRequest.success) { + const errorMessage = validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`); + } + const result = await next(); + const validationResult = parseSchema(CallToolResultSchema, result); + if (!validationResult.success) { + const errorMessage = validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); + } + return validationResult.data; + }; protected assertCapabilityForMethod(method: RequestMethod | string): void { switch (method) {