diff --git a/.changeset/wraphandler-hook.md b/.changeset/wraphandler-hook.md new file mode 100644 index 000000000..935f57658 --- /dev/null +++ b/.changeset/wraphandler-hook.md @@ -0,0 +1,7 @@ +--- +'@modelcontextprotocol/core': patch +'@modelcontextprotocol/client': patch +'@modelcontextprotocol/server': patch +--- + +refactor: subclasses override `_wrapHandler` hook instead of redeclaring `setRequestHandler`. diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 21a43bd15..4a279e532 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -6,10 +6,10 @@ import type { ClientContext, ClientNotification, ClientRequest, - ClientResult, CompleteRequest, GetPromptRequest, Implementation, + JSONRPCRequest, JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator, @@ -26,8 +26,7 @@ import type { ReadResourceRequest, RequestMethod, RequestOptions, - RequestTypeMap, - ResultTypeMap, + Result, ServerCapabilities, SubscribeRequest, TaskManagerOptions, @@ -200,6 +199,28 @@ export type ClientOptions = ProtocolOptions & { * * The client will automatically begin the initialization flow with the server when {@linkcode connect} is called. * + * To handle server-initiated requests (sampling, elicitation, roots), call {@linkcode setRequestHandler}. + * The client must declare the corresponding capability for the handler to be accepted. For + * `sampling/createMessage` and `elicitation/create`, the handler is automatically wrapped with + * schema validation for both the incoming request and the returned result. + * + * @example Handling a sampling request + * ```ts source="./client.examples.ts#Client_setRequestHandler_sampling" + * client.setRequestHandler('sampling/createMessage', async request => { + * const lastMessage = request.params.messages.at(-1); + * console.log('Sampling request:', lastMessage); + * + * // In production, send messages to your LLM here + * return { + * model: 'my-model', + * role: 'assistant' as const, + * content: { + * type: 'text' as const, + * text: 'Response from the model' + * } + * }; + * }); + * ``` */ export class Client extends Protocol { private _serverCapabilities?: ServerCapabilities; @@ -308,37 +329,15 @@ export class Client extends Protocol { } /** - * Registers a handler for server-initiated requests (sampling, elicitation, roots). - * The client must declare the corresponding capability for the handler to be accepted. - * Replaces any previously registered handler for the same method. - * - * For `sampling/createMessage` and `elicitation/create`, the handler is automatically - * wrapped with schema validation for both the incoming request and the returned result. - * - * @example Handling a sampling request - * ```ts source="./client.examples.ts#Client_setRequestHandler_sampling" - * client.setRequestHandler('sampling/createMessage', async request => { - * const lastMessage = request.params.messages.at(-1); - * console.log('Sampling request:', lastMessage); - * - * // In production, send messages to your LLM here - * return { - * model: 'my-model', - * role: 'assistant' as const, - * content: { - * type: 'text' as const, - * text: 'Response from the model' - * } - * }; - * }); - * ``` + * Enforces client-side validation for `elicitation/create` and `sampling/createMessage` + * regardless of how the handler was registered. */ - public override setRequestHandler( - method: M, - handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise - ): void { + protected override _wrapHandler( + method: string, + handler: (request: JSONRPCRequest, ctx: ClientContext) => Promise + ): (request: JSONRPCRequest, ctx: ClientContext) => Promise { if (method === 'elicitation/create') { - const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise => { + return async (request, ctx) => { const validatedRequest = parseSchema(ElicitRequestSchema, request); if (!validatedRequest.success) { // Type guard: if success is false, error is guaranteed to exist @@ -359,7 +358,7 @@ export class Client extends Protocol { throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests'); } - const result = await Promise.resolve(handler(request, ctx)); + const result = await handler(request, ctx); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -402,13 +401,10 @@ export class Client extends Protocol { return validatedResult; }; - - // Install the wrapped handler - return super.setRequestHandler(method, wrappedHandler); } if (method === 'sampling/createMessage') { - const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise => { + return async (request, ctx) => { const validatedRequest = parseSchema(CreateMessageRequestSchema, request); if (!validatedRequest.success) { const errorMessage = @@ -418,7 +414,7 @@ export class Client extends Protocol { const { params } = validatedRequest.data; - const result = await Promise.resolve(handler(request, ctx)); + const result = await handler(request, ctx); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -445,13 +441,9 @@ export class Client extends Protocol { return validationResult.data; }; - - // Install the wrapped handler - return super.setRequestHandler(method, wrappedHandler); } - // Other handlers use default behavior - return super.setRequestHandler(method, handler); + return handler; } protected assertCapability(capability: keyof ServerCapabilities, method: string): void { diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 57eab6932..799518832 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -1007,15 +1007,30 @@ export abstract class Protocol { */ setRequestHandler( method: M, - handler: (request: RequestTypeMap[M], ctx: ContextT) => Result | Promise + handler: (request: RequestTypeMap[M], ctx: ContextT) => ResultTypeMap[M] | Promise ): void { this.assertRequestHandlerCapability(method); const schema = getRequestSchema(method); - this._requestHandlers.set(method, (request, ctx) => { + const stored = (request: JSONRPCRequest, ctx: ContextT): Promise => { const parsed = schema.parse(request) as RequestTypeMap[M]; return Promise.resolve(handler(parsed, ctx)); - }); + }; + this._requestHandlers.set(method, this._wrapHandler(method, stored)); + } + + /** + * Hook for subclasses to wrap a registered request handler with role-specific + * validation or behavior (e.g. `Server` validates `tools/call` results, `Client` + * validates `elicitation/create` mode and result). The default implementation is identity. + * + * Subclasses overriding this hook avoid redeclaring `setRequestHandler` and its JSDoc. + */ + protected _wrapHandler( + _method: string, + handler: (request: JSONRPCRequest, ctx: ContextT) => Promise + ): (request: JSONRPCRequest, ctx: ContextT) => Promise { + return handler; } /** diff --git a/packages/core/test/shared/wrapHandler.test.ts b/packages/core/test/shared/wrapHandler.test.ts new file mode 100644 index 000000000..6a6e33fb0 --- /dev/null +++ b/packages/core/test/shared/wrapHandler.test.ts @@ -0,0 +1,35 @@ +import { describe, expect, it } from 'vitest'; + +import { Protocol } from '../../src/shared/protocol.js'; +import type { BaseContext, JSONRPCRequest, Result } from '../../src/exports/public/index.js'; + +class TestProtocol extends Protocol { + protected buildContext(ctx: BaseContext): BaseContext { + return ctx; + } + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} +} + +describe('Protocol._wrapHandler', () => { + it('routes setRequestHandler registration through _wrapHandler', () => { + const seen: string[] = []; + class SpyProtocol extends TestProtocol { + protected override _wrapHandler( + method: string, + handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise + ): (request: JSONRPCRequest, ctx: BaseContext) => Promise { + seen.push(method); + return handler; + } + } + const p = new SpyProtocol(); + seen.length = 0; + p.setRequestHandler('tools/list', () => ({ tools: [] })); + p.setRequestHandler('resources/list', () => ({ resources: [] })); + expect(seen).toEqual(['tools/list', 'resources/list']); + }); +}); diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index 4361f3e1e..8324c8dc1 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -12,6 +12,7 @@ import type { Implementation, InitializeRequest, InitializeResult, + JSONRPCRequest, JsonSchemaType, jsonSchemaValidator, ListRootsRequest, @@ -23,12 +24,10 @@ import type { ProtocolOptions, RequestMethod, RequestOptions, - RequestTypeMap, ResourceUpdatedNotification, - ResultTypeMap, + Result, ServerCapabilities, ServerContext, - ServerResult, TaskManagerOptions, ToolResultContent, ToolUseContent @@ -220,55 +219,51 @@ export class Server extends Protocol { } /** - * Override request handler registration to enforce server-side validation for `tools/call`. + * Enforces server-side validation for `tools/call` results regardless of how the + * handler was registered. */ - public override setRequestHandler( - method: M, - handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise - ): void { - if (method === 'tools/call') { - const wrappedHandler = async (request: RequestTypeMap[M], ctx: ServerContext): Promise => { - const validatedRequest = parseSchema(CallToolRequestSchema, request); - if (!validatedRequest.success) { - const errorMessage = - validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`); - } - - const { params } = validatedRequest.data; + protected override _wrapHandler( + method: string, + handler: (request: JSONRPCRequest, ctx: ServerContext) => Promise + ): (request: JSONRPCRequest, ctx: ServerContext) => Promise { + if (method !== 'tools/call') { + return handler; + } + return async (request, ctx) => { + const validatedRequest = parseSchema(CallToolRequestSchema, request); + if (!validatedRequest.success) { + const errorMessage = + validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`); + } - const result = await Promise.resolve(handler(request, ctx)); + const { params } = validatedRequest.data; - // When task creation is requested, validate and return CreateTaskResult - if (params.task) { - const taskValidationResult = parseSchema(CreateTaskResultSchema, result); - if (!taskValidationResult.success) { - const errorMessage = - taskValidationResult.error instanceof Error - ? taskValidationResult.error.message - : String(taskValidationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); - } - return taskValidationResult.data; - } + const result = await handler(request, ctx); - // For non-task requests, validate against CallToolResultSchema - const validationResult = parseSchema(CallToolResultSchema, result); - if (!validationResult.success) { + // When task creation is requested, validate and return CreateTaskResult + if (params.task) { + const taskValidationResult = parseSchema(CreateTaskResultSchema, result); + if (!taskValidationResult.success) { const errorMessage = - validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); + taskValidationResult.error instanceof Error + ? taskValidationResult.error.message + : String(taskValidationResult.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); } + return taskValidationResult.data; + } - return validationResult.data; - }; - - // Install the wrapped handler - return super.setRequestHandler(method, wrappedHandler); - } + // For non-task requests, validate against CallToolResultSchema + const validationResult = parseSchema(CallToolResultSchema, result); + if (!validationResult.success) { + const errorMessage = + validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); + } - // Other handlers use default behavior - return super.setRequestHandler(method, handler); + return validationResult.data; + }; } protected assertCapabilityForMethod(method: RequestMethod): void {