diff --git a/packages/@n8n/db/src/migrations/common/1784000000023-CreateInstanceAiMcpRegistryConnectionTable.ts b/packages/@n8n/db/src/migrations/common/1784000000023-CreateInstanceAiMcpRegistryConnectionTable.ts new file mode 100644 index 00000000000..2aa028aa7ca --- /dev/null +++ b/packages/@n8n/db/src/migrations/common/1784000000023-CreateInstanceAiMcpRegistryConnectionTable.ts @@ -0,0 +1,40 @@ +import type { MigrationContext, ReversibleMigration } from '../migration-types'; + +const table = 'instance_ai_mcp_registry_connections'; + +export class CreateInstanceAiMcpRegistryConnectionTable1784000000023 + implements ReversibleMigration +{ + async up({ schemaBuilder: { createTable, column } }: MigrationContext) { + await createTable(table) + .withColumns( + column('id').uuid.primary, + column('credentialId').varchar(36).notNull, + column('serverSlug').varchar(255).notNull, + column('toolFilter').json.comment( + 'Optional MCP tool filter per registry connection: { mode: "allow" | "exclude", tools: string[] }', + ), + column('userId').uuid.notNull, + ) + .withIndexOn(['userId', 'serverSlug', 'credentialId'], true) + .withForeignKey('credentialId', { + tableName: 'credentials_entity', + columnName: 'id', + onDelete: 'CASCADE', + }) + .withForeignKey('serverSlug', { + tableName: 'mcp_registry_server', + columnName: 'slug', + onDelete: 'CASCADE', + }) + .withForeignKey('userId', { + tableName: 'user', + columnName: 'id', + onDelete: 'CASCADE', + }).withTimestamps; + } + + async down({ schemaBuilder: { dropTable } }: MigrationContext) { + await dropTable(table); + } +} diff --git a/packages/@n8n/db/src/migrations/postgresdb/index.ts b/packages/@n8n/db/src/migrations/postgresdb/index.ts index bccd4c4b029..a4a0dbed63b 100644 --- a/packages/@n8n/db/src/migrations/postgresdb/index.ts +++ b/packages/@n8n/db/src/migrations/postgresdb/index.ts @@ -198,6 +198,7 @@ import { AddCustomTelemetryTagsToProject1784000000019 } from '../common/17840000 import { CreateWorkflowPublicationOutboxTable1784000000020 } from '../common/1784000000020-CreateWorkflowPublicationOutboxTable'; import { CreateAgentTaskDefinitionTable1784000000021 } from '../common/1784000000021-CreateAgentTaskDefinitionTable'; import { AddSubAgentLinkageToAgentExecutionThreads1784000000022 } from '../common/1784000000022-AddSubAgentLinkageToAgentExecutionThreads'; +import { CreateInstanceAiMcpRegistryConnectionTable1784000000023 } from '../common/1784000000023-CreateInstanceAiMcpRegistryConnectionTable'; import type { Migration } from '../migration-types'; export const postgresMigrations: Migration[] = [ @@ -401,4 +402,5 @@ export const postgresMigrations: Migration[] = [ CreateWorkflowPublicationOutboxTable1784000000020, CreateAgentTaskDefinitionTable1784000000021, AddSubAgentLinkageToAgentExecutionThreads1784000000022, + CreateInstanceAiMcpRegistryConnectionTable1784000000023, ]; diff --git a/packages/@n8n/db/src/migrations/sqlite/index.ts b/packages/@n8n/db/src/migrations/sqlite/index.ts index b18cec49af0..be5a6bb3581 100644 --- a/packages/@n8n/db/src/migrations/sqlite/index.ts +++ b/packages/@n8n/db/src/migrations/sqlite/index.ts @@ -190,6 +190,7 @@ import { CreateAgentFilesTable1784000000018 } from '../common/1784000000018-Crea import { AddCustomTelemetryTagsToProject1784000000019 } from '../common/1784000000019-AddCustomTelemetryTagsToProject'; import { CreateWorkflowPublicationOutboxTable1784000000020 } from '../common/1784000000020-CreateWorkflowPublicationOutboxTable'; import { AddSubAgentLinkageToAgentExecutionThreads1784000000022 } from '../common/1784000000022-AddSubAgentLinkageToAgentExecutionThreads'; +import { CreateInstanceAiMcpRegistryConnectionTable1784000000023 } from '../common/1784000000023-CreateInstanceAiMcpRegistryConnectionTable'; import type { Migration } from '../migration-types'; import { CreateAgentTaskDefinitionTable1784000000021 } from './1784000000021-CreateAgentTaskDefinitionTable'; @@ -387,6 +388,7 @@ const sqliteMigrations: Migration[] = [ CreateWorkflowPublicationOutboxTable1784000000020, CreateAgentTaskDefinitionTable1784000000021, AddSubAgentLinkageToAgentExecutionThreads1784000000022, + CreateInstanceAiMcpRegistryConnectionTable1784000000023, ]; export { sqliteMigrations }; diff --git a/packages/@n8n/instance-ai/src/mcp/__tests__/mcp-client-manager.test.ts b/packages/@n8n/instance-ai/src/mcp/__tests__/mcp-client-manager.test.ts index 9dab884a4c1..7d5c9550e71 100644 --- a/packages/@n8n/instance-ai/src/mcp/__tests__/mcp-client-manager.test.ts +++ b/packages/@n8n/instance-ai/src/mcp/__tests__/mcp-client-manager.test.ts @@ -246,6 +246,46 @@ describe('McpClientManager', () => { await manager.getRegularTools(configs); expect(mockedMcpClient).toHaveBeenCalledTimes(1); }); + + it('does not share cached clients across different scoped fetch cache keys', async () => { + const manager = new McpClientManager(); + await manager.getRegularTools([ + { + name: 'shared', + url: 'https://shared.example.com/', + cacheKey: 'registry-connection:1', + }, + ]); + await manager.getRegularTools([ + { + name: 'shared', + url: 'https://shared.example.com/', + cacheKey: 'registry-connection:2', + }, + ]); + + expect(mockedMcpClient).toHaveBeenCalledTimes(2); + }); + + it('reuses cached clients when scoped fetch cache key matches', async () => { + const manager = new McpClientManager(); + await manager.getRegularTools([ + { + name: 'shared', + url: 'https://shared.example.com/', + cacheKey: 'registry-connection:1', + }, + ]); + await manager.getRegularTools([ + { + name: 'shared', + url: 'https://shared.example.com/', + cacheKey: 'registry-connection:1', + }, + ]); + + expect(mockedMcpClient).toHaveBeenCalledTimes(1); + }); }); describe('concurrent dedup', () => { diff --git a/packages/@n8n/instance-ai/src/mcp/mcp-client-manager.ts b/packages/@n8n/instance-ai/src/mcp/mcp-client-manager.ts index ed3c0ed16bf..3a1bfd3b8b2 100644 --- a/packages/@n8n/instance-ai/src/mcp/mcp-client-manager.ts +++ b/packages/@n8n/instance-ai/src/mcp/mcp-client-manager.ts @@ -33,7 +33,12 @@ function buildNativeMcpConfigs(configs: McpServerConfig[]): NativeMcpServerConfi const servers: NativeMcpServerConfig[] = []; for (const server of configs) { if (server.url) { - servers.push({ name: server.name, url: server.url }); + servers.push({ + name: server.name, + url: server.url, + transport: server.transport, + fetch: server.fetch, + }); } else if (server.command) { servers.push({ name: server.name, @@ -191,7 +196,7 @@ export class McpClientManager { logger: Logger | undefined, source: string, ): Promise { - const client = new McpClient(buildNativeMcpConfigs(configs)); + const client = new McpClient(buildNativeMcpConfigs(configs), true); this.clientsByKey.set(clientKey, client); const registry = toolsToRegistry(await client.listTools()); diff --git a/packages/@n8n/instance-ai/src/types.ts b/packages/@n8n/instance-ai/src/types.ts index 4baec57bfe3..914869f7500 100644 --- a/packages/@n8n/instance-ai/src/types.ts +++ b/packages/@n8n/instance-ai/src/types.ts @@ -909,9 +909,17 @@ export type CheckpointSettleResult = export interface McpServerConfig { name: string; url?: string; + transport?: 'sse' | 'streamableHttp'; command?: string; args?: string[]; env?: Record; + fetch?: typeof fetch; + /** + * Optional cache discriminator used by `McpClientManager` when a server's + * connection behavior depends on runtime context (for example, per-user auth + * in a custom `fetch` implementation). + */ + cacheKey?: string; } // ── Memory ─────────────────────────────────────────────────────────────────── diff --git a/packages/cli/src/modules/agents/json-config/__tests__/mcp-client-factory.test.ts b/packages/cli/src/modules/agents/json-config/__tests__/mcp-client-factory.test.ts index afbbafb693c..51d4457a2a0 100644 --- a/packages/cli/src/modules/agents/json-config/__tests__/mcp-client-factory.test.ts +++ b/packages/cli/src/modules/agents/json-config/__tests__/mcp-client-factory.test.ts @@ -4,7 +4,7 @@ import { mock } from 'jest-mock-extended'; import type { OauthService } from '@/oauth/oauth.service'; -import { buildMcpClientForServer, createAuthFetch, mapApprovalToSdk } from '../mcp-client-factory'; +import { buildMcpClientForServer, mapApprovalToSdk } from '../mcp-client-factory'; // --------------------------------------------------------------------------- // Module mocks @@ -63,69 +63,6 @@ describe('mapApprovalToSdk', () => { }); }); -// --------------------------------------------------------------------------- -// createAuthFetch -// --------------------------------------------------------------------------- - -describe('createAuthFetch', () => { - beforeEach(() => { - proxyFetchMock.mockReset(); - }); - - it('routes through proxyFetch and injects the initial headers', async () => { - proxyFetchMock.mockResolvedValueOnce(makeOk()); - - const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } }); - const res = await fetchFn('https://example.test/mcp'); - - expect(res.status).toBe(200); - expect(proxyFetchMock).toHaveBeenCalledTimes(1); - const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit]; - expect(init.headers).toMatchObject({ Authorization: 'Bearer A' }); - }); - - it('returns 401 unchanged when no onUnauthorized handler is configured', async () => { - proxyFetchMock.mockResolvedValueOnce(make401()); - - const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } }); - const res = await fetchFn('https://example.test/mcp'); - - expect(res.status).toBe(401); - expect(proxyFetchMock).toHaveBeenCalledTimes(1); - }); - - it('returns the original 401 when onUnauthorized returns null', async () => { - proxyFetchMock.mockResolvedValueOnce(make401()); - - const onUnauthorized = jest.fn().mockResolvedValue(null); - const fetchFn = createAuthFetch({ - initialHeaders: { Authorization: 'Bearer A' }, - onUnauthorized, - }); - const res = await fetchFn('https://example.test/mcp'); - - expect(res.status).toBe(401); - expect(onUnauthorized).toHaveBeenCalledTimes(1); - expect(proxyFetchMock).toHaveBeenCalledTimes(1); - }); - - it('retries once with refreshed headers when onUnauthorized returns new headers', async () => { - proxyFetchMock.mockResolvedValueOnce(make401()).mockResolvedValueOnce(makeOk()); - - const onUnauthorized = jest.fn().mockResolvedValue({ Authorization: 'Bearer B' }); - const fetchFn = createAuthFetch({ - initialHeaders: { Authorization: 'Bearer A' }, - onUnauthorized, - }); - const res = await fetchFn('https://example.test/mcp'); - - expect(res.status).toBe(200); - expect(proxyFetchMock).toHaveBeenCalledTimes(2); - const [, init2] = proxyFetchMock.mock.calls[1] as [unknown, RequestInit]; - expect(init2.headers).toMatchObject({ Authorization: 'Bearer B' }); - }); -}); - // --------------------------------------------------------------------------- // buildMcpClientForServer — header derivation per auth type // --------------------------------------------------------------------------- @@ -314,56 +251,6 @@ describe('buildMcpClientForServer — SDK config mapping', () => { }); }); -// --------------------------------------------------------------------------- -// createAuthFetch — header merging and stateful refresh -// --------------------------------------------------------------------------- - -describe('createAuthFetch — header merging', () => { - beforeEach(() => { - proxyFetchMock.mockReset(); - proxyFetchMock.mockResolvedValue(new Response('ok', { status: 200 })); - }); - - it('merges caller-supplied init.headers with auth headers (auth takes precedence)', async () => { - const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } }); - await fetchFn('https://example.test/mcp', { headers: { 'X-Custom': 'value' } }); - - const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit]; - expect(init.headers).toMatchObject({ - 'X-Custom': 'value', - Authorization: 'Bearer A', - }); - }); - - it('uses the refreshed headers on the second call after a successful 401 refresh', async () => { - proxyFetchMock - .mockResolvedValueOnce(new Response('unauthorized', { status: 401 })) - .mockResolvedValueOnce(new Response('ok', { status: 200 })) - .mockResolvedValueOnce(new Response('ok', { status: 200 })); - - let callCount = 0; - const onUnauthorized = jest.fn().mockImplementation(async () => { - callCount++; - return { Authorization: `Bearer refreshed-${callCount}` }; - }); - - const fetchFn = createAuthFetch({ - initialHeaders: { Authorization: 'Bearer stale' }, - onUnauthorized, - }); - - // First call triggers a 401 → refresh → retry - await fetchFn('https://example.test/mcp'); - - // Second call should use the refreshed headers without triggering another refresh - await fetchFn('https://example.test/mcp'); - - expect(onUnauthorized).toHaveBeenCalledTimes(1); - const [, thirdInit] = proxyFetchMock.mock.calls[2] as [unknown, RequestInit]; - expect((thirdInit.headers as Record).Authorization).toBe('Bearer refreshed-1'); - }); -}); - // --------------------------------------------------------------------------- // buildMcpClientForServer — auth header edge cases // --------------------------------------------------------------------------- diff --git a/packages/cli/src/modules/agents/json-config/mcp-client-factory.ts b/packages/cli/src/modules/agents/json-config/mcp-client-factory.ts index bc34d91eb73..8f5f7a13c34 100644 --- a/packages/cli/src/modules/agents/json-config/mcp-client-factory.ts +++ b/packages/cli/src/modules/agents/json-config/mcp-client-factory.ts @@ -1,9 +1,9 @@ -import { proxyFetch } from '@n8n/ai-utilities'; import type { CredentialProvider, McpClient, McpServerConfig } from '@n8n/agents'; import type { AgentJsonMcpServerConfig } from '@n8n/api-types'; import { isMcpOAuth2Authentication } from 'n8n-workflow'; import type { OauthService } from '@/oauth/oauth.service'; +import { createAuthFetch } from '@/utils/auth-fetch'; /** * Convert the JSON-config `approval` shape into the SDK's `requireApproval` @@ -91,59 +91,6 @@ async function deriveAuthHeaders( } } -interface CreateAuthFetchOptions { - initialHeaders: Record; - /** - * Called on a 401 response. Should return a fresh set of auth headers, or - * `null` if the refresh failed. The returned headers replace the cached - * set used by subsequent requests. - */ - onUnauthorized?: () => Promise | null>; -} - -function headersToRecord(headers: HeadersInit | undefined): Record { - if (!headers) return {}; - if (headers instanceof Headers) return Object.fromEntries(headers.entries()); - if (Array.isArray(headers)) return Object.fromEntries(headers); - return headers; -} - -/** - * Build a fetch wrapper that: - * 1. routes through n8n's `proxyFetch` (so corporate HTTP_PROXY settings - * apply uniformly), - * 2. injects the latest auth headers on every request, - * 3. on a single 401, calls `onUnauthorized` to refresh the token and - * retries the request once with the new headers. - * - * This mirrors the langchain MCP node's `createAuthFetch` so an agent's MCP - * connection behaves identically to one configured via the workflow editor. - */ -export function createAuthFetch({ - initialHeaders, - onUnauthorized, -}: CreateAuthFetchOptions): typeof fetch { - let headers = initialHeaders; - - return async (input: RequestInfo | URL, init?: RequestInit): Promise => { - const response = await proxyFetch(input, { - ...init, - headers: { ...headersToRecord(init?.headers), ...headers }, - }); - - if (response.status !== 401 || !onUnauthorized) return response; - - const refreshed = await onUnauthorized(); - if (!refreshed) return response; - - headers = refreshed; - return await proxyFetch(input, { - ...init, - headers: { ...headersToRecord(init?.headers), ...headers }, - }); - }; -} - export interface BuildMcpClientDeps { credentialProvider: CredentialProvider; /** diff --git a/packages/cli/src/modules/instance-ai/entities/index.ts b/packages/cli/src/modules/instance-ai/entities/index.ts index a1b33f63ac5..259125a3f50 100644 --- a/packages/cli/src/modules/instance-ai/entities/index.ts +++ b/packages/cli/src/modules/instance-ai/entities/index.ts @@ -7,6 +7,7 @@ export { InstanceAiCheckpoint } from './instance-ai-checkpoint.entity'; export { InstanceAiObservation } from './instance-ai-observation.entity'; export { InstanceAiObservationCursor } from './instance-ai-observation-cursor.entity'; export { InstanceAiObservationLock } from './instance-ai-observation-lock.entity'; +export { InstanceAiMcpRegistryConnection } from './instance-ai-mcp-registry-connection.entity'; export type { InstanceAiObservationMarker, InstanceAiObservationStatus, diff --git a/packages/cli/src/modules/instance-ai/entities/instance-ai-mcp-registry-connection.entity.ts b/packages/cli/src/modules/instance-ai/entities/instance-ai-mcp-registry-connection.entity.ts new file mode 100644 index 00000000000..d5713c392a7 --- /dev/null +++ b/packages/cli/src/modules/instance-ai/entities/instance-ai-mcp-registry-connection.entity.ts @@ -0,0 +1,26 @@ +import { JsonColumn, WithTimestamps } from '@n8n/db'; +import { Column, Entity, Index, PrimaryColumn } from '@n8n/typeorm'; + +export type InstanceAiMcpToolFilter = { + mode: 'allow' | 'exclude'; + tools: string[]; +}; + +@Entity({ name: 'instance_ai_mcp_registry_connections' }) +@Index(['userId', 'serverSlug', 'credentialId'], { unique: true }) +export class InstanceAiMcpRegistryConnection extends WithTimestamps { + @PrimaryColumn('uuid') + id: string; + + @Column({ type: 'uuid' }) + userId: string; + + @Column({ type: 'varchar', length: 255 }) + serverSlug: string; + + @Column({ type: 'varchar', length: 36 }) + credentialId: string; + + @JsonColumn({ nullable: true }) + toolFilter: InstanceAiMcpToolFilter | null; +} diff --git a/packages/cli/src/modules/instance-ai/instance-ai.module.ts b/packages/cli/src/modules/instance-ai/instance-ai.module.ts index 924c306d91b..edc8ae24881 100644 --- a/packages/cli/src/modules/instance-ai/instance-ai.module.ts +++ b/packages/cli/src/modules/instance-ai/instance-ai.module.ts @@ -67,6 +67,9 @@ export class InstanceAiModule implements ModuleInterface { const { InstanceAiObservationLock } = await import( './entities/instance-ai-observation-lock.entity' ); + const { InstanceAiMcpRegistryConnection } = await import( + './entities/instance-ai-mcp-registry-connection.entity' + ); return [ InstanceAiThread, @@ -79,6 +82,7 @@ export class InstanceAiModule implements ModuleInterface { InstanceAiObservation, InstanceAiObservationCursor, InstanceAiObservationLock, + InstanceAiMcpRegistryConnection, ]; } diff --git a/packages/cli/src/modules/instance-ai/instance-ai.service.ts b/packages/cli/src/modules/instance-ai/instance-ai.service.ts index e140304385f..05482166ba2 100644 --- a/packages/cli/src/modules/instance-ai/instance-ai.service.ts +++ b/packages/cli/src/modules/instance-ai/instance-ai.service.ts @@ -114,6 +114,7 @@ import { InstanceAiPendingConfirmationRepository } from './repositories/instance import { InstanceAiThreadRepository } from './repositories/instance-ai-thread.repository'; import { TraceReplayState } from './trace-replay-state'; import { INSTANCE_AI_RUN_TIMEOUT_REASON, InstanceAiLivenessService } from './liveness'; +import { InstanceAiMcpRegistryService } from './mcp'; import { buildInstanceAiRunTraceMetadata, type InstanceAiRunTraceMetadataOptions, @@ -639,6 +640,7 @@ export class InstanceAiService { private readonly dbIterationLogStorage: DbIterationLogStorage, private readonly sourceControlPreferencesService: SourceControlPreferencesService, private readonly telemetry: Telemetry, + private readonly mcpRegistryService: InstanceAiMcpRegistryService, private readonly userRepository: UserRepository, private readonly aiBuilderTemporaryWorkflowRepository: AiBuilderTemporaryWorkflowRepository, private readonly errorReporter: ErrorReporter, @@ -3484,7 +3486,9 @@ export class InstanceAiService { return; } - const mcpServers = this.parseMcpServers(this.instanceAiConfig.mcpServers); + const staticMcpServers = this.parseMcpServers(this.instanceAiConfig.mcpServers); + const registryMcpServers = await this.mcpRegistryService.getRegistryMcpServers(user); + const mcpServers = [...staticMcpServers, ...registryMcpServers]; const executionPushRef = this.threadPushRef.get(threadId); const environment = await this.createExecutionEnvironment( diff --git a/packages/cli/src/modules/instance-ai/mcp/__tests__/instance-ai-mcp-registry.service.test.ts b/packages/cli/src/modules/instance-ai/mcp/__tests__/instance-ai-mcp-registry.service.test.ts new file mode 100644 index 00000000000..a7d54fe0d21 --- /dev/null +++ b/packages/cli/src/modules/instance-ai/mcp/__tests__/instance-ai-mcp-registry.service.test.ts @@ -0,0 +1,304 @@ +import type { Logger } from '@n8n/backend-common'; +import type { CredentialsEntity, User } from '@n8n/db'; +import { mock } from 'jest-mock-extended'; + +import type { CredentialsFinderService } from '@/credentials/credentials-finder.service'; +import type { CredentialsService } from '@/credentials/credentials.service'; +import type { McpRegistryService } from '@/modules/mcp-registry/registry/mcp-registry.service'; +import type { McpRegistryServer } from '@/modules/mcp-registry/registry/mcp-registry.types'; +import type { OauthService } from '@/oauth/oauth.service'; + +import type { InstanceAiMcpRegistryConnectionRepository } from '../../repositories/instance-ai-mcp-registry-connection.repository'; +import { InstanceAiMcpRegistryService } from '../instance-ai-mcp-registry.service'; +import type { InstanceAiMcpRegistryConnection } from '../../entities/instance-ai-mcp-registry-connection.entity'; + +const proxyFetchMock = jest.fn(); + +jest.mock('@n8n/ai-utilities', () => ({ + proxyFetch: (...args: unknown[]) => proxyFetchMock(...args), +})); + +function makeRegistryServer( + slug: string, + overrides: Partial = {}, +): McpRegistryServer { + return { + name: `com.test/${slug}`, + slug, + title: slug, + description: `${slug} description`, + tagline: `${slug} tagline`, + version: '1.0.0', + updatedAt: '2026-05-01T00:00:00.000Z', + icons: [], + authType: 'oauth2', + remotes: [{ type: 'streamable-http', url: `https://${slug}.example.com/mcp` }], + tools: [], + isOfficial: true, + origin: 'registry', + status: 'active', + ...overrides, + }; +} + +describe('InstanceAiMcpRegistryService', () => { + const user = { id: 'user-1' } as User; + const credential = { + id: 'cred-1', + name: 'MCP OAuth2', + type: 'mcpOAuth2Api', + shared: [{ role: 'credential:owner', projectId: 'project-1' }], + } as CredentialsEntity; + + const oauthCredentialData = { + clientId: 'client-id', + clientSecret: 'client-secret', + accessTokenUrl: 'https://auth.example.com/token', + oauthTokenData: { + access_token: 'stale-token', + refresh_token: 'refresh-token', + }, + }; + + function createService() { + const logger = mock({ scoped: jest.fn().mockReturnThis() }); + const connectionRepository = mock(); + const mcpRegistryService = mock(); + const credentialsFinderService = mock(); + const credentialsService = mock(); + const oauthService = mock(); + + const service = new InstanceAiMcpRegistryService( + logger, + connectionRepository, + mcpRegistryService, + credentialsFinderService, + credentialsService, + oauthService, + ); + + return { + service, + logger, + connectionRepository, + mcpRegistryService, + credentialsFinderService, + credentialsService, + oauthService, + }; + } + + beforeEach(() => { + jest.clearAllMocks(); + proxyFetchMock.mockReset(); + }); + + it('returns empty list when the user has no registry connections', async () => { + const { service, connectionRepository, mcpRegistryService } = createService(); + connectionRepository.findBy.mockResolvedValue([]); + + const result = await service.getRegistryMcpServers(user); + + expect(result).toEqual([]); + expect(mcpRegistryService.getBySlugs).not.toHaveBeenCalled(); + }); + + it('resolves servers with deterministic names and preferred transport', async () => { + const { + service, + connectionRepository, + mcpRegistryService, + credentialsFinderService, + credentialsService, + } = createService(); + const credentialsById: Record = { + 'cred-1': { id: 'cred-1', name: 'MCP OAuth2 #1', type: 'mcpOAuth2Api' } as CredentialsEntity, + 'cred-2': { id: 'cred-2', name: 'MCP OAuth2 #2', type: 'mcpOAuth2Api' } as CredentialsEntity, + 'cred-3': { id: 'cred-3', name: 'MCP OAuth2 #3', type: 'mcpOAuth2Api' } as CredentialsEntity, + }; + connectionRepository.findBy.mockResolvedValue([ + { id: '2', userId: user.id, serverSlug: 'linear', credentialId: 'cred-2' }, + { id: '1', userId: user.id, serverSlug: 'linear', credentialId: 'cred-1' }, + { id: '3', userId: user.id, serverSlug: 'notion', credentialId: 'cred-3' }, + ] as InstanceAiMcpRegistryConnection[]); + mcpRegistryService.getBySlugs.mockResolvedValue([ + makeRegistryServer('linear', { + remotes: [ + { type: 'sse', url: 'https://linear.example.com/sse' }, + { type: 'streamable-http', url: 'https://linear.example.com/mcp' }, + ], + }), + makeRegistryServer('notion', { + remotes: [{ type: 'sse', url: 'https://notion.example.com/sse' }], + }), + ]); + credentialsFinderService.findCredentialForUser.mockImplementation(async (credentialId) => { + return credentialsById[credentialId] ?? null; + }); + credentialsService.decrypt.mockResolvedValue(oauthCredentialData); + + const result = await service.getRegistryMcpServers(user); + + expect(result).toHaveLength(3); + expect(result[0]).toEqual( + expect.objectContaining({ + name: 'mcp_linear', + url: 'https://linear.example.com/mcp', + transport: 'streamableHttp', + cacheKey: 'registry-connection:1', + fetch: expect.any(Function), + }), + ); + expect(result[1]).toEqual( + expect.objectContaining({ + name: 'mcp_linear_2', + url: 'https://linear.example.com/mcp', + transport: 'streamableHttp', + cacheKey: 'registry-connection:2', + fetch: expect.any(Function), + }), + ); + expect(result[2]).toEqual( + expect.objectContaining({ + name: 'mcp_notion', + url: 'https://notion.example.com/sse', + transport: 'sse', + cacheKey: 'registry-connection:3', + fetch: expect.any(Function), + }), + ); + expect(credentialsFinderService.findCredentialForUser).toHaveBeenCalledWith('cred-1', user, [ + 'credential:read', + ]); + expect(credentialsFinderService.findCredentialForUser).toHaveBeenCalledWith('cred-2', user, [ + 'credential:read', + ]); + expect(credentialsFinderService.findCredentialForUser).toHaveBeenCalledWith('cred-3', user, [ + 'credential:read', + ]); + }); + + it('skips connections with missing server slugs or unsupported remotes', async () => { + const { service, connectionRepository, mcpRegistryService, logger } = createService(); + connectionRepository.findBy.mockResolvedValue([ + { id: '1', userId: user.id, serverSlug: 'missing', credentialId: credential.id }, + { id: '2', userId: user.id, serverSlug: 'bad-remote', credentialId: credential.id }, + ] as InstanceAiMcpRegistryConnection[]); + mcpRegistryService.getBySlugs.mockResolvedValue([ + makeRegistryServer('bad-remote', { remotes: [] }), + ]); + + const result = await service.getRegistryMcpServers(user); + + expect(result).toEqual([]); + expect(logger.warn).toHaveBeenCalledWith( + 'Skipping MCP registry connection with missing server slug', + expect.objectContaining({ connectionId: '1', serverSlug: 'missing', userId: user.id }), + ); + expect(logger.warn).toHaveBeenCalledWith( + 'Skipping MCP registry connection without supported remote transport', + expect.objectContaining({ connectionId: '2', serverSlug: 'bad-remote' }), + ); + }); + + it('does not attach custom fetch for non-oauth servers', async () => { + const { + service, + connectionRepository, + mcpRegistryService, + credentialsFinderService, + credentialsService, + } = createService(); + connectionRepository.findBy.mockResolvedValue([ + { id: '1', userId: user.id, serverSlug: 'public-server', credentialId: credential.id }, + ] as InstanceAiMcpRegistryConnection[]); + mcpRegistryService.getBySlugs.mockResolvedValue([ + makeRegistryServer('public-server', { + // currently only oauth2 is supported + // so we need to cast it to test this behavior + authType: 'none' as unknown as 'oauth2', + }), + ]); + + const [server] = await service.getRegistryMcpServers(user); + + expect(server).toEqual( + expect.objectContaining({ + name: 'mcp_public-server', + url: 'https://public-server.example.com/mcp', + transport: 'streamableHttp', + cacheKey: 'registry-connection:1', + }), + ); + expect(server.fetch).toBeUndefined(); + expect(credentialsFinderService.findCredentialForUser).not.toHaveBeenCalled(); + expect(credentialsService.decrypt).not.toHaveBeenCalled(); + }); + + it('adds auth header and retries once with refreshed OAuth token after 401', async () => { + const { + service, + connectionRepository, + mcpRegistryService, + credentialsFinderService, + credentialsService, + oauthService, + } = createService(); + connectionRepository.findBy.mockResolvedValue([ + { id: '1', userId: user.id, serverSlug: 'linear', credentialId: credential.id }, + ] as InstanceAiMcpRegistryConnection[]); + mcpRegistryService.getBySlugs.mockResolvedValue([makeRegistryServer('linear')]); + credentialsFinderService.findCredentialForUser.mockResolvedValue(credential); + credentialsService.decrypt.mockResolvedValue(oauthCredentialData); + proxyFetchMock + .mockResolvedValueOnce(new Response('unauthorized', { status: 401 })) + .mockResolvedValueOnce(new Response('ok', { status: 200 })); + oauthService.refreshOAuth2CredentialById.mockResolvedValue({ + Authorization: 'Bearer fresh-token', + }); + + const [server] = await service.getRegistryMcpServers(user); + const response = await server.fetch?.('https://linear.example.com/mcp'); + + expect(response?.status).toBe(200); + expect(proxyFetchMock).toHaveBeenCalledTimes(2); + const [, firstInit] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit]; + const [, secondInit] = proxyFetchMock.mock.calls[1] as [unknown, RequestInit]; + expect(new Headers(firstInit.headers).get('Authorization')).toBe('Bearer stale-token'); + expect(new Headers(secondInit.headers).get('Authorization')).toBe('Bearer fresh-token'); + expect(oauthService.refreshOAuth2CredentialById).toHaveBeenCalledWith( + credential.id, + 'project-1', + ); + }); + + it('returns original 401 response when token refresh fails', async () => { + const { + service, + connectionRepository, + mcpRegistryService, + credentialsFinderService, + credentialsService, + oauthService, + } = createService(); + connectionRepository.findBy.mockResolvedValue([ + { id: '1', userId: user.id, serverSlug: 'linear', credentialId: credential.id }, + ] as InstanceAiMcpRegistryConnection[]); + mcpRegistryService.getBySlugs.mockResolvedValue([makeRegistryServer('linear')]); + credentialsFinderService.findCredentialForUser.mockResolvedValue(credential); + credentialsService.decrypt.mockResolvedValue(oauthCredentialData); + oauthService.refreshOAuth2CredentialById.mockResolvedValue(null); + + proxyFetchMock.mockResolvedValue(new Response('unauthorized', { status: 401 })); + + const [server] = await service.getRegistryMcpServers(user); + const response = await server.fetch?.('https://linear.example.com/mcp'); + + expect(response?.status).toBe(401); + expect(proxyFetchMock).toHaveBeenCalledTimes(1); + expect(oauthService.refreshOAuth2CredentialById).toHaveBeenCalledWith( + credential.id, + 'project-1', + ); + }); +}); diff --git a/packages/cli/src/modules/instance-ai/mcp/index.ts b/packages/cli/src/modules/instance-ai/mcp/index.ts new file mode 100644 index 00000000000..f28656a4fe3 --- /dev/null +++ b/packages/cli/src/modules/instance-ai/mcp/index.ts @@ -0,0 +1 @@ +export * from './instance-ai-mcp-registry.service'; diff --git a/packages/cli/src/modules/instance-ai/mcp/instance-ai-mcp-registry.service.ts b/packages/cli/src/modules/instance-ai/mcp/instance-ai-mcp-registry.service.ts new file mode 100644 index 00000000000..63b5798bf61 --- /dev/null +++ b/packages/cli/src/modules/instance-ai/mcp/instance-ai-mcp-registry.service.ts @@ -0,0 +1,266 @@ +import { isObjectLiteral, Logger } from '@n8n/backend-common'; +import type { CredentialsEntity, User } from '@n8n/db'; +import { Service } from '@n8n/di'; +import type { McpServerConfig } from '@n8n/instance-ai'; +import type { ICredentialDataDecryptedObject } from 'n8n-workflow'; + +import { CredentialsFinderService } from '@/credentials/credentials-finder.service'; +import { CredentialsService } from '@/credentials/credentials.service'; +import { McpRegistryService } from '@/modules/mcp-registry/registry/mcp-registry.service'; +import type { McpRegistryRemote } from '@/modules/mcp-registry/registry/mcp-registry.types'; +import { OauthService } from '@/oauth/oauth.service'; +import { createAuthFetch } from '@/utils/auth-fetch'; + +import { InstanceAiMcpRegistryConnectionRepository } from '../repositories/instance-ai-mcp-registry-connection.repository'; + +type Transport = 'sse' | 'streamableHttp'; + +interface ResolvedRegistryServer { + serverSlug: string; + credentialId: string; + authType: string; + endpointUrl: string; + transport: Transport; +} + +interface OAuth2FetchContext { + credentialId: string; + accessToken: string; + projectId: string; +} + +function readString(data: Record, key: string): string | undefined { + const value = data[key]; + return typeof value === 'string' && value.length > 0 ? value : undefined; +} + +function readAccessToken(tokenData: Record): string | undefined { + return readString(tokenData, 'accessToken') ?? readString(tokenData, 'access_token'); +} + +function readOAuthTokenData(data: ICredentialDataDecryptedObject): Record | null { + const tokenData = data.oauthTokenData; + return isObjectLiteral(tokenData) ? tokenData : null; +} + +function getPreferredRemote(remotes: McpRegistryRemote[]): { + transport: Transport; + endpointUrl: string; +} | null { + const streamable = remotes.find((remote) => remote.type === 'streamable-http'); + if (streamable?.url) { + return { transport: 'streamableHttp', endpointUrl: streamable.url }; + } + + const sse = remotes.find((remote) => remote.type === 'sse'); + if (sse?.url) { + return { transport: 'sse', endpointUrl: sse.url }; + } + + return null; +} + +const MCP_REGISTRY_SERVER_PREFIX = 'mcp_'; +const MAX_MCP_SERVER_NAME_LENGTH = 24; + +function buildServerName(serverSlug: string, sequence: number): string { + const safeSlug = serverSlug.replace(/[^A-Za-z0-9_-]/g, '_'); + const baseName = `${MCP_REGISTRY_SERVER_PREFIX}${safeSlug}`; + if (sequence <= 1) { + return baseName.slice(0, MAX_MCP_SERVER_NAME_LENGTH); + } + + const suffix = `_${sequence}`; + const maxBaseLength = Math.max(0, MAX_MCP_SERVER_NAME_LENGTH - suffix.length); + return `${baseName.slice(0, maxBaseLength)}${suffix}`; +} + +@Service() +export class InstanceAiMcpRegistryService { + private readonly logger: Logger; + + constructor( + logger: Logger, + private readonly connectionRepository: InstanceAiMcpRegistryConnectionRepository, + private readonly mcpRegistryService: McpRegistryService, + private readonly credentialsFinderService: CredentialsFinderService, + private readonly credentialsService: CredentialsService, + private readonly oauthService: OauthService, + ) { + this.logger = logger.scoped('instance-ai'); + } + + async getRegistryMcpServers(user: User): Promise { + const connections = await this.connectionRepository.findBy({ userId: user.id }); + if (connections.length === 0) { + return []; + } + + const sortedConnections = connections.sort((left, right) => left.id.localeCompare(right.id)); + const slugs = [...new Set(sortedConnections.map((connection) => connection.serverSlug))]; + const servers = await this.mcpRegistryService.getBySlugs(slugs); + const serverBySlug = new Map(servers.map((server) => [server.slug, server])); + const slugCounts = new Map(); + + const resolved: McpServerConfig[] = []; + for (const connection of sortedConnections) { + const server = serverBySlug.get(connection.serverSlug); + if (!server) { + this.logger.warn('Skipping MCP registry connection with missing server slug', { + connectionId: connection.id, + serverSlug: connection.serverSlug, + userId: user.id, + }); + continue; + } + + const resolvedServer = this.resolveRegistryServer( + connection.id, + connection.serverSlug, + connection.credentialId, + server.authType, + server.remotes, + ); + if (!resolvedServer) { + continue; + } + + const nextCount = (slugCounts.get(resolvedServer.serverSlug) ?? 0) + 1; + slugCounts.set(resolvedServer.serverSlug, nextCount); + const serverConfig: McpServerConfig = { + name: buildServerName(resolvedServer.serverSlug, nextCount), + url: resolvedServer.endpointUrl, + transport: resolvedServer.transport, + cacheKey: `registry-connection:${connection.id}`, + }; + + if (resolvedServer.authType === 'oauth2') { + const oauth2FetchContext = await this.buildOAuth2FetchContext( + resolvedServer, + user, + connection.id, + ); + if (!oauth2FetchContext) { + continue; + } + + serverConfig.fetch = createAuthFetch({ + initialHeaders: { Authorization: `Bearer ${oauth2FetchContext.accessToken}` }, + onUnauthorized: async () => { + if (!oauth2FetchContext.projectId) { + return null; + } + + return await this.oauthService.refreshOAuth2CredentialById( + oauth2FetchContext.credentialId, + oauth2FetchContext.projectId, + ); + }, + }); + } + + resolved.push(serverConfig); + } + + return resolved; + } + + private resolveRegistryServer( + connectionId: string, + serverSlug: string, + credentialId: string, + authType: string, + remotes: McpRegistryRemote[], + ): ResolvedRegistryServer | null { + const remote = getPreferredRemote(remotes); + if (!remote) { + this.logger.warn('Skipping MCP registry connection without supported remote transport', { + connectionId, + serverSlug, + credentialId, + }); + return null; + } + + return { + serverSlug, + credentialId, + authType, + endpointUrl: remote.endpointUrl, + transport: remote.transport, + }; + } + + private async buildOAuth2FetchContext( + config: ResolvedRegistryServer, + user: User, + connectionId: string, + ): Promise { + const credentialWithData = await this.getCredentialWithData(config.credentialId, user); + if (!credentialWithData) { + this.logger.warn('Skipping MCP registry connection with inaccessible credential', { + connectionId, + serverSlug: config.serverSlug, + credentialId: config.credentialId, + userId: user.id, + }); + return null; + } + + const tokenData = readOAuthTokenData(credentialWithData.data); + if (!tokenData) { + this.logger.warn('Skipping MCP registry connection without OAuth2 token data', { + connectionId, + serverSlug: config.serverSlug, + credentialId: config.credentialId, + }); + return null; + } + + const accessToken = readAccessToken(tokenData); + if (!accessToken) { + this.logger.warn('Skipping MCP registry connection without access token', { + connectionId, + serverSlug: config.serverSlug, + credentialId: config.credentialId, + }); + return null; + } + + const projectId = credentialWithData.credential.shared?.[0]?.projectId ?? null; + if (!projectId) { + this.logger.warn('Skipping OAuth2 token refresh for credential without project sharing', { + connectionId, + serverSlug: config.serverSlug, + credentialId: config.credentialId, + }); + } + + return { + credentialId: config.credentialId, + accessToken, + projectId, + }; + } + + private async getCredentialWithData( + credentialId: string, + user: User, + ): Promise<{ credential: CredentialsEntity; data: ICredentialDataDecryptedObject } | null> { + const credential = await this.credentialsFinderService.findCredentialForUser( + credentialId, + user, + ['credential:read'], + ); + if (!credential) { + return null; + } + + const data = await this.credentialsService.decrypt(credential, true); + if (!isObjectLiteral(data) || Object.keys(data).length === 0) { + return null; + } + + return { credential, data }; + } +} diff --git a/packages/cli/src/modules/instance-ai/repositories/index.ts b/packages/cli/src/modules/instance-ai/repositories/index.ts index 5406ca08976..659b1da408e 100644 --- a/packages/cli/src/modules/instance-ai/repositories/index.ts +++ b/packages/cli/src/modules/instance-ai/repositories/index.ts @@ -7,3 +7,4 @@ export { InstanceAiCheckpointRepository } from './instance-ai-checkpoint.reposit export { InstanceAiObservationRepository } from './instance-ai-observation.repository'; export { InstanceAiObservationCursorRepository } from './instance-ai-observation-cursor.repository'; export { InstanceAiObservationLockRepository } from './instance-ai-observation-lock.repository'; +export { InstanceAiMcpRegistryConnectionRepository } from './instance-ai-mcp-registry-connection.repository'; diff --git a/packages/cli/src/modules/instance-ai/repositories/instance-ai-mcp-registry-connection.repository.ts b/packages/cli/src/modules/instance-ai/repositories/instance-ai-mcp-registry-connection.repository.ts new file mode 100644 index 00000000000..99fe3ab8e79 --- /dev/null +++ b/packages/cli/src/modules/instance-ai/repositories/instance-ai-mcp-registry-connection.repository.ts @@ -0,0 +1,11 @@ +import { Service } from '@n8n/di'; +import { DataSource, Repository } from '@n8n/typeorm'; + +import { InstanceAiMcpRegistryConnection } from '../entities/instance-ai-mcp-registry-connection.entity'; + +@Service() +export class InstanceAiMcpRegistryConnectionRepository extends Repository { + constructor(dataSource: DataSource) { + super(InstanceAiMcpRegistryConnection, dataSource.manager); + } +} diff --git a/packages/cli/src/modules/mcp-registry/registry/__tests__/mcp-registry.service.test.ts b/packages/cli/src/modules/mcp-registry/registry/__tests__/mcp-registry.service.test.ts index 3fe332d43f3..66426b139b8 100644 --- a/packages/cli/src/modules/mcp-registry/registry/__tests__/mcp-registry.service.test.ts +++ b/packages/cli/src/modules/mcp-registry/registry/__tests__/mcp-registry.service.test.ts @@ -131,6 +131,24 @@ describe('McpRegistryService', () => { expect(notion).toEqual(notionMockServer); expect(missing).toBeUndefined(); }); + + it('returns empty array for getBySlugs when input is empty', async () => { + const { service, repository } = createService(); + + const servers = await service.getBySlugs([]); + + expect(servers).toEqual([]); + expect(repository.findBy).not.toHaveBeenCalled(); + }); + + it('returns mapped servers for getBySlugs', async () => { + const { service, repository } = createService(); + + const servers = await service.getBySlugs(['notion', 'linear']); + + expect(repository.findBy).toHaveBeenCalledWith([{ slug: 'notion' }, { slug: 'linear' }]); + expect(servers).toEqual([notionMockServer, linearMockServer]); + }); }); describe('refresh flow', () => { diff --git a/packages/cli/src/modules/mcp-registry/registry/mcp-registry.service.ts b/packages/cli/src/modules/mcp-registry/registry/mcp-registry.service.ts index daab80d3a49..1ec2eb180cb 100644 --- a/packages/cli/src/modules/mcp-registry/registry/mcp-registry.service.ts +++ b/packages/cli/src/modules/mcp-registry/registry/mcp-registry.service.ts @@ -90,6 +90,15 @@ export class McpRegistryService { return entity ? fromEntity(entity) : undefined; } + async getBySlugs(slugs: string[]): Promise { + if (slugs.length === 0) { + return []; + } + + const entities = await this.repository.findBy(slugs.map((slug) => ({ slug }))); + return entities.map(fromEntity); + } + private startPeriodicRefresh(): void { if (this.isShuttingDown || this.refreshInterval) { return; diff --git a/packages/cli/src/utils/__tests__/auth-fetch.test.ts b/packages/cli/src/utils/__tests__/auth-fetch.test.ts new file mode 100644 index 00000000000..130c267d0a9 --- /dev/null +++ b/packages/cli/src/utils/__tests__/auth-fetch.test.ts @@ -0,0 +1,119 @@ +import { createAuthFetch } from '@/utils/auth-fetch'; + +const proxyFetchMock = jest.fn(); +jest.mock('@n8n/ai-utilities', () => ({ + proxyFetch: (...args: unknown[]) => proxyFetchMock(...args), +})); + +function makeOk(): Response { + return new Response('ok', { status: 200 }); +} + +function make401(): Response { + return new Response('unauthorized', { status: 401 }); +} + +describe('createAuthFetch', () => { + beforeEach(() => { + proxyFetchMock.mockReset(); + }); + + it('routes through proxyFetch and injects the initial headers', async () => { + proxyFetchMock.mockResolvedValueOnce(makeOk()); + + const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } }); + const res = await fetchFn('https://example.test/mcp'); + + expect(res.status).toBe(200); + expect(proxyFetchMock).toHaveBeenCalledTimes(1); + const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit]; + expect(init.headers).toMatchObject({ Authorization: 'Bearer A' }); + }); + + it('returns 401 unchanged when no onUnauthorized handler is configured', async () => { + proxyFetchMock.mockResolvedValueOnce(make401()); + + const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } }); + const res = await fetchFn('https://example.test/mcp'); + + expect(res.status).toBe(401); + expect(proxyFetchMock).toHaveBeenCalledTimes(1); + }); + + it('returns the original 401 when onUnauthorized returns null', async () => { + proxyFetchMock.mockResolvedValueOnce(make401()); + + const onUnauthorized = jest.fn().mockResolvedValue(null); + const fetchFn = createAuthFetch({ + initialHeaders: { Authorization: 'Bearer A' }, + onUnauthorized, + }); + const res = await fetchFn('https://example.test/mcp'); + + expect(res.status).toBe(401); + expect(onUnauthorized).toHaveBeenCalledTimes(1); + expect(proxyFetchMock).toHaveBeenCalledTimes(1); + }); + + it('retries once with refreshed headers when onUnauthorized returns new headers', async () => { + proxyFetchMock.mockResolvedValueOnce(make401()).mockResolvedValueOnce(makeOk()); + + const onUnauthorized = jest.fn().mockResolvedValue({ Authorization: 'Bearer B' }); + const fetchFn = createAuthFetch({ + initialHeaders: { Authorization: 'Bearer A' }, + onUnauthorized, + }); + const res = await fetchFn('https://example.test/mcp'); + + expect(res.status).toBe(200); + expect(proxyFetchMock).toHaveBeenCalledTimes(2); + const [, init2] = proxyFetchMock.mock.calls[1] as [unknown, RequestInit]; + expect(init2.headers).toMatchObject({ Authorization: 'Bearer B' }); + }); +}); + +describe('createAuthFetch — header merging', () => { + beforeEach(() => { + proxyFetchMock.mockReset(); + proxyFetchMock.mockResolvedValue(new Response('ok', { status: 200 })); + }); + + it('merges caller-supplied init.headers with auth headers (auth takes precedence)', async () => { + const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } }); + await fetchFn('https://example.test/mcp', { headers: { 'X-Custom': 'value' } }); + + const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit]; + expect(init.headers).toMatchObject({ + 'X-Custom': 'value', + Authorization: 'Bearer A', + }); + }); + + it('uses the refreshed headers on the second call after a successful 401 refresh', async () => { + proxyFetchMock + .mockResolvedValueOnce(new Response('unauthorized', { status: 401 })) + .mockResolvedValueOnce(new Response('ok', { status: 200 })) + .mockResolvedValueOnce(new Response('ok', { status: 200 })); + + let callCount = 0; + const onUnauthorized = jest.fn().mockImplementation(async () => { + callCount++; + return { Authorization: `Bearer refreshed-${callCount}` }; + }); + + const fetchFn = createAuthFetch({ + initialHeaders: { Authorization: 'Bearer stale' }, + onUnauthorized, + }); + + // First call triggers a 401 → refresh → retry + await fetchFn('https://example.test/mcp'); + + // Second call should use the refreshed headers without triggering another refresh + await fetchFn('https://example.test/mcp'); + + expect(onUnauthorized).toHaveBeenCalledTimes(1); + const [, thirdInit] = proxyFetchMock.mock.calls[2] as [unknown, RequestInit]; + expect((thirdInit.headers as Record).Authorization).toBe('Bearer refreshed-1'); + }); +}); diff --git a/packages/cli/src/utils/auth-fetch.ts b/packages/cli/src/utils/auth-fetch.ts new file mode 100644 index 00000000000..1bae86c3898 --- /dev/null +++ b/packages/cli/src/utils/auth-fetch.ts @@ -0,0 +1,54 @@ +import { proxyFetch } from '@n8n/ai-utilities'; + +interface CreateAuthFetchOptions { + initialHeaders: Record; + /** + * Called on a 401 response. Should return a fresh set of auth headers, or + * `null` if the refresh failed. The returned headers replace the cached + * set used by subsequent requests. + */ + onUnauthorized?: () => Promise | null>; +} + +function headersToRecord(headers: HeadersInit | undefined): Record { + if (!headers) return {}; + if (headers instanceof Headers) return Object.fromEntries(headers.entries()); + if (Array.isArray(headers)) return Object.fromEntries(headers); + return headers; +} + +/** + * Build a fetch wrapper that: + * 1. routes through n8n's `proxyFetch` (so corporate HTTP_PROXY settings + * apply uniformly), + * 2. injects the latest auth headers on every request, + * 3. on a single 401, calls `onUnauthorized` to refresh the token and + * retries the request once with the new headers. + * + * This mirrors the langchain MCP node's `createAuthFetch` so an agent's MCP + * connection behaves identically to one configured via the workflow editor. + */ +export function createAuthFetch({ + initialHeaders, + onUnauthorized, +}: CreateAuthFetchOptions): typeof fetch { + let headers = initialHeaders; + + return async (input: RequestInfo | URL, init?: RequestInit): Promise => { + const response = await proxyFetch(input, { + ...init, + headers: { ...headersToRecord(init?.headers), ...headers }, + }); + + if (response.status !== 401 || !onUnauthorized) return response; + + const refreshed = await onUnauthorized(); + if (!refreshed) return response; + + headers = refreshed; + return await proxyFetch(input, { + ...init, + headers: { ...headersToRecord(init?.headers), ...headers }, + }); + }; +}