feat(core): Add support for per-user connections to MCP servers from the registry in instance AI (#31325)
Some checks are pending
Build: Benchmark Image / build (push) Waiting to run
CI: Master (Build, Test, Lint) / Build for Github Cache (push) Waiting to run
CI: Master (Build, Test, Lint) / Unit tests (22.22.3) (push) Waiting to run
CI: Master (Build, Test, Lint) / Unit tests (24.15.0) (push) Waiting to run
CI: Master (Build, Test, Lint) / Lint (push) Waiting to run
CI: Master (Build, Test, Lint) / Performance (push) Waiting to run
CI: Master (Build, Test, Lint) / Notify Slack on failure (push) Blocked by required conditions
Util: Sync API Docs / sync-public-api (push) Waiting to run

This commit is contained in:
RomanDavydchuk 2026-06-02 18:27:14 +03:00 committed by GitHub
parent 29b1220a90
commit ee3b277ff0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 920 additions and 171 deletions

View File

@ -0,0 +1,40 @@
import type { MigrationContext, ReversibleMigration } from '../migration-types';
const table = 'instance_ai_mcp_registry_connections';
export class CreateInstanceAiMcpRegistryConnectionTable1784000000023
implements ReversibleMigration
{
async up({ schemaBuilder: { createTable, column } }: MigrationContext) {
await createTable(table)
.withColumns(
column('id').uuid.primary,
column('credentialId').varchar(36).notNull,
column('serverSlug').varchar(255).notNull,
column('toolFilter').json.comment(
'Optional MCP tool filter per registry connection: { mode: "allow" | "exclude", tools: string[] }',
),
column('userId').uuid.notNull,
)
.withIndexOn(['userId', 'serverSlug', 'credentialId'], true)
.withForeignKey('credentialId', {
tableName: 'credentials_entity',
columnName: 'id',
onDelete: 'CASCADE',
})
.withForeignKey('serverSlug', {
tableName: 'mcp_registry_server',
columnName: 'slug',
onDelete: 'CASCADE',
})
.withForeignKey('userId', {
tableName: 'user',
columnName: 'id',
onDelete: 'CASCADE',
}).withTimestamps;
}
async down({ schemaBuilder: { dropTable } }: MigrationContext) {
await dropTable(table);
}
}

View File

@ -198,6 +198,7 @@ import { AddCustomTelemetryTagsToProject1784000000019 } from '../common/17840000
import { CreateWorkflowPublicationOutboxTable1784000000020 } from '../common/1784000000020-CreateWorkflowPublicationOutboxTable';
import { CreateAgentTaskDefinitionTable1784000000021 } from '../common/1784000000021-CreateAgentTaskDefinitionTable';
import { AddSubAgentLinkageToAgentExecutionThreads1784000000022 } from '../common/1784000000022-AddSubAgentLinkageToAgentExecutionThreads';
import { CreateInstanceAiMcpRegistryConnectionTable1784000000023 } from '../common/1784000000023-CreateInstanceAiMcpRegistryConnectionTable';
import type { Migration } from '../migration-types';
export const postgresMigrations: Migration[] = [
@ -401,4 +402,5 @@ export const postgresMigrations: Migration[] = [
CreateWorkflowPublicationOutboxTable1784000000020,
CreateAgentTaskDefinitionTable1784000000021,
AddSubAgentLinkageToAgentExecutionThreads1784000000022,
CreateInstanceAiMcpRegistryConnectionTable1784000000023,
];

View File

@ -190,6 +190,7 @@ import { CreateAgentFilesTable1784000000018 } from '../common/1784000000018-Crea
import { AddCustomTelemetryTagsToProject1784000000019 } from '../common/1784000000019-AddCustomTelemetryTagsToProject';
import { CreateWorkflowPublicationOutboxTable1784000000020 } from '../common/1784000000020-CreateWorkflowPublicationOutboxTable';
import { AddSubAgentLinkageToAgentExecutionThreads1784000000022 } from '../common/1784000000022-AddSubAgentLinkageToAgentExecutionThreads';
import { CreateInstanceAiMcpRegistryConnectionTable1784000000023 } from '../common/1784000000023-CreateInstanceAiMcpRegistryConnectionTable';
import type { Migration } from '../migration-types';
import { CreateAgentTaskDefinitionTable1784000000021 } from './1784000000021-CreateAgentTaskDefinitionTable';
@ -387,6 +388,7 @@ const sqliteMigrations: Migration[] = [
CreateWorkflowPublicationOutboxTable1784000000020,
CreateAgentTaskDefinitionTable1784000000021,
AddSubAgentLinkageToAgentExecutionThreads1784000000022,
CreateInstanceAiMcpRegistryConnectionTable1784000000023,
];
export { sqliteMigrations };

View File

@ -246,6 +246,46 @@ describe('McpClientManager', () => {
await manager.getRegularTools(configs);
expect(mockedMcpClient).toHaveBeenCalledTimes(1);
});
it('does not share cached clients across different scoped fetch cache keys', async () => {
const manager = new McpClientManager();
await manager.getRegularTools([
{
name: 'shared',
url: 'https://shared.example.com/',
cacheKey: 'registry-connection:1',
},
]);
await manager.getRegularTools([
{
name: 'shared',
url: 'https://shared.example.com/',
cacheKey: 'registry-connection:2',
},
]);
expect(mockedMcpClient).toHaveBeenCalledTimes(2);
});
it('reuses cached clients when scoped fetch cache key matches', async () => {
const manager = new McpClientManager();
await manager.getRegularTools([
{
name: 'shared',
url: 'https://shared.example.com/',
cacheKey: 'registry-connection:1',
},
]);
await manager.getRegularTools([
{
name: 'shared',
url: 'https://shared.example.com/',
cacheKey: 'registry-connection:1',
},
]);
expect(mockedMcpClient).toHaveBeenCalledTimes(1);
});
});
describe('concurrent dedup', () => {

View File

@ -33,7 +33,12 @@ function buildNativeMcpConfigs(configs: McpServerConfig[]): NativeMcpServerConfi
const servers: NativeMcpServerConfig[] = [];
for (const server of configs) {
if (server.url) {
servers.push({ name: server.name, url: server.url });
servers.push({
name: server.name,
url: server.url,
transport: server.transport,
fetch: server.fetch,
});
} else if (server.command) {
servers.push({
name: server.name,
@ -191,7 +196,7 @@ export class McpClientManager {
logger: Logger | undefined,
source: string,
): Promise<McpToolRegistry> {
const client = new McpClient(buildNativeMcpConfigs(configs));
const client = new McpClient(buildNativeMcpConfigs(configs), true);
this.clientsByKey.set(clientKey, client);
const registry = toolsToRegistry(await client.listTools());

View File

@ -909,9 +909,17 @@ export type CheckpointSettleResult =
export interface McpServerConfig {
name: string;
url?: string;
transport?: 'sse' | 'streamableHttp';
command?: string;
args?: string[];
env?: Record<string, string>;
fetch?: typeof fetch;
/**
* Optional cache discriminator used by `McpClientManager` when a server's
* connection behavior depends on runtime context (for example, per-user auth
* in a custom `fetch` implementation).
*/
cacheKey?: string;
}
// ── Memory ───────────────────────────────────────────────────────────────────

View File

@ -4,7 +4,7 @@ import { mock } from 'jest-mock-extended';
import type { OauthService } from '@/oauth/oauth.service';
import { buildMcpClientForServer, createAuthFetch, mapApprovalToSdk } from '../mcp-client-factory';
import { buildMcpClientForServer, mapApprovalToSdk } from '../mcp-client-factory';
// ---------------------------------------------------------------------------
// Module mocks
@ -63,69 +63,6 @@ describe('mapApprovalToSdk', () => {
});
});
// ---------------------------------------------------------------------------
// createAuthFetch
// ---------------------------------------------------------------------------
describe('createAuthFetch', () => {
beforeEach(() => {
proxyFetchMock.mockReset();
});
it('routes through proxyFetch and injects the initial headers', async () => {
proxyFetchMock.mockResolvedValueOnce(makeOk());
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
const res = await fetchFn('https://example.test/mcp');
expect(res.status).toBe(200);
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit];
expect(init.headers).toMatchObject({ Authorization: 'Bearer A' });
});
it('returns 401 unchanged when no onUnauthorized handler is configured', async () => {
proxyFetchMock.mockResolvedValueOnce(make401());
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
const res = await fetchFn('https://example.test/mcp');
expect(res.status).toBe(401);
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
});
it('returns the original 401 when onUnauthorized returns null', async () => {
proxyFetchMock.mockResolvedValueOnce(make401());
const onUnauthorized = jest.fn().mockResolvedValue(null);
const fetchFn = createAuthFetch({
initialHeaders: { Authorization: 'Bearer A' },
onUnauthorized,
});
const res = await fetchFn('https://example.test/mcp');
expect(res.status).toBe(401);
expect(onUnauthorized).toHaveBeenCalledTimes(1);
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
});
it('retries once with refreshed headers when onUnauthorized returns new headers', async () => {
proxyFetchMock.mockResolvedValueOnce(make401()).mockResolvedValueOnce(makeOk());
const onUnauthorized = jest.fn().mockResolvedValue({ Authorization: 'Bearer B' });
const fetchFn = createAuthFetch({
initialHeaders: { Authorization: 'Bearer A' },
onUnauthorized,
});
const res = await fetchFn('https://example.test/mcp');
expect(res.status).toBe(200);
expect(proxyFetchMock).toHaveBeenCalledTimes(2);
const [, init2] = proxyFetchMock.mock.calls[1] as [unknown, RequestInit];
expect(init2.headers).toMatchObject({ Authorization: 'Bearer B' });
});
});
// ---------------------------------------------------------------------------
// buildMcpClientForServer — header derivation per auth type
// ---------------------------------------------------------------------------
@ -314,56 +251,6 @@ describe('buildMcpClientForServer — SDK config mapping', () => {
});
});
// ---------------------------------------------------------------------------
// createAuthFetch — header merging and stateful refresh
// ---------------------------------------------------------------------------
describe('createAuthFetch — header merging', () => {
beforeEach(() => {
proxyFetchMock.mockReset();
proxyFetchMock.mockResolvedValue(new Response('ok', { status: 200 }));
});
it('merges caller-supplied init.headers with auth headers (auth takes precedence)', async () => {
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
await fetchFn('https://example.test/mcp', { headers: { 'X-Custom': 'value' } });
const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit];
expect(init.headers).toMatchObject({
'X-Custom': 'value',
Authorization: 'Bearer A',
});
});
it('uses the refreshed headers on the second call after a successful 401 refresh', async () => {
proxyFetchMock
.mockResolvedValueOnce(new Response('unauthorized', { status: 401 }))
.mockResolvedValueOnce(new Response('ok', { status: 200 }))
.mockResolvedValueOnce(new Response('ok', { status: 200 }));
let callCount = 0;
const onUnauthorized = jest.fn().mockImplementation(async () => {
callCount++;
return { Authorization: `Bearer refreshed-${callCount}` };
});
const fetchFn = createAuthFetch({
initialHeaders: { Authorization: 'Bearer stale' },
onUnauthorized,
});
// First call triggers a 401 → refresh → retry
await fetchFn('https://example.test/mcp');
// Second call should use the refreshed headers without triggering another refresh
await fetchFn('https://example.test/mcp');
expect(onUnauthorized).toHaveBeenCalledTimes(1);
const [, thirdInit] = proxyFetchMock.mock.calls[2] as [unknown, RequestInit];
expect((thirdInit.headers as Record<string, string>).Authorization).toBe('Bearer refreshed-1');
});
});
// ---------------------------------------------------------------------------
// buildMcpClientForServer — auth header edge cases
// ---------------------------------------------------------------------------

View File

@ -1,9 +1,9 @@
import { proxyFetch } from '@n8n/ai-utilities';
import type { CredentialProvider, McpClient, McpServerConfig } from '@n8n/agents';
import type { AgentJsonMcpServerConfig } from '@n8n/api-types';
import { isMcpOAuth2Authentication } from 'n8n-workflow';
import type { OauthService } from '@/oauth/oauth.service';
import { createAuthFetch } from '@/utils/auth-fetch';
/**
* Convert the JSON-config `approval` shape into the SDK's `requireApproval`
@ -91,59 +91,6 @@ async function deriveAuthHeaders(
}
}
interface CreateAuthFetchOptions {
initialHeaders: Record<string, string>;
/**
* Called on a 401 response. Should return a fresh set of auth headers, or
* `null` if the refresh failed. The returned headers replace the cached
* set used by subsequent requests.
*/
onUnauthorized?: () => Promise<Record<string, string> | null>;
}
function headersToRecord(headers: HeadersInit | undefined): Record<string, string> {
if (!headers) return {};
if (headers instanceof Headers) return Object.fromEntries(headers.entries());
if (Array.isArray(headers)) return Object.fromEntries(headers);
return headers;
}
/**
* Build a fetch wrapper that:
* 1. routes through n8n's `proxyFetch` (so corporate HTTP_PROXY settings
* apply uniformly),
* 2. injects the latest auth headers on every request,
* 3. on a single 401, calls `onUnauthorized` to refresh the token and
* retries the request once with the new headers.
*
* This mirrors the langchain MCP node's `createAuthFetch` so an agent's MCP
* connection behaves identically to one configured via the workflow editor.
*/
export function createAuthFetch({
initialHeaders,
onUnauthorized,
}: CreateAuthFetchOptions): typeof fetch {
let headers = initialHeaders;
return async (input: RequestInfo | URL, init?: RequestInit): Promise<Response> => {
const response = await proxyFetch(input, {
...init,
headers: { ...headersToRecord(init?.headers), ...headers },
});
if (response.status !== 401 || !onUnauthorized) return response;
const refreshed = await onUnauthorized();
if (!refreshed) return response;
headers = refreshed;
return await proxyFetch(input, {
...init,
headers: { ...headersToRecord(init?.headers), ...headers },
});
};
}
export interface BuildMcpClientDeps {
credentialProvider: CredentialProvider;
/**

View File

@ -7,6 +7,7 @@ export { InstanceAiCheckpoint } from './instance-ai-checkpoint.entity';
export { InstanceAiObservation } from './instance-ai-observation.entity';
export { InstanceAiObservationCursor } from './instance-ai-observation-cursor.entity';
export { InstanceAiObservationLock } from './instance-ai-observation-lock.entity';
export { InstanceAiMcpRegistryConnection } from './instance-ai-mcp-registry-connection.entity';
export type {
InstanceAiObservationMarker,
InstanceAiObservationStatus,

View File

@ -0,0 +1,26 @@
import { JsonColumn, WithTimestamps } from '@n8n/db';
import { Column, Entity, Index, PrimaryColumn } from '@n8n/typeorm';
export type InstanceAiMcpToolFilter = {
mode: 'allow' | 'exclude';
tools: string[];
};
@Entity({ name: 'instance_ai_mcp_registry_connections' })
@Index(['userId', 'serverSlug', 'credentialId'], { unique: true })
export class InstanceAiMcpRegistryConnection extends WithTimestamps {
@PrimaryColumn('uuid')
id: string;
@Column({ type: 'uuid' })
userId: string;
@Column({ type: 'varchar', length: 255 })
serverSlug: string;
@Column({ type: 'varchar', length: 36 })
credentialId: string;
@JsonColumn({ nullable: true })
toolFilter: InstanceAiMcpToolFilter | null;
}

View File

@ -67,6 +67,9 @@ export class InstanceAiModule implements ModuleInterface {
const { InstanceAiObservationLock } = await import(
'./entities/instance-ai-observation-lock.entity'
);
const { InstanceAiMcpRegistryConnection } = await import(
'./entities/instance-ai-mcp-registry-connection.entity'
);
return [
InstanceAiThread,
@ -79,6 +82,7 @@ export class InstanceAiModule implements ModuleInterface {
InstanceAiObservation,
InstanceAiObservationCursor,
InstanceAiObservationLock,
InstanceAiMcpRegistryConnection,
];
}

View File

@ -114,6 +114,7 @@ import { InstanceAiPendingConfirmationRepository } from './repositories/instance
import { InstanceAiThreadRepository } from './repositories/instance-ai-thread.repository';
import { TraceReplayState } from './trace-replay-state';
import { INSTANCE_AI_RUN_TIMEOUT_REASON, InstanceAiLivenessService } from './liveness';
import { InstanceAiMcpRegistryService } from './mcp';
import {
buildInstanceAiRunTraceMetadata,
type InstanceAiRunTraceMetadataOptions,
@ -639,6 +640,7 @@ export class InstanceAiService {
private readonly dbIterationLogStorage: DbIterationLogStorage,
private readonly sourceControlPreferencesService: SourceControlPreferencesService,
private readonly telemetry: Telemetry,
private readonly mcpRegistryService: InstanceAiMcpRegistryService,
private readonly userRepository: UserRepository,
private readonly aiBuilderTemporaryWorkflowRepository: AiBuilderTemporaryWorkflowRepository,
private readonly errorReporter: ErrorReporter,
@ -3484,7 +3486,9 @@ export class InstanceAiService {
return;
}
const mcpServers = this.parseMcpServers(this.instanceAiConfig.mcpServers);
const staticMcpServers = this.parseMcpServers(this.instanceAiConfig.mcpServers);
const registryMcpServers = await this.mcpRegistryService.getRegistryMcpServers(user);
const mcpServers = [...staticMcpServers, ...registryMcpServers];
const executionPushRef = this.threadPushRef.get(threadId);
const environment = await this.createExecutionEnvironment(

View File

@ -0,0 +1,304 @@
import type { Logger } from '@n8n/backend-common';
import type { CredentialsEntity, User } from '@n8n/db';
import { mock } from 'jest-mock-extended';
import type { CredentialsFinderService } from '@/credentials/credentials-finder.service';
import type { CredentialsService } from '@/credentials/credentials.service';
import type { McpRegistryService } from '@/modules/mcp-registry/registry/mcp-registry.service';
import type { McpRegistryServer } from '@/modules/mcp-registry/registry/mcp-registry.types';
import type { OauthService } from '@/oauth/oauth.service';
import type { InstanceAiMcpRegistryConnectionRepository } from '../../repositories/instance-ai-mcp-registry-connection.repository';
import { InstanceAiMcpRegistryService } from '../instance-ai-mcp-registry.service';
import type { InstanceAiMcpRegistryConnection } from '../../entities/instance-ai-mcp-registry-connection.entity';
const proxyFetchMock = jest.fn();
jest.mock('@n8n/ai-utilities', () => ({
proxyFetch: (...args: unknown[]) => proxyFetchMock(...args),
}));
function makeRegistryServer(
slug: string,
overrides: Partial<McpRegistryServer> = {},
): McpRegistryServer {
return {
name: `com.test/${slug}`,
slug,
title: slug,
description: `${slug} description`,
tagline: `${slug} tagline`,
version: '1.0.0',
updatedAt: '2026-05-01T00:00:00.000Z',
icons: [],
authType: 'oauth2',
remotes: [{ type: 'streamable-http', url: `https://${slug}.example.com/mcp` }],
tools: [],
isOfficial: true,
origin: 'registry',
status: 'active',
...overrides,
};
}
describe('InstanceAiMcpRegistryService', () => {
const user = { id: 'user-1' } as User;
const credential = {
id: 'cred-1',
name: 'MCP OAuth2',
type: 'mcpOAuth2Api',
shared: [{ role: 'credential:owner', projectId: 'project-1' }],
} as CredentialsEntity;
const oauthCredentialData = {
clientId: 'client-id',
clientSecret: 'client-secret',
accessTokenUrl: 'https://auth.example.com/token',
oauthTokenData: {
access_token: 'stale-token',
refresh_token: 'refresh-token',
},
};
function createService() {
const logger = mock<Logger>({ scoped: jest.fn().mockReturnThis() });
const connectionRepository = mock<InstanceAiMcpRegistryConnectionRepository>();
const mcpRegistryService = mock<McpRegistryService>();
const credentialsFinderService = mock<CredentialsFinderService>();
const credentialsService = mock<CredentialsService>();
const oauthService = mock<OauthService>();
const service = new InstanceAiMcpRegistryService(
logger,
connectionRepository,
mcpRegistryService,
credentialsFinderService,
credentialsService,
oauthService,
);
return {
service,
logger,
connectionRepository,
mcpRegistryService,
credentialsFinderService,
credentialsService,
oauthService,
};
}
beforeEach(() => {
jest.clearAllMocks();
proxyFetchMock.mockReset();
});
it('returns empty list when the user has no registry connections', async () => {
const { service, connectionRepository, mcpRegistryService } = createService();
connectionRepository.findBy.mockResolvedValue([]);
const result = await service.getRegistryMcpServers(user);
expect(result).toEqual([]);
expect(mcpRegistryService.getBySlugs).not.toHaveBeenCalled();
});
it('resolves servers with deterministic names and preferred transport', async () => {
const {
service,
connectionRepository,
mcpRegistryService,
credentialsFinderService,
credentialsService,
} = createService();
const credentialsById: Record<string, CredentialsEntity> = {
'cred-1': { id: 'cred-1', name: 'MCP OAuth2 #1', type: 'mcpOAuth2Api' } as CredentialsEntity,
'cred-2': { id: 'cred-2', name: 'MCP OAuth2 #2', type: 'mcpOAuth2Api' } as CredentialsEntity,
'cred-3': { id: 'cred-3', name: 'MCP OAuth2 #3', type: 'mcpOAuth2Api' } as CredentialsEntity,
};
connectionRepository.findBy.mockResolvedValue([
{ id: '2', userId: user.id, serverSlug: 'linear', credentialId: 'cred-2' },
{ id: '1', userId: user.id, serverSlug: 'linear', credentialId: 'cred-1' },
{ id: '3', userId: user.id, serverSlug: 'notion', credentialId: 'cred-3' },
] as InstanceAiMcpRegistryConnection[]);
mcpRegistryService.getBySlugs.mockResolvedValue([
makeRegistryServer('linear', {
remotes: [
{ type: 'sse', url: 'https://linear.example.com/sse' },
{ type: 'streamable-http', url: 'https://linear.example.com/mcp' },
],
}),
makeRegistryServer('notion', {
remotes: [{ type: 'sse', url: 'https://notion.example.com/sse' }],
}),
]);
credentialsFinderService.findCredentialForUser.mockImplementation(async (credentialId) => {
return credentialsById[credentialId] ?? null;
});
credentialsService.decrypt.mockResolvedValue(oauthCredentialData);
const result = await service.getRegistryMcpServers(user);
expect(result).toHaveLength(3);
expect(result[0]).toEqual(
expect.objectContaining({
name: 'mcp_linear',
url: 'https://linear.example.com/mcp',
transport: 'streamableHttp',
cacheKey: 'registry-connection:1',
fetch: expect.any(Function),
}),
);
expect(result[1]).toEqual(
expect.objectContaining({
name: 'mcp_linear_2',
url: 'https://linear.example.com/mcp',
transport: 'streamableHttp',
cacheKey: 'registry-connection:2',
fetch: expect.any(Function),
}),
);
expect(result[2]).toEqual(
expect.objectContaining({
name: 'mcp_notion',
url: 'https://notion.example.com/sse',
transport: 'sse',
cacheKey: 'registry-connection:3',
fetch: expect.any(Function),
}),
);
expect(credentialsFinderService.findCredentialForUser).toHaveBeenCalledWith('cred-1', user, [
'credential:read',
]);
expect(credentialsFinderService.findCredentialForUser).toHaveBeenCalledWith('cred-2', user, [
'credential:read',
]);
expect(credentialsFinderService.findCredentialForUser).toHaveBeenCalledWith('cred-3', user, [
'credential:read',
]);
});
it('skips connections with missing server slugs or unsupported remotes', async () => {
const { service, connectionRepository, mcpRegistryService, logger } = createService();
connectionRepository.findBy.mockResolvedValue([
{ id: '1', userId: user.id, serverSlug: 'missing', credentialId: credential.id },
{ id: '2', userId: user.id, serverSlug: 'bad-remote', credentialId: credential.id },
] as InstanceAiMcpRegistryConnection[]);
mcpRegistryService.getBySlugs.mockResolvedValue([
makeRegistryServer('bad-remote', { remotes: [] }),
]);
const result = await service.getRegistryMcpServers(user);
expect(result).toEqual([]);
expect(logger.warn).toHaveBeenCalledWith(
'Skipping MCP registry connection with missing server slug',
expect.objectContaining({ connectionId: '1', serverSlug: 'missing', userId: user.id }),
);
expect(logger.warn).toHaveBeenCalledWith(
'Skipping MCP registry connection without supported remote transport',
expect.objectContaining({ connectionId: '2', serverSlug: 'bad-remote' }),
);
});
it('does not attach custom fetch for non-oauth servers', async () => {
const {
service,
connectionRepository,
mcpRegistryService,
credentialsFinderService,
credentialsService,
} = createService();
connectionRepository.findBy.mockResolvedValue([
{ id: '1', userId: user.id, serverSlug: 'public-server', credentialId: credential.id },
] as InstanceAiMcpRegistryConnection[]);
mcpRegistryService.getBySlugs.mockResolvedValue([
makeRegistryServer('public-server', {
// currently only oauth2 is supported
// so we need to cast it to test this behavior
authType: 'none' as unknown as 'oauth2',
}),
]);
const [server] = await service.getRegistryMcpServers(user);
expect(server).toEqual(
expect.objectContaining({
name: 'mcp_public-server',
url: 'https://public-server.example.com/mcp',
transport: 'streamableHttp',
cacheKey: 'registry-connection:1',
}),
);
expect(server.fetch).toBeUndefined();
expect(credentialsFinderService.findCredentialForUser).not.toHaveBeenCalled();
expect(credentialsService.decrypt).not.toHaveBeenCalled();
});
it('adds auth header and retries once with refreshed OAuth token after 401', async () => {
const {
service,
connectionRepository,
mcpRegistryService,
credentialsFinderService,
credentialsService,
oauthService,
} = createService();
connectionRepository.findBy.mockResolvedValue([
{ id: '1', userId: user.id, serverSlug: 'linear', credentialId: credential.id },
] as InstanceAiMcpRegistryConnection[]);
mcpRegistryService.getBySlugs.mockResolvedValue([makeRegistryServer('linear')]);
credentialsFinderService.findCredentialForUser.mockResolvedValue(credential);
credentialsService.decrypt.mockResolvedValue(oauthCredentialData);
proxyFetchMock
.mockResolvedValueOnce(new Response('unauthorized', { status: 401 }))
.mockResolvedValueOnce(new Response('ok', { status: 200 }));
oauthService.refreshOAuth2CredentialById.mockResolvedValue({
Authorization: 'Bearer fresh-token',
});
const [server] = await service.getRegistryMcpServers(user);
const response = await server.fetch?.('https://linear.example.com/mcp');
expect(response?.status).toBe(200);
expect(proxyFetchMock).toHaveBeenCalledTimes(2);
const [, firstInit] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit];
const [, secondInit] = proxyFetchMock.mock.calls[1] as [unknown, RequestInit];
expect(new Headers(firstInit.headers).get('Authorization')).toBe('Bearer stale-token');
expect(new Headers(secondInit.headers).get('Authorization')).toBe('Bearer fresh-token');
expect(oauthService.refreshOAuth2CredentialById).toHaveBeenCalledWith(
credential.id,
'project-1',
);
});
it('returns original 401 response when token refresh fails', async () => {
const {
service,
connectionRepository,
mcpRegistryService,
credentialsFinderService,
credentialsService,
oauthService,
} = createService();
connectionRepository.findBy.mockResolvedValue([
{ id: '1', userId: user.id, serverSlug: 'linear', credentialId: credential.id },
] as InstanceAiMcpRegistryConnection[]);
mcpRegistryService.getBySlugs.mockResolvedValue([makeRegistryServer('linear')]);
credentialsFinderService.findCredentialForUser.mockResolvedValue(credential);
credentialsService.decrypt.mockResolvedValue(oauthCredentialData);
oauthService.refreshOAuth2CredentialById.mockResolvedValue(null);
proxyFetchMock.mockResolvedValue(new Response('unauthorized', { status: 401 }));
const [server] = await service.getRegistryMcpServers(user);
const response = await server.fetch?.('https://linear.example.com/mcp');
expect(response?.status).toBe(401);
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
expect(oauthService.refreshOAuth2CredentialById).toHaveBeenCalledWith(
credential.id,
'project-1',
);
});
});

View File

@ -0,0 +1 @@
export * from './instance-ai-mcp-registry.service';

View File

@ -0,0 +1,266 @@
import { isObjectLiteral, Logger } from '@n8n/backend-common';
import type { CredentialsEntity, User } from '@n8n/db';
import { Service } from '@n8n/di';
import type { McpServerConfig } from '@n8n/instance-ai';
import type { ICredentialDataDecryptedObject } from 'n8n-workflow';
import { CredentialsFinderService } from '@/credentials/credentials-finder.service';
import { CredentialsService } from '@/credentials/credentials.service';
import { McpRegistryService } from '@/modules/mcp-registry/registry/mcp-registry.service';
import type { McpRegistryRemote } from '@/modules/mcp-registry/registry/mcp-registry.types';
import { OauthService } from '@/oauth/oauth.service';
import { createAuthFetch } from '@/utils/auth-fetch';
import { InstanceAiMcpRegistryConnectionRepository } from '../repositories/instance-ai-mcp-registry-connection.repository';
type Transport = 'sse' | 'streamableHttp';
interface ResolvedRegistryServer {
serverSlug: string;
credentialId: string;
authType: string;
endpointUrl: string;
transport: Transport;
}
interface OAuth2FetchContext {
credentialId: string;
accessToken: string;
projectId: string;
}
function readString(data: Record<string, unknown>, key: string): string | undefined {
const value = data[key];
return typeof value === 'string' && value.length > 0 ? value : undefined;
}
function readAccessToken(tokenData: Record<string, unknown>): string | undefined {
return readString(tokenData, 'accessToken') ?? readString(tokenData, 'access_token');
}
function readOAuthTokenData(data: ICredentialDataDecryptedObject): Record<string, unknown> | null {
const tokenData = data.oauthTokenData;
return isObjectLiteral(tokenData) ? tokenData : null;
}
function getPreferredRemote(remotes: McpRegistryRemote[]): {
transport: Transport;
endpointUrl: string;
} | null {
const streamable = remotes.find((remote) => remote.type === 'streamable-http');
if (streamable?.url) {
return { transport: 'streamableHttp', endpointUrl: streamable.url };
}
const sse = remotes.find((remote) => remote.type === 'sse');
if (sse?.url) {
return { transport: 'sse', endpointUrl: sse.url };
}
return null;
}
const MCP_REGISTRY_SERVER_PREFIX = 'mcp_';
const MAX_MCP_SERVER_NAME_LENGTH = 24;
function buildServerName(serverSlug: string, sequence: number): string {
const safeSlug = serverSlug.replace(/[^A-Za-z0-9_-]/g, '_');
const baseName = `${MCP_REGISTRY_SERVER_PREFIX}${safeSlug}`;
if (sequence <= 1) {
return baseName.slice(0, MAX_MCP_SERVER_NAME_LENGTH);
}
const suffix = `_${sequence}`;
const maxBaseLength = Math.max(0, MAX_MCP_SERVER_NAME_LENGTH - suffix.length);
return `${baseName.slice(0, maxBaseLength)}${suffix}`;
}
@Service()
export class InstanceAiMcpRegistryService {
private readonly logger: Logger;
constructor(
logger: Logger,
private readonly connectionRepository: InstanceAiMcpRegistryConnectionRepository,
private readonly mcpRegistryService: McpRegistryService,
private readonly credentialsFinderService: CredentialsFinderService,
private readonly credentialsService: CredentialsService,
private readonly oauthService: OauthService,
) {
this.logger = logger.scoped('instance-ai');
}
async getRegistryMcpServers(user: User): Promise<McpServerConfig[]> {
const connections = await this.connectionRepository.findBy({ userId: user.id });
if (connections.length === 0) {
return [];
}
const sortedConnections = connections.sort((left, right) => left.id.localeCompare(right.id));
const slugs = [...new Set(sortedConnections.map((connection) => connection.serverSlug))];
const servers = await this.mcpRegistryService.getBySlugs(slugs);
const serverBySlug = new Map(servers.map((server) => [server.slug, server]));
const slugCounts = new Map<string, number>();
const resolved: McpServerConfig[] = [];
for (const connection of sortedConnections) {
const server = serverBySlug.get(connection.serverSlug);
if (!server) {
this.logger.warn('Skipping MCP registry connection with missing server slug', {
connectionId: connection.id,
serverSlug: connection.serverSlug,
userId: user.id,
});
continue;
}
const resolvedServer = this.resolveRegistryServer(
connection.id,
connection.serverSlug,
connection.credentialId,
server.authType,
server.remotes,
);
if (!resolvedServer) {
continue;
}
const nextCount = (slugCounts.get(resolvedServer.serverSlug) ?? 0) + 1;
slugCounts.set(resolvedServer.serverSlug, nextCount);
const serverConfig: McpServerConfig = {
name: buildServerName(resolvedServer.serverSlug, nextCount),
url: resolvedServer.endpointUrl,
transport: resolvedServer.transport,
cacheKey: `registry-connection:${connection.id}`,
};
if (resolvedServer.authType === 'oauth2') {
const oauth2FetchContext = await this.buildOAuth2FetchContext(
resolvedServer,
user,
connection.id,
);
if (!oauth2FetchContext) {
continue;
}
serverConfig.fetch = createAuthFetch({
initialHeaders: { Authorization: `Bearer ${oauth2FetchContext.accessToken}` },
onUnauthorized: async () => {
if (!oauth2FetchContext.projectId) {
return null;
}
return await this.oauthService.refreshOAuth2CredentialById(
oauth2FetchContext.credentialId,
oauth2FetchContext.projectId,
);
},
});
}
resolved.push(serverConfig);
}
return resolved;
}
private resolveRegistryServer(
connectionId: string,
serverSlug: string,
credentialId: string,
authType: string,
remotes: McpRegistryRemote[],
): ResolvedRegistryServer | null {
const remote = getPreferredRemote(remotes);
if (!remote) {
this.logger.warn('Skipping MCP registry connection without supported remote transport', {
connectionId,
serverSlug,
credentialId,
});
return null;
}
return {
serverSlug,
credentialId,
authType,
endpointUrl: remote.endpointUrl,
transport: remote.transport,
};
}
private async buildOAuth2FetchContext(
config: ResolvedRegistryServer,
user: User,
connectionId: string,
): Promise<OAuth2FetchContext | null> {
const credentialWithData = await this.getCredentialWithData(config.credentialId, user);
if (!credentialWithData) {
this.logger.warn('Skipping MCP registry connection with inaccessible credential', {
connectionId,
serverSlug: config.serverSlug,
credentialId: config.credentialId,
userId: user.id,
});
return null;
}
const tokenData = readOAuthTokenData(credentialWithData.data);
if (!tokenData) {
this.logger.warn('Skipping MCP registry connection without OAuth2 token data', {
connectionId,
serverSlug: config.serverSlug,
credentialId: config.credentialId,
});
return null;
}
const accessToken = readAccessToken(tokenData);
if (!accessToken) {
this.logger.warn('Skipping MCP registry connection without access token', {
connectionId,
serverSlug: config.serverSlug,
credentialId: config.credentialId,
});
return null;
}
const projectId = credentialWithData.credential.shared?.[0]?.projectId ?? null;
if (!projectId) {
this.logger.warn('Skipping OAuth2 token refresh for credential without project sharing', {
connectionId,
serverSlug: config.serverSlug,
credentialId: config.credentialId,
});
}
return {
credentialId: config.credentialId,
accessToken,
projectId,
};
}
private async getCredentialWithData(
credentialId: string,
user: User,
): Promise<{ credential: CredentialsEntity; data: ICredentialDataDecryptedObject } | null> {
const credential = await this.credentialsFinderService.findCredentialForUser(
credentialId,
user,
['credential:read'],
);
if (!credential) {
return null;
}
const data = await this.credentialsService.decrypt(credential, true);
if (!isObjectLiteral(data) || Object.keys(data).length === 0) {
return null;
}
return { credential, data };
}
}

View File

@ -7,3 +7,4 @@ export { InstanceAiCheckpointRepository } from './instance-ai-checkpoint.reposit
export { InstanceAiObservationRepository } from './instance-ai-observation.repository';
export { InstanceAiObservationCursorRepository } from './instance-ai-observation-cursor.repository';
export { InstanceAiObservationLockRepository } from './instance-ai-observation-lock.repository';
export { InstanceAiMcpRegistryConnectionRepository } from './instance-ai-mcp-registry-connection.repository';

View File

@ -0,0 +1,11 @@
import { Service } from '@n8n/di';
import { DataSource, Repository } from '@n8n/typeorm';
import { InstanceAiMcpRegistryConnection } from '../entities/instance-ai-mcp-registry-connection.entity';
@Service()
export class InstanceAiMcpRegistryConnectionRepository extends Repository<InstanceAiMcpRegistryConnection> {
constructor(dataSource: DataSource) {
super(InstanceAiMcpRegistryConnection, dataSource.manager);
}
}

View File

@ -131,6 +131,24 @@ describe('McpRegistryService', () => {
expect(notion).toEqual(notionMockServer);
expect(missing).toBeUndefined();
});
it('returns empty array for getBySlugs when input is empty', async () => {
const { service, repository } = createService();
const servers = await service.getBySlugs([]);
expect(servers).toEqual([]);
expect(repository.findBy).not.toHaveBeenCalled();
});
it('returns mapped servers for getBySlugs', async () => {
const { service, repository } = createService();
const servers = await service.getBySlugs(['notion', 'linear']);
expect(repository.findBy).toHaveBeenCalledWith([{ slug: 'notion' }, { slug: 'linear' }]);
expect(servers).toEqual([notionMockServer, linearMockServer]);
});
});
describe('refresh flow', () => {

View File

@ -90,6 +90,15 @@ export class McpRegistryService {
return entity ? fromEntity(entity) : undefined;
}
async getBySlugs(slugs: string[]): Promise<McpRegistryServer[]> {
if (slugs.length === 0) {
return [];
}
const entities = await this.repository.findBy(slugs.map((slug) => ({ slug })));
return entities.map(fromEntity);
}
private startPeriodicRefresh(): void {
if (this.isShuttingDown || this.refreshInterval) {
return;

View File

@ -0,0 +1,119 @@
import { createAuthFetch } from '@/utils/auth-fetch';
const proxyFetchMock = jest.fn();
jest.mock('@n8n/ai-utilities', () => ({
proxyFetch: (...args: unknown[]) => proxyFetchMock(...args),
}));
function makeOk(): Response {
return new Response('ok', { status: 200 });
}
function make401(): Response {
return new Response('unauthorized', { status: 401 });
}
describe('createAuthFetch', () => {
beforeEach(() => {
proxyFetchMock.mockReset();
});
it('routes through proxyFetch and injects the initial headers', async () => {
proxyFetchMock.mockResolvedValueOnce(makeOk());
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
const res = await fetchFn('https://example.test/mcp');
expect(res.status).toBe(200);
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit];
expect(init.headers).toMatchObject({ Authorization: 'Bearer A' });
});
it('returns 401 unchanged when no onUnauthorized handler is configured', async () => {
proxyFetchMock.mockResolvedValueOnce(make401());
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
const res = await fetchFn('https://example.test/mcp');
expect(res.status).toBe(401);
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
});
it('returns the original 401 when onUnauthorized returns null', async () => {
proxyFetchMock.mockResolvedValueOnce(make401());
const onUnauthorized = jest.fn().mockResolvedValue(null);
const fetchFn = createAuthFetch({
initialHeaders: { Authorization: 'Bearer A' },
onUnauthorized,
});
const res = await fetchFn('https://example.test/mcp');
expect(res.status).toBe(401);
expect(onUnauthorized).toHaveBeenCalledTimes(1);
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
});
it('retries once with refreshed headers when onUnauthorized returns new headers', async () => {
proxyFetchMock.mockResolvedValueOnce(make401()).mockResolvedValueOnce(makeOk());
const onUnauthorized = jest.fn().mockResolvedValue({ Authorization: 'Bearer B' });
const fetchFn = createAuthFetch({
initialHeaders: { Authorization: 'Bearer A' },
onUnauthorized,
});
const res = await fetchFn('https://example.test/mcp');
expect(res.status).toBe(200);
expect(proxyFetchMock).toHaveBeenCalledTimes(2);
const [, init2] = proxyFetchMock.mock.calls[1] as [unknown, RequestInit];
expect(init2.headers).toMatchObject({ Authorization: 'Bearer B' });
});
});
describe('createAuthFetch — header merging', () => {
beforeEach(() => {
proxyFetchMock.mockReset();
proxyFetchMock.mockResolvedValue(new Response('ok', { status: 200 }));
});
it('merges caller-supplied init.headers with auth headers (auth takes precedence)', async () => {
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
await fetchFn('https://example.test/mcp', { headers: { 'X-Custom': 'value' } });
const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit];
expect(init.headers).toMatchObject({
'X-Custom': 'value',
Authorization: 'Bearer A',
});
});
it('uses the refreshed headers on the second call after a successful 401 refresh', async () => {
proxyFetchMock
.mockResolvedValueOnce(new Response('unauthorized', { status: 401 }))
.mockResolvedValueOnce(new Response('ok', { status: 200 }))
.mockResolvedValueOnce(new Response('ok', { status: 200 }));
let callCount = 0;
const onUnauthorized = jest.fn().mockImplementation(async () => {
callCount++;
return { Authorization: `Bearer refreshed-${callCount}` };
});
const fetchFn = createAuthFetch({
initialHeaders: { Authorization: 'Bearer stale' },
onUnauthorized,
});
// First call triggers a 401 → refresh → retry
await fetchFn('https://example.test/mcp');
// Second call should use the refreshed headers without triggering another refresh
await fetchFn('https://example.test/mcp');
expect(onUnauthorized).toHaveBeenCalledTimes(1);
const [, thirdInit] = proxyFetchMock.mock.calls[2] as [unknown, RequestInit];
expect((thirdInit.headers as Record<string, string>).Authorization).toBe('Bearer refreshed-1');
});
});

View File

@ -0,0 +1,54 @@
import { proxyFetch } from '@n8n/ai-utilities';
interface CreateAuthFetchOptions {
initialHeaders: Record<string, string>;
/**
* Called on a 401 response. Should return a fresh set of auth headers, or
* `null` if the refresh failed. The returned headers replace the cached
* set used by subsequent requests.
*/
onUnauthorized?: () => Promise<Record<string, string> | null>;
}
function headersToRecord(headers: HeadersInit | undefined): Record<string, string> {
if (!headers) return {};
if (headers instanceof Headers) return Object.fromEntries(headers.entries());
if (Array.isArray(headers)) return Object.fromEntries(headers);
return headers;
}
/**
* Build a fetch wrapper that:
* 1. routes through n8n's `proxyFetch` (so corporate HTTP_PROXY settings
* apply uniformly),
* 2. injects the latest auth headers on every request,
* 3. on a single 401, calls `onUnauthorized` to refresh the token and
* retries the request once with the new headers.
*
* This mirrors the langchain MCP node's `createAuthFetch` so an agent's MCP
* connection behaves identically to one configured via the workflow editor.
*/
export function createAuthFetch({
initialHeaders,
onUnauthorized,
}: CreateAuthFetchOptions): typeof fetch {
let headers = initialHeaders;
return async (input: RequestInfo | URL, init?: RequestInit): Promise<Response> => {
const response = await proxyFetch(input, {
...init,
headers: { ...headersToRecord(init?.headers), ...headers },
});
if (response.status !== 401 || !onUnauthorized) return response;
const refreshed = await onUnauthorized();
if (!refreshed) return response;
headers = refreshed;
return await proxyFetch(input, {
...init,
headers: { ...headersToRecord(init?.headers), ...headers },
});
};
}