Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/wraphandler-hook.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
'@modelcontextprotocol/core': patch
'@modelcontextprotocol/client': patch
'@modelcontextprotocol/server': patch
---

refactor: subclasses override `_wrapHandler` hook instead of redeclaring `setRequestHandler`.
78 changes: 35 additions & 43 deletions packages/client/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import type {
ClientContext,
ClientNotification,
ClientRequest,
ClientResult,
CompleteRequest,
GetPromptRequest,
Implementation,
JSONRPCRequest,
JsonSchemaType,
JsonSchemaValidator,
jsonSchemaValidator,
Expand All @@ -26,8 +26,7 @@ import type {
ReadResourceRequest,
RequestMethod,
RequestOptions,
RequestTypeMap,
ResultTypeMap,
Result,
ServerCapabilities,
SubscribeRequest,
TaskManagerOptions,
Expand Down Expand Up @@ -200,6 +199,28 @@ export type ClientOptions = ProtocolOptions & {
*
* The client will automatically begin the initialization flow with the server when {@linkcode connect} is called.
*
* To handle server-initiated requests (sampling, elicitation, roots), call {@linkcode setRequestHandler}.
* The client must declare the corresponding capability for the handler to be accepted. For
* `sampling/createMessage` and `elicitation/create`, the handler is automatically wrapped with
* schema validation for both the incoming request and the returned result.
*
* @example Handling a sampling request
* ```ts source="./client.examples.ts#Client_setRequestHandler_sampling"
* client.setRequestHandler('sampling/createMessage', async request => {
* const lastMessage = request.params.messages.at(-1);
* console.log('Sampling request:', lastMessage);
*
* // In production, send messages to your LLM here
* return {
* model: 'my-model',
* role: 'assistant' as const,
* content: {
* type: 'text' as const,
* text: 'Response from the model'
* }
* };
* });
* ```
*/
export class Client extends Protocol<ClientContext> {
private _serverCapabilities?: ServerCapabilities;
Expand Down Expand Up @@ -308,37 +329,15 @@ export class Client extends Protocol<ClientContext> {
}

/**
* Registers a handler for server-initiated requests (sampling, elicitation, roots).
* The client must declare the corresponding capability for the handler to be accepted.
* Replaces any previously registered handler for the same method.
*
* For `sampling/createMessage` and `elicitation/create`, the handler is automatically
* wrapped with schema validation for both the incoming request and the returned result.
*
* @example Handling a sampling request
* ```ts source="./client.examples.ts#Client_setRequestHandler_sampling"
* client.setRequestHandler('sampling/createMessage', async request => {
* const lastMessage = request.params.messages.at(-1);
* console.log('Sampling request:', lastMessage);
*
* // In production, send messages to your LLM here
* return {
* model: 'my-model',
* role: 'assistant' as const,
* content: {
* type: 'text' as const,
* text: 'Response from the model'
* }
* };
* });
* ```
* Enforces client-side validation for `elicitation/create` and `sampling/createMessage`
* regardless of how the handler was registered.
*/
public override setRequestHandler<M extends RequestMethod>(
method: M,
handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
): void {
protected override _wrapHandler(
method: string,
handler: (request: JSONRPCRequest, ctx: ClientContext) => Promise<Result>
): (request: JSONRPCRequest, ctx: ClientContext) => Promise<Result> {
if (method === 'elicitation/create') {
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
return async (request, ctx) => {
const validatedRequest = parseSchema(ElicitRequestSchema, request);
if (!validatedRequest.success) {
// Type guard: if success is false, error is guaranteed to exist
Expand All @@ -359,7 +358,7 @@ export class Client extends Protocol<ClientContext> {
throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests');
}

const result = await Promise.resolve(handler(request, ctx));
const result = await handler(request, ctx);

// When task creation is requested, validate and return CreateTaskResult
if (params.task) {
Expand Down Expand Up @@ -402,13 +401,10 @@ export class Client extends Protocol<ClientContext> {

return validatedResult;
};

// Install the wrapped handler
return super.setRequestHandler(method, wrappedHandler);
}

if (method === 'sampling/createMessage') {
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
return async (request, ctx) => {
const validatedRequest = parseSchema(CreateMessageRequestSchema, request);
if (!validatedRequest.success) {
const errorMessage =
Expand All @@ -418,7 +414,7 @@ export class Client extends Protocol<ClientContext> {

const { params } = validatedRequest.data;

const result = await Promise.resolve(handler(request, ctx));
const result = await handler(request, ctx);

// When task creation is requested, validate and return CreateTaskResult
if (params.task) {
Expand All @@ -445,13 +441,9 @@ export class Client extends Protocol<ClientContext> {

return validationResult.data;
};

// Install the wrapped handler
return super.setRequestHandler(method, wrappedHandler);
}

// Other handlers use default behavior
return super.setRequestHandler(method, handler);
return handler;
}

protected assertCapability(capability: keyof ServerCapabilities, method: string): void {
Expand Down
21 changes: 18 additions & 3 deletions packages/core/src/shared/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1007,15 +1007,30 @@ export abstract class Protocol<ContextT extends BaseContext> {
*/
setRequestHandler<M extends RequestMethod>(
method: M,
handler: (request: RequestTypeMap[M], ctx: ContextT) => Result | Promise<Result>
handler: (request: RequestTypeMap[M], ctx: ContextT) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
): void {
this.assertRequestHandlerCapability(method);
const schema = getRequestSchema(method);

this._requestHandlers.set(method, (request, ctx) => {
const stored = (request: JSONRPCRequest, ctx: ContextT): Promise<Result> => {
const parsed = schema.parse(request) as RequestTypeMap[M];
return Promise.resolve(handler(parsed, ctx));
});
};
this._requestHandlers.set(method, this._wrapHandler(method, stored));
Comment thread
claude[bot] marked this conversation as resolved.
}

/**
* Hook for subclasses to wrap a registered request handler with role-specific
* validation or behavior (e.g. `Server` validates `tools/call` results, `Client`
* validates `elicitation/create` mode and result). The default implementation is identity.
*
* Subclasses overriding this hook avoid redeclaring `setRequestHandler` and its JSDoc.
*/
protected _wrapHandler(
_method: string,
handler: (request: JSONRPCRequest, ctx: ContextT) => Promise<Result>
): (request: JSONRPCRequest, ctx: ContextT) => Promise<Result> {
return handler;
}

/**
Expand Down
35 changes: 35 additions & 0 deletions packages/core/test/shared/wrapHandler.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { describe, expect, it } from 'vitest';

import { Protocol } from '../../src/shared/protocol.js';
import type { BaseContext, JSONRPCRequest, Result } from '../../src/exports/public/index.js';

class TestProtocol extends Protocol<BaseContext> {
protected buildContext(ctx: BaseContext): BaseContext {
return ctx;
}
protected assertCapabilityForMethod(): void {}
protected assertNotificationCapability(): void {}
protected assertRequestHandlerCapability(): void {}
protected assertTaskCapability(): void {}
protected assertTaskHandlerCapability(): void {}
}

describe('Protocol._wrapHandler', () => {
it('routes setRequestHandler registration through _wrapHandler', () => {
const seen: string[] = [];
class SpyProtocol extends TestProtocol {
protected override _wrapHandler(
method: string,
handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise<Result>
): (request: JSONRPCRequest, ctx: BaseContext) => Promise<Result> {
seen.push(method);
return handler;
}
}
const p = new SpyProtocol();
seen.length = 0;
p.setRequestHandler('tools/list', () => ({ tools: [] }));
p.setRequestHandler('resources/list', () => ({ resources: [] }));
expect(seen).toEqual(['tools/list', 'resources/list']);
});
});
83 changes: 39 additions & 44 deletions packages/server/src/server/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import type {
Implementation,
InitializeRequest,
InitializeResult,
JSONRPCRequest,
JsonSchemaType,
jsonSchemaValidator,
ListRootsRequest,
Expand All @@ -23,12 +24,10 @@ import type {
ProtocolOptions,
RequestMethod,
RequestOptions,
RequestTypeMap,
ResourceUpdatedNotification,
ResultTypeMap,
Result,
ServerCapabilities,
ServerContext,
ServerResult,
TaskManagerOptions,
ToolResultContent,
ToolUseContent
Expand Down Expand Up @@ -220,55 +219,51 @@ export class Server extends Protocol<ServerContext> {
}

/**
* Override request handler registration to enforce server-side validation for `tools/call`.
* Enforces server-side validation for `tools/call` results regardless of how the
* handler was registered.
*/
public override setRequestHandler<M extends RequestMethod>(
method: M,
handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
): void {
if (method === 'tools/call') {
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ServerContext): Promise<ServerResult> => {
const validatedRequest = parseSchema(CallToolRequestSchema, request);
if (!validatedRequest.success) {
const errorMessage =
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`);
}

const { params } = validatedRequest.data;
protected override _wrapHandler(
method: string,
handler: (request: JSONRPCRequest, ctx: ServerContext) => Promise<Result>
): (request: JSONRPCRequest, ctx: ServerContext) => Promise<Result> {
if (method !== 'tools/call') {
return handler;
}
return async (request, ctx) => {
const validatedRequest = parseSchema(CallToolRequestSchema, request);
if (!validatedRequest.success) {
const errorMessage =
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`);
}

const result = await Promise.resolve(handler(request, ctx));
const { params } = validatedRequest.data;

// When task creation is requested, validate and return CreateTaskResult
if (params.task) {
const taskValidationResult = parseSchema(CreateTaskResultSchema, result);
if (!taskValidationResult.success) {
const errorMessage =
taskValidationResult.error instanceof Error
? taskValidationResult.error.message
: String(taskValidationResult.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`);
}
return taskValidationResult.data;
}
const result = await handler(request, ctx);

// For non-task requests, validate against CallToolResultSchema
const validationResult = parseSchema(CallToolResultSchema, result);
if (!validationResult.success) {
// When task creation is requested, validate and return CreateTaskResult
if (params.task) {
const taskValidationResult = parseSchema(CreateTaskResultSchema, result);
if (!taskValidationResult.success) {
const errorMessage =
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`);
taskValidationResult.error instanceof Error
? taskValidationResult.error.message
: String(taskValidationResult.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`);
}
return taskValidationResult.data;
}

return validationResult.data;
};

// Install the wrapped handler
return super.setRequestHandler(method, wrappedHandler);
}
// For non-task requests, validate against CallToolResultSchema
const validationResult = parseSchema(CallToolResultSchema, result);
if (!validationResult.success) {
const errorMessage =
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`);
}

// Other handlers use default behavior
return super.setRequestHandler(method, handler);
return validationResult.data;
};
}

protected assertCapabilityForMethod(method: RequestMethod): void {
Expand Down
Loading