diff --git a/.changeset/fix-adapt-oauth-provider-expiry.md b/.changeset/fix-adapt-oauth-provider-expiry.md new file mode 100644 index 000000000..665f1cefa --- /dev/null +++ b/.changeset/fix-adapt-oauth-provider-expiry.md @@ -0,0 +1,5 @@ +--- +'@modelcontextprotocol/client': patch +--- + +Fix `adaptOAuthProvider` returning expired tokens on long-running StreamableHTTP connections. The adapter now intercepts `saveTokens` to record when each token was issued, then checks elapsed time against `expires_in` (with a 60-second buffer) before returning the token. Expired or near-expiry tokens return `undefined`, causing the transport to omit the `Authorization` header and trigger a 401 → `onUnauthorized` → refresh flow. diff --git a/packages/client/src/client/auth.ts b/packages/client/src/client/auth.ts index 93a03ece6..696d8efe9 100644 --- a/packages/client/src/client/auth.ts +++ b/packages/client/src/client/auth.ts @@ -118,14 +118,39 @@ export async function handleOAuthUnauthorized(provider: OAuthClientProvider, ctx * transports consume. Called once at transport construction — the transport stores * the adapted provider for `_commonHeaders()` and 401 handling, while keeping the * original `OAuthClientProvider` for OAuth-specific paths (`finishAuth()`, 403 upscoping). + * + * Uses a Proxy to intercept `saveTokens` calls made by the auth flow so that + * `token()` can detect expiry even when the provider stores the raw `expires_in` + * from the server response (e.g. 3600) rather than recomputing remaining seconds. + * The original provider is never mutated; spy assertions in tests remain valid. */ export function adaptOAuthProvider(provider: OAuthClientProvider): AuthProvider { + let issuedAt: number | undefined; + + // Proxy intercepts saveTokens to record the issue timestamp without mutating the provider + const proxied: OAuthClientProvider = new Proxy(provider, { + get(target, prop, receiver) { + if (prop === 'saveTokens') { + return async (tokens: OAuthTokens) => { + issuedAt = Math.floor(Date.now() / 1000); + return target.saveTokens(tokens); + }; + } + return Reflect.get(target, prop, receiver); + } + }); + return { token: async () => { const tokens = await provider.tokens(); - return tokens?.access_token; + if (!tokens?.access_token) return; + if (tokens.expires_in !== undefined && issuedAt !== undefined) { + const elapsed = Math.floor(Date.now() / 1000) - issuedAt; + if (elapsed >= tokens.expires_in - 60) return; + } + return tokens.access_token; }, - onUnauthorized: async ctx => handleOAuthUnauthorized(provider, ctx) + onUnauthorized: async ctx => handleOAuthUnauthorized(proxied, ctx) }; } diff --git a/packages/client/test/client/auth.test.ts b/packages/client/test/client/auth.test.ts index 53263ad8c..70e46fefe 100644 --- a/packages/client/test/client/auth.test.ts +++ b/packages/client/test/client/auth.test.ts @@ -5,6 +5,7 @@ import { expect, vi } from 'vitest'; import type { OAuthClientProvider } from '../../src/client/auth.js'; import { + adaptOAuthProvider, auth, buildDiscoveryUrls, determineScope, @@ -4055,4 +4056,204 @@ describe('OAuth Authorization', () => { }); }); }); + + describe('adaptOAuthProvider', () => { + // Helper: mock fetch for a minimal OAuth token refresh flow. + // Handles discovery (404 for PRM, returns auth-server metadata) and + // the token endpoint (returns newTokens). + function mockRefreshFlow(newTokens: OAuthTokens): void { + mockFetch.mockImplementation((url: string | URL) => { + const urlStr = url.toString(); + if (urlStr.includes('oauth-protected-resource')) { + return Promise.resolve({ ok: false, status: 404, text: () => Promise.resolve('') }); + } + if (urlStr.includes('oauth-authorization-server') || urlStr.includes('openid-configuration')) { + return Promise.resolve({ + ok: true, + status: 200, + json: () => + Promise.resolve({ + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'] + }) + }); + } + if (urlStr.includes('/token')) { + return Promise.resolve({ + ok: true, + status: 200, + json: () => Promise.resolve(newTokens) + }); + } + return Promise.reject(new Error(`Unexpected fetch: ${urlStr}`)); + }); + } + + // Helper: trigger onUnauthorized so that auth() calls proxied.saveTokens, recording issuedAt. + async function triggerSaveTokensViaUnauthorized(adapted: ReturnType): Promise { + const mockResponse = { ok: false, status: 401, headers: new Headers() } as unknown as Response; + await adapted.onUnauthorized!({ + response: mockResponse, + serverUrl: new URL('https://example.com/'), + fetchFn: mockFetch + }); + } + + it('returns access_token when no expires_in is set', async () => { + const provider: OAuthClientProvider = { + get redirectUrl() { + return 'https://example.com/callback'; + }, + get clientMetadata(): OAuthClientMetadata { + return { redirect_uris: ['https://example.com/callback'] }; + }, + clientInformation: () => ({ client_id: 'test-client' }), + tokens: () => ({ access_token: 'tok', token_type: 'bearer' }), + saveTokens: vi.fn(), + redirectToAuthorization: () => {}, + saveCodeVerifier: () => {}, + codeVerifier: () => 'verifier' + }; + const adapted = adaptOAuthProvider(provider); + expect(await adapted.token()).toBe('tok'); + }); + + it('returns access_token when issuedAt is unknown and expires_in is present', async () => { + // Token loaded from storage before saveTokens was intercepted — no issuedAt, so no expiry check + const provider: OAuthClientProvider = { + get redirectUrl() { + return 'https://example.com/callback'; + }, + get clientMetadata(): OAuthClientMetadata { + return { redirect_uris: ['https://example.com/callback'] }; + }, + clientInformation: () => ({ client_id: 'test-client' }), + tokens: () => ({ access_token: 'tok', token_type: 'bearer', expires_in: 3600 }), + saveTokens: vi.fn(), + redirectToAuthorization: () => {}, + saveCodeVerifier: () => {}, + codeVerifier: () => 'verifier' + }; + const adapted = adaptOAuthProvider(provider); + expect(await adapted.token()).toBe('tok'); + }); + + it('returns undefined when no tokens', async () => { + const provider: OAuthClientProvider = { + get redirectUrl() { + return 'https://example.com/callback'; + }, + get clientMetadata(): OAuthClientMetadata { + return { redirect_uris: ['https://example.com/callback'] }; + }, + clientInformation: () => ({ client_id: 'test-client' }), + tokens: () => undefined, + saveTokens: vi.fn(), + redirectToAuthorization: () => {}, + saveCodeVerifier: () => {}, + codeVerifier: () => 'verifier' + }; + const adapted = adaptOAuthProvider(provider); + expect(await adapted.token()).toBeUndefined(); + }); + + // Helper: create a minimal client_credentials provider for expiry tests. + // prepareTokenRequest is required for the non-interactive flow. + function makeClientCredentialsProvider(): { provider: OAuthClientProvider; getCurrentTokens: () => OAuthTokens | undefined } { + let currentTokens: OAuthTokens | undefined; + const provider: OAuthClientProvider = { + get redirectUrl() { + return undefined; + }, + get clientMetadata(): OAuthClientMetadata { + return { redirect_uris: [], grant_types: ['client_credentials'] }; + }, + clientInformation: () => ({ client_id: 'test-client', client_secret: 'secret' }), + tokens: () => currentTokens, + saveTokens: vi.fn(tokens => { + currentTokens = tokens; + }), + redirectToAuthorization: () => {}, + saveCodeVerifier: () => {}, + codeVerifier: () => 'verifier', + prepareTokenRequest: () => new URLSearchParams({ grant_type: 'client_credentials' }) + }; + return { provider, getCurrentTokens: () => currentTokens }; + } + + it('returns access_token for a freshly saved token (issuedAt just set)', async () => { + const { provider } = makeClientCredentialsProvider(); + const adapted = adaptOAuthProvider(provider); + + mockRefreshFlow({ access_token: 'fresh', token_type: 'Bearer', expires_in: 3600 }); + await triggerSaveTokensViaUnauthorized(adapted); + + expect(await adapted.token()).toBe('fresh'); + }); + + it('returns undefined for a token saved more than (expires_in - 60)s ago', async () => { + vi.useFakeTimers(); + try { + const { provider } = makeClientCredentialsProvider(); + const adapted = adaptOAuthProvider(provider); + + // Trigger auth flow at t=0; saves token with expires_in=3600 + mockRefreshFlow({ access_token: 'tok', token_type: 'Bearer', expires_in: 3600 }); + await triggerSaveTokensViaUnauthorized(adapted); + + // Fast-forward to just before the 60-second buffer (3539 s elapsed, 61 s remaining) + vi.advanceTimersByTime(3539 * 1000); + expect(await adapted.token()).toBe('tok'); + + // Fast-forward 2 more seconds (3541 s elapsed, 59 s remaining — inside buffer) + vi.advanceTimersByTime(2000); + expect(await adapted.token()).toBeUndefined(); + } finally { + vi.useRealTimers(); + } + }); + + it('returns undefined for an already-expired token (elapsed >= expires_in)', async () => { + vi.useFakeTimers(); + try { + const { provider } = makeClientCredentialsProvider(); + const adapted = adaptOAuthProvider(provider); + + mockRefreshFlow({ access_token: 'expired', token_type: 'Bearer', expires_in: 3600 }); + await triggerSaveTokensViaUnauthorized(adapted); + + // Fast-forward past full expiry + vi.advanceTimersByTime(4000 * 1000); + expect(await adapted.token()).toBeUndefined(); + } finally { + vi.useRealTimers(); + } + }); + + it('resets issuedAt after a second saveTokens call, making refreshed token valid again', async () => { + vi.useFakeTimers(); + try { + const { provider } = makeClientCredentialsProvider(); + const adapted = adaptOAuthProvider(provider); + + // First auth at t=0 + mockRefreshFlow({ access_token: 'first', token_type: 'Bearer', expires_in: 3600 }); + await triggerSaveTokensViaUnauthorized(adapted); + expect(await adapted.token()).toBe('first'); + + // Advance past expiry buffer — token should be stale + vi.advanceTimersByTime(3600 * 1000); + expect(await adapted.token()).toBeUndefined(); + + // Second auth at t=3600 — new token saved, issuedAt reset + mockRefreshFlow({ access_token: 'refreshed', token_type: 'Bearer', expires_in: 3600 }); + await triggerSaveTokensViaUnauthorized(adapted); + expect(await adapted.token()).toBe('refreshed'); + } finally { + vi.useRealTimers(); + } + }); + }); });