mirror of
https://github.com/n8n-io/n8n.git
synced 2026-06-03 18:27:09 +02:00
feat(core): Add support for per-user connections to MCP servers from the registry in instance AI (#31325)
Some checks are pending
Build: Benchmark Image / build (push) Waiting to run
CI: Master (Build, Test, Lint) / Build for Github Cache (push) Waiting to run
CI: Master (Build, Test, Lint) / Unit tests (22.22.3) (push) Waiting to run
CI: Master (Build, Test, Lint) / Unit tests (24.15.0) (push) Waiting to run
CI: Master (Build, Test, Lint) / Lint (push) Waiting to run
CI: Master (Build, Test, Lint) / Performance (push) Waiting to run
CI: Master (Build, Test, Lint) / Notify Slack on failure (push) Blocked by required conditions
Util: Sync API Docs / sync-public-api (push) Waiting to run
Some checks are pending
Build: Benchmark Image / build (push) Waiting to run
CI: Master (Build, Test, Lint) / Build for Github Cache (push) Waiting to run
CI: Master (Build, Test, Lint) / Unit tests (22.22.3) (push) Waiting to run
CI: Master (Build, Test, Lint) / Unit tests (24.15.0) (push) Waiting to run
CI: Master (Build, Test, Lint) / Lint (push) Waiting to run
CI: Master (Build, Test, Lint) / Performance (push) Waiting to run
CI: Master (Build, Test, Lint) / Notify Slack on failure (push) Blocked by required conditions
Util: Sync API Docs / sync-public-api (push) Waiting to run
This commit is contained in:
parent
29b1220a90
commit
ee3b277ff0
|
|
@ -0,0 +1,40 @@
|
|||
import type { MigrationContext, ReversibleMigration } from '../migration-types';
|
||||
|
||||
const table = 'instance_ai_mcp_registry_connections';
|
||||
|
||||
export class CreateInstanceAiMcpRegistryConnectionTable1784000000023
|
||||
implements ReversibleMigration
|
||||
{
|
||||
async up({ schemaBuilder: { createTable, column } }: MigrationContext) {
|
||||
await createTable(table)
|
||||
.withColumns(
|
||||
column('id').uuid.primary,
|
||||
column('credentialId').varchar(36).notNull,
|
||||
column('serverSlug').varchar(255).notNull,
|
||||
column('toolFilter').json.comment(
|
||||
'Optional MCP tool filter per registry connection: { mode: "allow" | "exclude", tools: string[] }',
|
||||
),
|
||||
column('userId').uuid.notNull,
|
||||
)
|
||||
.withIndexOn(['userId', 'serverSlug', 'credentialId'], true)
|
||||
.withForeignKey('credentialId', {
|
||||
tableName: 'credentials_entity',
|
||||
columnName: 'id',
|
||||
onDelete: 'CASCADE',
|
||||
})
|
||||
.withForeignKey('serverSlug', {
|
||||
tableName: 'mcp_registry_server',
|
||||
columnName: 'slug',
|
||||
onDelete: 'CASCADE',
|
||||
})
|
||||
.withForeignKey('userId', {
|
||||
tableName: 'user',
|
||||
columnName: 'id',
|
||||
onDelete: 'CASCADE',
|
||||
}).withTimestamps;
|
||||
}
|
||||
|
||||
async down({ schemaBuilder: { dropTable } }: MigrationContext) {
|
||||
await dropTable(table);
|
||||
}
|
||||
}
|
||||
|
|
@ -198,6 +198,7 @@ import { AddCustomTelemetryTagsToProject1784000000019 } from '../common/17840000
|
|||
import { CreateWorkflowPublicationOutboxTable1784000000020 } from '../common/1784000000020-CreateWorkflowPublicationOutboxTable';
|
||||
import { CreateAgentTaskDefinitionTable1784000000021 } from '../common/1784000000021-CreateAgentTaskDefinitionTable';
|
||||
import { AddSubAgentLinkageToAgentExecutionThreads1784000000022 } from '../common/1784000000022-AddSubAgentLinkageToAgentExecutionThreads';
|
||||
import { CreateInstanceAiMcpRegistryConnectionTable1784000000023 } from '../common/1784000000023-CreateInstanceAiMcpRegistryConnectionTable';
|
||||
import type { Migration } from '../migration-types';
|
||||
|
||||
export const postgresMigrations: Migration[] = [
|
||||
|
|
@ -401,4 +402,5 @@ export const postgresMigrations: Migration[] = [
|
|||
CreateWorkflowPublicationOutboxTable1784000000020,
|
||||
CreateAgentTaskDefinitionTable1784000000021,
|
||||
AddSubAgentLinkageToAgentExecutionThreads1784000000022,
|
||||
CreateInstanceAiMcpRegistryConnectionTable1784000000023,
|
||||
];
|
||||
|
|
|
|||
|
|
@ -190,6 +190,7 @@ import { CreateAgentFilesTable1784000000018 } from '../common/1784000000018-Crea
|
|||
import { AddCustomTelemetryTagsToProject1784000000019 } from '../common/1784000000019-AddCustomTelemetryTagsToProject';
|
||||
import { CreateWorkflowPublicationOutboxTable1784000000020 } from '../common/1784000000020-CreateWorkflowPublicationOutboxTable';
|
||||
import { AddSubAgentLinkageToAgentExecutionThreads1784000000022 } from '../common/1784000000022-AddSubAgentLinkageToAgentExecutionThreads';
|
||||
import { CreateInstanceAiMcpRegistryConnectionTable1784000000023 } from '../common/1784000000023-CreateInstanceAiMcpRegistryConnectionTable';
|
||||
import type { Migration } from '../migration-types';
|
||||
import { CreateAgentTaskDefinitionTable1784000000021 } from './1784000000021-CreateAgentTaskDefinitionTable';
|
||||
|
||||
|
|
@ -387,6 +388,7 @@ const sqliteMigrations: Migration[] = [
|
|||
CreateWorkflowPublicationOutboxTable1784000000020,
|
||||
CreateAgentTaskDefinitionTable1784000000021,
|
||||
AddSubAgentLinkageToAgentExecutionThreads1784000000022,
|
||||
CreateInstanceAiMcpRegistryConnectionTable1784000000023,
|
||||
];
|
||||
|
||||
export { sqliteMigrations };
|
||||
|
|
|
|||
|
|
@ -246,6 +246,46 @@ describe('McpClientManager', () => {
|
|||
await manager.getRegularTools(configs);
|
||||
expect(mockedMcpClient).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('does not share cached clients across different scoped fetch cache keys', async () => {
|
||||
const manager = new McpClientManager();
|
||||
await manager.getRegularTools([
|
||||
{
|
||||
name: 'shared',
|
||||
url: 'https://shared.example.com/',
|
||||
cacheKey: 'registry-connection:1',
|
||||
},
|
||||
]);
|
||||
await manager.getRegularTools([
|
||||
{
|
||||
name: 'shared',
|
||||
url: 'https://shared.example.com/',
|
||||
cacheKey: 'registry-connection:2',
|
||||
},
|
||||
]);
|
||||
|
||||
expect(mockedMcpClient).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('reuses cached clients when scoped fetch cache key matches', async () => {
|
||||
const manager = new McpClientManager();
|
||||
await manager.getRegularTools([
|
||||
{
|
||||
name: 'shared',
|
||||
url: 'https://shared.example.com/',
|
||||
cacheKey: 'registry-connection:1',
|
||||
},
|
||||
]);
|
||||
await manager.getRegularTools([
|
||||
{
|
||||
name: 'shared',
|
||||
url: 'https://shared.example.com/',
|
||||
cacheKey: 'registry-connection:1',
|
||||
},
|
||||
]);
|
||||
|
||||
expect(mockedMcpClient).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('concurrent dedup', () => {
|
||||
|
|
|
|||
|
|
@ -33,7 +33,12 @@ function buildNativeMcpConfigs(configs: McpServerConfig[]): NativeMcpServerConfi
|
|||
const servers: NativeMcpServerConfig[] = [];
|
||||
for (const server of configs) {
|
||||
if (server.url) {
|
||||
servers.push({ name: server.name, url: server.url });
|
||||
servers.push({
|
||||
name: server.name,
|
||||
url: server.url,
|
||||
transport: server.transport,
|
||||
fetch: server.fetch,
|
||||
});
|
||||
} else if (server.command) {
|
||||
servers.push({
|
||||
name: server.name,
|
||||
|
|
@ -191,7 +196,7 @@ export class McpClientManager {
|
|||
logger: Logger | undefined,
|
||||
source: string,
|
||||
): Promise<McpToolRegistry> {
|
||||
const client = new McpClient(buildNativeMcpConfigs(configs));
|
||||
const client = new McpClient(buildNativeMcpConfigs(configs), true);
|
||||
this.clientsByKey.set(clientKey, client);
|
||||
|
||||
const registry = toolsToRegistry(await client.listTools());
|
||||
|
|
|
|||
|
|
@ -909,9 +909,17 @@ export type CheckpointSettleResult =
|
|||
export interface McpServerConfig {
|
||||
name: string;
|
||||
url?: string;
|
||||
transport?: 'sse' | 'streamableHttp';
|
||||
command?: string;
|
||||
args?: string[];
|
||||
env?: Record<string, string>;
|
||||
fetch?: typeof fetch;
|
||||
/**
|
||||
* Optional cache discriminator used by `McpClientManager` when a server's
|
||||
* connection behavior depends on runtime context (for example, per-user auth
|
||||
* in a custom `fetch` implementation).
|
||||
*/
|
||||
cacheKey?: string;
|
||||
}
|
||||
|
||||
// ── Memory ───────────────────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import { mock } from 'jest-mock-extended';
|
|||
|
||||
import type { OauthService } from '@/oauth/oauth.service';
|
||||
|
||||
import { buildMcpClientForServer, createAuthFetch, mapApprovalToSdk } from '../mcp-client-factory';
|
||||
import { buildMcpClientForServer, mapApprovalToSdk } from '../mcp-client-factory';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Module mocks
|
||||
|
|
@ -63,69 +63,6 @@ describe('mapApprovalToSdk', () => {
|
|||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// createAuthFetch
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('createAuthFetch', () => {
|
||||
beforeEach(() => {
|
||||
proxyFetchMock.mockReset();
|
||||
});
|
||||
|
||||
it('routes through proxyFetch and injects the initial headers', async () => {
|
||||
proxyFetchMock.mockResolvedValueOnce(makeOk());
|
||||
|
||||
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
|
||||
const res = await fetchFn('https://example.test/mcp');
|
||||
|
||||
expect(res.status).toBe(200);
|
||||
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
|
||||
const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit];
|
||||
expect(init.headers).toMatchObject({ Authorization: 'Bearer A' });
|
||||
});
|
||||
|
||||
it('returns 401 unchanged when no onUnauthorized handler is configured', async () => {
|
||||
proxyFetchMock.mockResolvedValueOnce(make401());
|
||||
|
||||
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
|
||||
const res = await fetchFn('https://example.test/mcp');
|
||||
|
||||
expect(res.status).toBe(401);
|
||||
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('returns the original 401 when onUnauthorized returns null', async () => {
|
||||
proxyFetchMock.mockResolvedValueOnce(make401());
|
||||
|
||||
const onUnauthorized = jest.fn().mockResolvedValue(null);
|
||||
const fetchFn = createAuthFetch({
|
||||
initialHeaders: { Authorization: 'Bearer A' },
|
||||
onUnauthorized,
|
||||
});
|
||||
const res = await fetchFn('https://example.test/mcp');
|
||||
|
||||
expect(res.status).toBe(401);
|
||||
expect(onUnauthorized).toHaveBeenCalledTimes(1);
|
||||
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('retries once with refreshed headers when onUnauthorized returns new headers', async () => {
|
||||
proxyFetchMock.mockResolvedValueOnce(make401()).mockResolvedValueOnce(makeOk());
|
||||
|
||||
const onUnauthorized = jest.fn().mockResolvedValue({ Authorization: 'Bearer B' });
|
||||
const fetchFn = createAuthFetch({
|
||||
initialHeaders: { Authorization: 'Bearer A' },
|
||||
onUnauthorized,
|
||||
});
|
||||
const res = await fetchFn('https://example.test/mcp');
|
||||
|
||||
expect(res.status).toBe(200);
|
||||
expect(proxyFetchMock).toHaveBeenCalledTimes(2);
|
||||
const [, init2] = proxyFetchMock.mock.calls[1] as [unknown, RequestInit];
|
||||
expect(init2.headers).toMatchObject({ Authorization: 'Bearer B' });
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildMcpClientForServer — header derivation per auth type
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
@ -314,56 +251,6 @@ describe('buildMcpClientForServer — SDK config mapping', () => {
|
|||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// createAuthFetch — header merging and stateful refresh
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('createAuthFetch — header merging', () => {
|
||||
beforeEach(() => {
|
||||
proxyFetchMock.mockReset();
|
||||
proxyFetchMock.mockResolvedValue(new Response('ok', { status: 200 }));
|
||||
});
|
||||
|
||||
it('merges caller-supplied init.headers with auth headers (auth takes precedence)', async () => {
|
||||
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
|
||||
await fetchFn('https://example.test/mcp', { headers: { 'X-Custom': 'value' } });
|
||||
|
||||
const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit];
|
||||
expect(init.headers).toMatchObject({
|
||||
'X-Custom': 'value',
|
||||
Authorization: 'Bearer A',
|
||||
});
|
||||
});
|
||||
|
||||
it('uses the refreshed headers on the second call after a successful 401 refresh', async () => {
|
||||
proxyFetchMock
|
||||
.mockResolvedValueOnce(new Response('unauthorized', { status: 401 }))
|
||||
.mockResolvedValueOnce(new Response('ok', { status: 200 }))
|
||||
.mockResolvedValueOnce(new Response('ok', { status: 200 }));
|
||||
|
||||
let callCount = 0;
|
||||
const onUnauthorized = jest.fn().mockImplementation(async () => {
|
||||
callCount++;
|
||||
return { Authorization: `Bearer refreshed-${callCount}` };
|
||||
});
|
||||
|
||||
const fetchFn = createAuthFetch({
|
||||
initialHeaders: { Authorization: 'Bearer stale' },
|
||||
onUnauthorized,
|
||||
});
|
||||
|
||||
// First call triggers a 401 → refresh → retry
|
||||
await fetchFn('https://example.test/mcp');
|
||||
|
||||
// Second call should use the refreshed headers without triggering another refresh
|
||||
await fetchFn('https://example.test/mcp');
|
||||
|
||||
expect(onUnauthorized).toHaveBeenCalledTimes(1);
|
||||
const [, thirdInit] = proxyFetchMock.mock.calls[2] as [unknown, RequestInit];
|
||||
expect((thirdInit.headers as Record<string, string>).Authorization).toBe('Bearer refreshed-1');
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildMcpClientForServer — auth header edge cases
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import { proxyFetch } from '@n8n/ai-utilities';
|
||||
import type { CredentialProvider, McpClient, McpServerConfig } from '@n8n/agents';
|
||||
import type { AgentJsonMcpServerConfig } from '@n8n/api-types';
|
||||
import { isMcpOAuth2Authentication } from 'n8n-workflow';
|
||||
|
||||
import type { OauthService } from '@/oauth/oauth.service';
|
||||
import { createAuthFetch } from '@/utils/auth-fetch';
|
||||
|
||||
/**
|
||||
* Convert the JSON-config `approval` shape into the SDK's `requireApproval`
|
||||
|
|
@ -91,59 +91,6 @@ async function deriveAuthHeaders(
|
|||
}
|
||||
}
|
||||
|
||||
interface CreateAuthFetchOptions {
|
||||
initialHeaders: Record<string, string>;
|
||||
/**
|
||||
* Called on a 401 response. Should return a fresh set of auth headers, or
|
||||
* `null` if the refresh failed. The returned headers replace the cached
|
||||
* set used by subsequent requests.
|
||||
*/
|
||||
onUnauthorized?: () => Promise<Record<string, string> | null>;
|
||||
}
|
||||
|
||||
function headersToRecord(headers: HeadersInit | undefined): Record<string, string> {
|
||||
if (!headers) return {};
|
||||
if (headers instanceof Headers) return Object.fromEntries(headers.entries());
|
||||
if (Array.isArray(headers)) return Object.fromEntries(headers);
|
||||
return headers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a fetch wrapper that:
|
||||
* 1. routes through n8n's `proxyFetch` (so corporate HTTP_PROXY settings
|
||||
* apply uniformly),
|
||||
* 2. injects the latest auth headers on every request,
|
||||
* 3. on a single 401, calls `onUnauthorized` to refresh the token and
|
||||
* retries the request once with the new headers.
|
||||
*
|
||||
* This mirrors the langchain MCP node's `createAuthFetch` so an agent's MCP
|
||||
* connection behaves identically to one configured via the workflow editor.
|
||||
*/
|
||||
export function createAuthFetch({
|
||||
initialHeaders,
|
||||
onUnauthorized,
|
||||
}: CreateAuthFetchOptions): typeof fetch {
|
||||
let headers = initialHeaders;
|
||||
|
||||
return async (input: RequestInfo | URL, init?: RequestInit): Promise<Response> => {
|
||||
const response = await proxyFetch(input, {
|
||||
...init,
|
||||
headers: { ...headersToRecord(init?.headers), ...headers },
|
||||
});
|
||||
|
||||
if (response.status !== 401 || !onUnauthorized) return response;
|
||||
|
||||
const refreshed = await onUnauthorized();
|
||||
if (!refreshed) return response;
|
||||
|
||||
headers = refreshed;
|
||||
return await proxyFetch(input, {
|
||||
...init,
|
||||
headers: { ...headersToRecord(init?.headers), ...headers },
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
export interface BuildMcpClientDeps {
|
||||
credentialProvider: CredentialProvider;
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ export { InstanceAiCheckpoint } from './instance-ai-checkpoint.entity';
|
|||
export { InstanceAiObservation } from './instance-ai-observation.entity';
|
||||
export { InstanceAiObservationCursor } from './instance-ai-observation-cursor.entity';
|
||||
export { InstanceAiObservationLock } from './instance-ai-observation-lock.entity';
|
||||
export { InstanceAiMcpRegistryConnection } from './instance-ai-mcp-registry-connection.entity';
|
||||
export type {
|
||||
InstanceAiObservationMarker,
|
||||
InstanceAiObservationStatus,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
import { JsonColumn, WithTimestamps } from '@n8n/db';
|
||||
import { Column, Entity, Index, PrimaryColumn } from '@n8n/typeorm';
|
||||
|
||||
export type InstanceAiMcpToolFilter = {
|
||||
mode: 'allow' | 'exclude';
|
||||
tools: string[];
|
||||
};
|
||||
|
||||
@Entity({ name: 'instance_ai_mcp_registry_connections' })
|
||||
@Index(['userId', 'serverSlug', 'credentialId'], { unique: true })
|
||||
export class InstanceAiMcpRegistryConnection extends WithTimestamps {
|
||||
@PrimaryColumn('uuid')
|
||||
id: string;
|
||||
|
||||
@Column({ type: 'uuid' })
|
||||
userId: string;
|
||||
|
||||
@Column({ type: 'varchar', length: 255 })
|
||||
serverSlug: string;
|
||||
|
||||
@Column({ type: 'varchar', length: 36 })
|
||||
credentialId: string;
|
||||
|
||||
@JsonColumn({ nullable: true })
|
||||
toolFilter: InstanceAiMcpToolFilter | null;
|
||||
}
|
||||
|
|
@ -67,6 +67,9 @@ export class InstanceAiModule implements ModuleInterface {
|
|||
const { InstanceAiObservationLock } = await import(
|
||||
'./entities/instance-ai-observation-lock.entity'
|
||||
);
|
||||
const { InstanceAiMcpRegistryConnection } = await import(
|
||||
'./entities/instance-ai-mcp-registry-connection.entity'
|
||||
);
|
||||
|
||||
return [
|
||||
InstanceAiThread,
|
||||
|
|
@ -79,6 +82,7 @@ export class InstanceAiModule implements ModuleInterface {
|
|||
InstanceAiObservation,
|
||||
InstanceAiObservationCursor,
|
||||
InstanceAiObservationLock,
|
||||
InstanceAiMcpRegistryConnection,
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ import { InstanceAiPendingConfirmationRepository } from './repositories/instance
|
|||
import { InstanceAiThreadRepository } from './repositories/instance-ai-thread.repository';
|
||||
import { TraceReplayState } from './trace-replay-state';
|
||||
import { INSTANCE_AI_RUN_TIMEOUT_REASON, InstanceAiLivenessService } from './liveness';
|
||||
import { InstanceAiMcpRegistryService } from './mcp';
|
||||
import {
|
||||
buildInstanceAiRunTraceMetadata,
|
||||
type InstanceAiRunTraceMetadataOptions,
|
||||
|
|
@ -639,6 +640,7 @@ export class InstanceAiService {
|
|||
private readonly dbIterationLogStorage: DbIterationLogStorage,
|
||||
private readonly sourceControlPreferencesService: SourceControlPreferencesService,
|
||||
private readonly telemetry: Telemetry,
|
||||
private readonly mcpRegistryService: InstanceAiMcpRegistryService,
|
||||
private readonly userRepository: UserRepository,
|
||||
private readonly aiBuilderTemporaryWorkflowRepository: AiBuilderTemporaryWorkflowRepository,
|
||||
private readonly errorReporter: ErrorReporter,
|
||||
|
|
@ -3484,7 +3486,9 @@ export class InstanceAiService {
|
|||
return;
|
||||
}
|
||||
|
||||
const mcpServers = this.parseMcpServers(this.instanceAiConfig.mcpServers);
|
||||
const staticMcpServers = this.parseMcpServers(this.instanceAiConfig.mcpServers);
|
||||
const registryMcpServers = await this.mcpRegistryService.getRegistryMcpServers(user);
|
||||
const mcpServers = [...staticMcpServers, ...registryMcpServers];
|
||||
|
||||
const executionPushRef = this.threadPushRef.get(threadId);
|
||||
const environment = await this.createExecutionEnvironment(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,304 @@
|
|||
import type { Logger } from '@n8n/backend-common';
|
||||
import type { CredentialsEntity, User } from '@n8n/db';
|
||||
import { mock } from 'jest-mock-extended';
|
||||
|
||||
import type { CredentialsFinderService } from '@/credentials/credentials-finder.service';
|
||||
import type { CredentialsService } from '@/credentials/credentials.service';
|
||||
import type { McpRegistryService } from '@/modules/mcp-registry/registry/mcp-registry.service';
|
||||
import type { McpRegistryServer } from '@/modules/mcp-registry/registry/mcp-registry.types';
|
||||
import type { OauthService } from '@/oauth/oauth.service';
|
||||
|
||||
import type { InstanceAiMcpRegistryConnectionRepository } from '../../repositories/instance-ai-mcp-registry-connection.repository';
|
||||
import { InstanceAiMcpRegistryService } from '../instance-ai-mcp-registry.service';
|
||||
import type { InstanceAiMcpRegistryConnection } from '../../entities/instance-ai-mcp-registry-connection.entity';
|
||||
|
||||
const proxyFetchMock = jest.fn();
|
||||
|
||||
jest.mock('@n8n/ai-utilities', () => ({
|
||||
proxyFetch: (...args: unknown[]) => proxyFetchMock(...args),
|
||||
}));
|
||||
|
||||
function makeRegistryServer(
|
||||
slug: string,
|
||||
overrides: Partial<McpRegistryServer> = {},
|
||||
): McpRegistryServer {
|
||||
return {
|
||||
name: `com.test/${slug}`,
|
||||
slug,
|
||||
title: slug,
|
||||
description: `${slug} description`,
|
||||
tagline: `${slug} tagline`,
|
||||
version: '1.0.0',
|
||||
updatedAt: '2026-05-01T00:00:00.000Z',
|
||||
icons: [],
|
||||
authType: 'oauth2',
|
||||
remotes: [{ type: 'streamable-http', url: `https://${slug}.example.com/mcp` }],
|
||||
tools: [],
|
||||
isOfficial: true,
|
||||
origin: 'registry',
|
||||
status: 'active',
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe('InstanceAiMcpRegistryService', () => {
|
||||
const user = { id: 'user-1' } as User;
|
||||
const credential = {
|
||||
id: 'cred-1',
|
||||
name: 'MCP OAuth2',
|
||||
type: 'mcpOAuth2Api',
|
||||
shared: [{ role: 'credential:owner', projectId: 'project-1' }],
|
||||
} as CredentialsEntity;
|
||||
|
||||
const oauthCredentialData = {
|
||||
clientId: 'client-id',
|
||||
clientSecret: 'client-secret',
|
||||
accessTokenUrl: 'https://auth.example.com/token',
|
||||
oauthTokenData: {
|
||||
access_token: 'stale-token',
|
||||
refresh_token: 'refresh-token',
|
||||
},
|
||||
};
|
||||
|
||||
function createService() {
|
||||
const logger = mock<Logger>({ scoped: jest.fn().mockReturnThis() });
|
||||
const connectionRepository = mock<InstanceAiMcpRegistryConnectionRepository>();
|
||||
const mcpRegistryService = mock<McpRegistryService>();
|
||||
const credentialsFinderService = mock<CredentialsFinderService>();
|
||||
const credentialsService = mock<CredentialsService>();
|
||||
const oauthService = mock<OauthService>();
|
||||
|
||||
const service = new InstanceAiMcpRegistryService(
|
||||
logger,
|
||||
connectionRepository,
|
||||
mcpRegistryService,
|
||||
credentialsFinderService,
|
||||
credentialsService,
|
||||
oauthService,
|
||||
);
|
||||
|
||||
return {
|
||||
service,
|
||||
logger,
|
||||
connectionRepository,
|
||||
mcpRegistryService,
|
||||
credentialsFinderService,
|
||||
credentialsService,
|
||||
oauthService,
|
||||
};
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
proxyFetchMock.mockReset();
|
||||
});
|
||||
|
||||
it('returns empty list when the user has no registry connections', async () => {
|
||||
const { service, connectionRepository, mcpRegistryService } = createService();
|
||||
connectionRepository.findBy.mockResolvedValue([]);
|
||||
|
||||
const result = await service.getRegistryMcpServers(user);
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(mcpRegistryService.getBySlugs).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('resolves servers with deterministic names and preferred transport', async () => {
|
||||
const {
|
||||
service,
|
||||
connectionRepository,
|
||||
mcpRegistryService,
|
||||
credentialsFinderService,
|
||||
credentialsService,
|
||||
} = createService();
|
||||
const credentialsById: Record<string, CredentialsEntity> = {
|
||||
'cred-1': { id: 'cred-1', name: 'MCP OAuth2 #1', type: 'mcpOAuth2Api' } as CredentialsEntity,
|
||||
'cred-2': { id: 'cred-2', name: 'MCP OAuth2 #2', type: 'mcpOAuth2Api' } as CredentialsEntity,
|
||||
'cred-3': { id: 'cred-3', name: 'MCP OAuth2 #3', type: 'mcpOAuth2Api' } as CredentialsEntity,
|
||||
};
|
||||
connectionRepository.findBy.mockResolvedValue([
|
||||
{ id: '2', userId: user.id, serverSlug: 'linear', credentialId: 'cred-2' },
|
||||
{ id: '1', userId: user.id, serverSlug: 'linear', credentialId: 'cred-1' },
|
||||
{ id: '3', userId: user.id, serverSlug: 'notion', credentialId: 'cred-3' },
|
||||
] as InstanceAiMcpRegistryConnection[]);
|
||||
mcpRegistryService.getBySlugs.mockResolvedValue([
|
||||
makeRegistryServer('linear', {
|
||||
remotes: [
|
||||
{ type: 'sse', url: 'https://linear.example.com/sse' },
|
||||
{ type: 'streamable-http', url: 'https://linear.example.com/mcp' },
|
||||
],
|
||||
}),
|
||||
makeRegistryServer('notion', {
|
||||
remotes: [{ type: 'sse', url: 'https://notion.example.com/sse' }],
|
||||
}),
|
||||
]);
|
||||
credentialsFinderService.findCredentialForUser.mockImplementation(async (credentialId) => {
|
||||
return credentialsById[credentialId] ?? null;
|
||||
});
|
||||
credentialsService.decrypt.mockResolvedValue(oauthCredentialData);
|
||||
|
||||
const result = await service.getRegistryMcpServers(user);
|
||||
|
||||
expect(result).toHaveLength(3);
|
||||
expect(result[0]).toEqual(
|
||||
expect.objectContaining({
|
||||
name: 'mcp_linear',
|
||||
url: 'https://linear.example.com/mcp',
|
||||
transport: 'streamableHttp',
|
||||
cacheKey: 'registry-connection:1',
|
||||
fetch: expect.any(Function),
|
||||
}),
|
||||
);
|
||||
expect(result[1]).toEqual(
|
||||
expect.objectContaining({
|
||||
name: 'mcp_linear_2',
|
||||
url: 'https://linear.example.com/mcp',
|
||||
transport: 'streamableHttp',
|
||||
cacheKey: 'registry-connection:2',
|
||||
fetch: expect.any(Function),
|
||||
}),
|
||||
);
|
||||
expect(result[2]).toEqual(
|
||||
expect.objectContaining({
|
||||
name: 'mcp_notion',
|
||||
url: 'https://notion.example.com/sse',
|
||||
transport: 'sse',
|
||||
cacheKey: 'registry-connection:3',
|
||||
fetch: expect.any(Function),
|
||||
}),
|
||||
);
|
||||
expect(credentialsFinderService.findCredentialForUser).toHaveBeenCalledWith('cred-1', user, [
|
||||
'credential:read',
|
||||
]);
|
||||
expect(credentialsFinderService.findCredentialForUser).toHaveBeenCalledWith('cred-2', user, [
|
||||
'credential:read',
|
||||
]);
|
||||
expect(credentialsFinderService.findCredentialForUser).toHaveBeenCalledWith('cred-3', user, [
|
||||
'credential:read',
|
||||
]);
|
||||
});
|
||||
|
||||
it('skips connections with missing server slugs or unsupported remotes', async () => {
|
||||
const { service, connectionRepository, mcpRegistryService, logger } = createService();
|
||||
connectionRepository.findBy.mockResolvedValue([
|
||||
{ id: '1', userId: user.id, serverSlug: 'missing', credentialId: credential.id },
|
||||
{ id: '2', userId: user.id, serverSlug: 'bad-remote', credentialId: credential.id },
|
||||
] as InstanceAiMcpRegistryConnection[]);
|
||||
mcpRegistryService.getBySlugs.mockResolvedValue([
|
||||
makeRegistryServer('bad-remote', { remotes: [] }),
|
||||
]);
|
||||
|
||||
const result = await service.getRegistryMcpServers(user);
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
'Skipping MCP registry connection with missing server slug',
|
||||
expect.objectContaining({ connectionId: '1', serverSlug: 'missing', userId: user.id }),
|
||||
);
|
||||
expect(logger.warn).toHaveBeenCalledWith(
|
||||
'Skipping MCP registry connection without supported remote transport',
|
||||
expect.objectContaining({ connectionId: '2', serverSlug: 'bad-remote' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('does not attach custom fetch for non-oauth servers', async () => {
|
||||
const {
|
||||
service,
|
||||
connectionRepository,
|
||||
mcpRegistryService,
|
||||
credentialsFinderService,
|
||||
credentialsService,
|
||||
} = createService();
|
||||
connectionRepository.findBy.mockResolvedValue([
|
||||
{ id: '1', userId: user.id, serverSlug: 'public-server', credentialId: credential.id },
|
||||
] as InstanceAiMcpRegistryConnection[]);
|
||||
mcpRegistryService.getBySlugs.mockResolvedValue([
|
||||
makeRegistryServer('public-server', {
|
||||
// currently only oauth2 is supported
|
||||
// so we need to cast it to test this behavior
|
||||
authType: 'none' as unknown as 'oauth2',
|
||||
}),
|
||||
]);
|
||||
|
||||
const [server] = await service.getRegistryMcpServers(user);
|
||||
|
||||
expect(server).toEqual(
|
||||
expect.objectContaining({
|
||||
name: 'mcp_public-server',
|
||||
url: 'https://public-server.example.com/mcp',
|
||||
transport: 'streamableHttp',
|
||||
cacheKey: 'registry-connection:1',
|
||||
}),
|
||||
);
|
||||
expect(server.fetch).toBeUndefined();
|
||||
expect(credentialsFinderService.findCredentialForUser).not.toHaveBeenCalled();
|
||||
expect(credentialsService.decrypt).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('adds auth header and retries once with refreshed OAuth token after 401', async () => {
|
||||
const {
|
||||
service,
|
||||
connectionRepository,
|
||||
mcpRegistryService,
|
||||
credentialsFinderService,
|
||||
credentialsService,
|
||||
oauthService,
|
||||
} = createService();
|
||||
connectionRepository.findBy.mockResolvedValue([
|
||||
{ id: '1', userId: user.id, serverSlug: 'linear', credentialId: credential.id },
|
||||
] as InstanceAiMcpRegistryConnection[]);
|
||||
mcpRegistryService.getBySlugs.mockResolvedValue([makeRegistryServer('linear')]);
|
||||
credentialsFinderService.findCredentialForUser.mockResolvedValue(credential);
|
||||
credentialsService.decrypt.mockResolvedValue(oauthCredentialData);
|
||||
proxyFetchMock
|
||||
.mockResolvedValueOnce(new Response('unauthorized', { status: 401 }))
|
||||
.mockResolvedValueOnce(new Response('ok', { status: 200 }));
|
||||
oauthService.refreshOAuth2CredentialById.mockResolvedValue({
|
||||
Authorization: 'Bearer fresh-token',
|
||||
});
|
||||
|
||||
const [server] = await service.getRegistryMcpServers(user);
|
||||
const response = await server.fetch?.('https://linear.example.com/mcp');
|
||||
|
||||
expect(response?.status).toBe(200);
|
||||
expect(proxyFetchMock).toHaveBeenCalledTimes(2);
|
||||
const [, firstInit] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit];
|
||||
const [, secondInit] = proxyFetchMock.mock.calls[1] as [unknown, RequestInit];
|
||||
expect(new Headers(firstInit.headers).get('Authorization')).toBe('Bearer stale-token');
|
||||
expect(new Headers(secondInit.headers).get('Authorization')).toBe('Bearer fresh-token');
|
||||
expect(oauthService.refreshOAuth2CredentialById).toHaveBeenCalledWith(
|
||||
credential.id,
|
||||
'project-1',
|
||||
);
|
||||
});
|
||||
|
||||
it('returns original 401 response when token refresh fails', async () => {
|
||||
const {
|
||||
service,
|
||||
connectionRepository,
|
||||
mcpRegistryService,
|
||||
credentialsFinderService,
|
||||
credentialsService,
|
||||
oauthService,
|
||||
} = createService();
|
||||
connectionRepository.findBy.mockResolvedValue([
|
||||
{ id: '1', userId: user.id, serverSlug: 'linear', credentialId: credential.id },
|
||||
] as InstanceAiMcpRegistryConnection[]);
|
||||
mcpRegistryService.getBySlugs.mockResolvedValue([makeRegistryServer('linear')]);
|
||||
credentialsFinderService.findCredentialForUser.mockResolvedValue(credential);
|
||||
credentialsService.decrypt.mockResolvedValue(oauthCredentialData);
|
||||
oauthService.refreshOAuth2CredentialById.mockResolvedValue(null);
|
||||
|
||||
proxyFetchMock.mockResolvedValue(new Response('unauthorized', { status: 401 }));
|
||||
|
||||
const [server] = await service.getRegistryMcpServers(user);
|
||||
const response = await server.fetch?.('https://linear.example.com/mcp');
|
||||
|
||||
expect(response?.status).toBe(401);
|
||||
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
|
||||
expect(oauthService.refreshOAuth2CredentialById).toHaveBeenCalledWith(
|
||||
credential.id,
|
||||
'project-1',
|
||||
);
|
||||
});
|
||||
});
|
||||
1
packages/cli/src/modules/instance-ai/mcp/index.ts
Normal file
1
packages/cli/src/modules/instance-ai/mcp/index.ts
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from './instance-ai-mcp-registry.service';
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
import { isObjectLiteral, Logger } from '@n8n/backend-common';
|
||||
import type { CredentialsEntity, User } from '@n8n/db';
|
||||
import { Service } from '@n8n/di';
|
||||
import type { McpServerConfig } from '@n8n/instance-ai';
|
||||
import type { ICredentialDataDecryptedObject } from 'n8n-workflow';
|
||||
|
||||
import { CredentialsFinderService } from '@/credentials/credentials-finder.service';
|
||||
import { CredentialsService } from '@/credentials/credentials.service';
|
||||
import { McpRegistryService } from '@/modules/mcp-registry/registry/mcp-registry.service';
|
||||
import type { McpRegistryRemote } from '@/modules/mcp-registry/registry/mcp-registry.types';
|
||||
import { OauthService } from '@/oauth/oauth.service';
|
||||
import { createAuthFetch } from '@/utils/auth-fetch';
|
||||
|
||||
import { InstanceAiMcpRegistryConnectionRepository } from '../repositories/instance-ai-mcp-registry-connection.repository';
|
||||
|
||||
type Transport = 'sse' | 'streamableHttp';
|
||||
|
||||
interface ResolvedRegistryServer {
|
||||
serverSlug: string;
|
||||
credentialId: string;
|
||||
authType: string;
|
||||
endpointUrl: string;
|
||||
transport: Transport;
|
||||
}
|
||||
|
||||
interface OAuth2FetchContext {
|
||||
credentialId: string;
|
||||
accessToken: string;
|
||||
projectId: string;
|
||||
}
|
||||
|
||||
function readString(data: Record<string, unknown>, key: string): string | undefined {
|
||||
const value = data[key];
|
||||
return typeof value === 'string' && value.length > 0 ? value : undefined;
|
||||
}
|
||||
|
||||
function readAccessToken(tokenData: Record<string, unknown>): string | undefined {
|
||||
return readString(tokenData, 'accessToken') ?? readString(tokenData, 'access_token');
|
||||
}
|
||||
|
||||
function readOAuthTokenData(data: ICredentialDataDecryptedObject): Record<string, unknown> | null {
|
||||
const tokenData = data.oauthTokenData;
|
||||
return isObjectLiteral(tokenData) ? tokenData : null;
|
||||
}
|
||||
|
||||
function getPreferredRemote(remotes: McpRegistryRemote[]): {
|
||||
transport: Transport;
|
||||
endpointUrl: string;
|
||||
} | null {
|
||||
const streamable = remotes.find((remote) => remote.type === 'streamable-http');
|
||||
if (streamable?.url) {
|
||||
return { transport: 'streamableHttp', endpointUrl: streamable.url };
|
||||
}
|
||||
|
||||
const sse = remotes.find((remote) => remote.type === 'sse');
|
||||
if (sse?.url) {
|
||||
return { transport: 'sse', endpointUrl: sse.url };
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
const MCP_REGISTRY_SERVER_PREFIX = 'mcp_';
|
||||
const MAX_MCP_SERVER_NAME_LENGTH = 24;
|
||||
|
||||
function buildServerName(serverSlug: string, sequence: number): string {
|
||||
const safeSlug = serverSlug.replace(/[^A-Za-z0-9_-]/g, '_');
|
||||
const baseName = `${MCP_REGISTRY_SERVER_PREFIX}${safeSlug}`;
|
||||
if (sequence <= 1) {
|
||||
return baseName.slice(0, MAX_MCP_SERVER_NAME_LENGTH);
|
||||
}
|
||||
|
||||
const suffix = `_${sequence}`;
|
||||
const maxBaseLength = Math.max(0, MAX_MCP_SERVER_NAME_LENGTH - suffix.length);
|
||||
return `${baseName.slice(0, maxBaseLength)}${suffix}`;
|
||||
}
|
||||
|
||||
@Service()
|
||||
export class InstanceAiMcpRegistryService {
|
||||
private readonly logger: Logger;
|
||||
|
||||
constructor(
|
||||
logger: Logger,
|
||||
private readonly connectionRepository: InstanceAiMcpRegistryConnectionRepository,
|
||||
private readonly mcpRegistryService: McpRegistryService,
|
||||
private readonly credentialsFinderService: CredentialsFinderService,
|
||||
private readonly credentialsService: CredentialsService,
|
||||
private readonly oauthService: OauthService,
|
||||
) {
|
||||
this.logger = logger.scoped('instance-ai');
|
||||
}
|
||||
|
||||
async getRegistryMcpServers(user: User): Promise<McpServerConfig[]> {
|
||||
const connections = await this.connectionRepository.findBy({ userId: user.id });
|
||||
if (connections.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const sortedConnections = connections.sort((left, right) => left.id.localeCompare(right.id));
|
||||
const slugs = [...new Set(sortedConnections.map((connection) => connection.serverSlug))];
|
||||
const servers = await this.mcpRegistryService.getBySlugs(slugs);
|
||||
const serverBySlug = new Map(servers.map((server) => [server.slug, server]));
|
||||
const slugCounts = new Map<string, number>();
|
||||
|
||||
const resolved: McpServerConfig[] = [];
|
||||
for (const connection of sortedConnections) {
|
||||
const server = serverBySlug.get(connection.serverSlug);
|
||||
if (!server) {
|
||||
this.logger.warn('Skipping MCP registry connection with missing server slug', {
|
||||
connectionId: connection.id,
|
||||
serverSlug: connection.serverSlug,
|
||||
userId: user.id,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
const resolvedServer = this.resolveRegistryServer(
|
||||
connection.id,
|
||||
connection.serverSlug,
|
||||
connection.credentialId,
|
||||
server.authType,
|
||||
server.remotes,
|
||||
);
|
||||
if (!resolvedServer) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const nextCount = (slugCounts.get(resolvedServer.serverSlug) ?? 0) + 1;
|
||||
slugCounts.set(resolvedServer.serverSlug, nextCount);
|
||||
const serverConfig: McpServerConfig = {
|
||||
name: buildServerName(resolvedServer.serverSlug, nextCount),
|
||||
url: resolvedServer.endpointUrl,
|
||||
transport: resolvedServer.transport,
|
||||
cacheKey: `registry-connection:${connection.id}`,
|
||||
};
|
||||
|
||||
if (resolvedServer.authType === 'oauth2') {
|
||||
const oauth2FetchContext = await this.buildOAuth2FetchContext(
|
||||
resolvedServer,
|
||||
user,
|
||||
connection.id,
|
||||
);
|
||||
if (!oauth2FetchContext) {
|
||||
continue;
|
||||
}
|
||||
|
||||
serverConfig.fetch = createAuthFetch({
|
||||
initialHeaders: { Authorization: `Bearer ${oauth2FetchContext.accessToken}` },
|
||||
onUnauthorized: async () => {
|
||||
if (!oauth2FetchContext.projectId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return await this.oauthService.refreshOAuth2CredentialById(
|
||||
oauth2FetchContext.credentialId,
|
||||
oauth2FetchContext.projectId,
|
||||
);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
resolved.push(serverConfig);
|
||||
}
|
||||
|
||||
return resolved;
|
||||
}
|
||||
|
||||
private resolveRegistryServer(
|
||||
connectionId: string,
|
||||
serverSlug: string,
|
||||
credentialId: string,
|
||||
authType: string,
|
||||
remotes: McpRegistryRemote[],
|
||||
): ResolvedRegistryServer | null {
|
||||
const remote = getPreferredRemote(remotes);
|
||||
if (!remote) {
|
||||
this.logger.warn('Skipping MCP registry connection without supported remote transport', {
|
||||
connectionId,
|
||||
serverSlug,
|
||||
credentialId,
|
||||
});
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
serverSlug,
|
||||
credentialId,
|
||||
authType,
|
||||
endpointUrl: remote.endpointUrl,
|
||||
transport: remote.transport,
|
||||
};
|
||||
}
|
||||
|
||||
private async buildOAuth2FetchContext(
|
||||
config: ResolvedRegistryServer,
|
||||
user: User,
|
||||
connectionId: string,
|
||||
): Promise<OAuth2FetchContext | null> {
|
||||
const credentialWithData = await this.getCredentialWithData(config.credentialId, user);
|
||||
if (!credentialWithData) {
|
||||
this.logger.warn('Skipping MCP registry connection with inaccessible credential', {
|
||||
connectionId,
|
||||
serverSlug: config.serverSlug,
|
||||
credentialId: config.credentialId,
|
||||
userId: user.id,
|
||||
});
|
||||
return null;
|
||||
}
|
||||
|
||||
const tokenData = readOAuthTokenData(credentialWithData.data);
|
||||
if (!tokenData) {
|
||||
this.logger.warn('Skipping MCP registry connection without OAuth2 token data', {
|
||||
connectionId,
|
||||
serverSlug: config.serverSlug,
|
||||
credentialId: config.credentialId,
|
||||
});
|
||||
return null;
|
||||
}
|
||||
|
||||
const accessToken = readAccessToken(tokenData);
|
||||
if (!accessToken) {
|
||||
this.logger.warn('Skipping MCP registry connection without access token', {
|
||||
connectionId,
|
||||
serverSlug: config.serverSlug,
|
||||
credentialId: config.credentialId,
|
||||
});
|
||||
return null;
|
||||
}
|
||||
|
||||
const projectId = credentialWithData.credential.shared?.[0]?.projectId ?? null;
|
||||
if (!projectId) {
|
||||
this.logger.warn('Skipping OAuth2 token refresh for credential without project sharing', {
|
||||
connectionId,
|
||||
serverSlug: config.serverSlug,
|
||||
credentialId: config.credentialId,
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
credentialId: config.credentialId,
|
||||
accessToken,
|
||||
projectId,
|
||||
};
|
||||
}
|
||||
|
||||
private async getCredentialWithData(
|
||||
credentialId: string,
|
||||
user: User,
|
||||
): Promise<{ credential: CredentialsEntity; data: ICredentialDataDecryptedObject } | null> {
|
||||
const credential = await this.credentialsFinderService.findCredentialForUser(
|
||||
credentialId,
|
||||
user,
|
||||
['credential:read'],
|
||||
);
|
||||
if (!credential) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const data = await this.credentialsService.decrypt(credential, true);
|
||||
if (!isObjectLiteral(data) || Object.keys(data).length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return { credential, data };
|
||||
}
|
||||
}
|
||||
|
|
@ -7,3 +7,4 @@ export { InstanceAiCheckpointRepository } from './instance-ai-checkpoint.reposit
|
|||
export { InstanceAiObservationRepository } from './instance-ai-observation.repository';
|
||||
export { InstanceAiObservationCursorRepository } from './instance-ai-observation-cursor.repository';
|
||||
export { InstanceAiObservationLockRepository } from './instance-ai-observation-lock.repository';
|
||||
export { InstanceAiMcpRegistryConnectionRepository } from './instance-ai-mcp-registry-connection.repository';
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
import { Service } from '@n8n/di';
|
||||
import { DataSource, Repository } from '@n8n/typeorm';
|
||||
|
||||
import { InstanceAiMcpRegistryConnection } from '../entities/instance-ai-mcp-registry-connection.entity';
|
||||
|
||||
@Service()
|
||||
export class InstanceAiMcpRegistryConnectionRepository extends Repository<InstanceAiMcpRegistryConnection> {
|
||||
constructor(dataSource: DataSource) {
|
||||
super(InstanceAiMcpRegistryConnection, dataSource.manager);
|
||||
}
|
||||
}
|
||||
|
|
@ -131,6 +131,24 @@ describe('McpRegistryService', () => {
|
|||
expect(notion).toEqual(notionMockServer);
|
||||
expect(missing).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns empty array for getBySlugs when input is empty', async () => {
|
||||
const { service, repository } = createService();
|
||||
|
||||
const servers = await service.getBySlugs([]);
|
||||
|
||||
expect(servers).toEqual([]);
|
||||
expect(repository.findBy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns mapped servers for getBySlugs', async () => {
|
||||
const { service, repository } = createService();
|
||||
|
||||
const servers = await service.getBySlugs(['notion', 'linear']);
|
||||
|
||||
expect(repository.findBy).toHaveBeenCalledWith([{ slug: 'notion' }, { slug: 'linear' }]);
|
||||
expect(servers).toEqual([notionMockServer, linearMockServer]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('refresh flow', () => {
|
||||
|
|
|
|||
|
|
@ -90,6 +90,15 @@ export class McpRegistryService {
|
|||
return entity ? fromEntity(entity) : undefined;
|
||||
}
|
||||
|
||||
async getBySlugs(slugs: string[]): Promise<McpRegistryServer[]> {
|
||||
if (slugs.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const entities = await this.repository.findBy(slugs.map((slug) => ({ slug })));
|
||||
return entities.map(fromEntity);
|
||||
}
|
||||
|
||||
private startPeriodicRefresh(): void {
|
||||
if (this.isShuttingDown || this.refreshInterval) {
|
||||
return;
|
||||
|
|
|
|||
119
packages/cli/src/utils/__tests__/auth-fetch.test.ts
Normal file
119
packages/cli/src/utils/__tests__/auth-fetch.test.ts
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import { createAuthFetch } from '@/utils/auth-fetch';
|
||||
|
||||
const proxyFetchMock = jest.fn();
|
||||
jest.mock('@n8n/ai-utilities', () => ({
|
||||
proxyFetch: (...args: unknown[]) => proxyFetchMock(...args),
|
||||
}));
|
||||
|
||||
function makeOk(): Response {
|
||||
return new Response('ok', { status: 200 });
|
||||
}
|
||||
|
||||
function make401(): Response {
|
||||
return new Response('unauthorized', { status: 401 });
|
||||
}
|
||||
|
||||
describe('createAuthFetch', () => {
|
||||
beforeEach(() => {
|
||||
proxyFetchMock.mockReset();
|
||||
});
|
||||
|
||||
it('routes through proxyFetch and injects the initial headers', async () => {
|
||||
proxyFetchMock.mockResolvedValueOnce(makeOk());
|
||||
|
||||
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
|
||||
const res = await fetchFn('https://example.test/mcp');
|
||||
|
||||
expect(res.status).toBe(200);
|
||||
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
|
||||
const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit];
|
||||
expect(init.headers).toMatchObject({ Authorization: 'Bearer A' });
|
||||
});
|
||||
|
||||
it('returns 401 unchanged when no onUnauthorized handler is configured', async () => {
|
||||
proxyFetchMock.mockResolvedValueOnce(make401());
|
||||
|
||||
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
|
||||
const res = await fetchFn('https://example.test/mcp');
|
||||
|
||||
expect(res.status).toBe(401);
|
||||
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('returns the original 401 when onUnauthorized returns null', async () => {
|
||||
proxyFetchMock.mockResolvedValueOnce(make401());
|
||||
|
||||
const onUnauthorized = jest.fn().mockResolvedValue(null);
|
||||
const fetchFn = createAuthFetch({
|
||||
initialHeaders: { Authorization: 'Bearer A' },
|
||||
onUnauthorized,
|
||||
});
|
||||
const res = await fetchFn('https://example.test/mcp');
|
||||
|
||||
expect(res.status).toBe(401);
|
||||
expect(onUnauthorized).toHaveBeenCalledTimes(1);
|
||||
expect(proxyFetchMock).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('retries once with refreshed headers when onUnauthorized returns new headers', async () => {
|
||||
proxyFetchMock.mockResolvedValueOnce(make401()).mockResolvedValueOnce(makeOk());
|
||||
|
||||
const onUnauthorized = jest.fn().mockResolvedValue({ Authorization: 'Bearer B' });
|
||||
const fetchFn = createAuthFetch({
|
||||
initialHeaders: { Authorization: 'Bearer A' },
|
||||
onUnauthorized,
|
||||
});
|
||||
const res = await fetchFn('https://example.test/mcp');
|
||||
|
||||
expect(res.status).toBe(200);
|
||||
expect(proxyFetchMock).toHaveBeenCalledTimes(2);
|
||||
const [, init2] = proxyFetchMock.mock.calls[1] as [unknown, RequestInit];
|
||||
expect(init2.headers).toMatchObject({ Authorization: 'Bearer B' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('createAuthFetch — header merging', () => {
|
||||
beforeEach(() => {
|
||||
proxyFetchMock.mockReset();
|
||||
proxyFetchMock.mockResolvedValue(new Response('ok', { status: 200 }));
|
||||
});
|
||||
|
||||
it('merges caller-supplied init.headers with auth headers (auth takes precedence)', async () => {
|
||||
const fetchFn = createAuthFetch({ initialHeaders: { Authorization: 'Bearer A' } });
|
||||
await fetchFn('https://example.test/mcp', { headers: { 'X-Custom': 'value' } });
|
||||
|
||||
const [, init] = proxyFetchMock.mock.calls[0] as [unknown, RequestInit];
|
||||
expect(init.headers).toMatchObject({
|
||||
'X-Custom': 'value',
|
||||
Authorization: 'Bearer A',
|
||||
});
|
||||
});
|
||||
|
||||
it('uses the refreshed headers on the second call after a successful 401 refresh', async () => {
|
||||
proxyFetchMock
|
||||
.mockResolvedValueOnce(new Response('unauthorized', { status: 401 }))
|
||||
.mockResolvedValueOnce(new Response('ok', { status: 200 }))
|
||||
.mockResolvedValueOnce(new Response('ok', { status: 200 }));
|
||||
|
||||
let callCount = 0;
|
||||
const onUnauthorized = jest.fn().mockImplementation(async () => {
|
||||
callCount++;
|
||||
return { Authorization: `Bearer refreshed-${callCount}` };
|
||||
});
|
||||
|
||||
const fetchFn = createAuthFetch({
|
||||
initialHeaders: { Authorization: 'Bearer stale' },
|
||||
onUnauthorized,
|
||||
});
|
||||
|
||||
// First call triggers a 401 → refresh → retry
|
||||
await fetchFn('https://example.test/mcp');
|
||||
|
||||
// Second call should use the refreshed headers without triggering another refresh
|
||||
await fetchFn('https://example.test/mcp');
|
||||
|
||||
expect(onUnauthorized).toHaveBeenCalledTimes(1);
|
||||
const [, thirdInit] = proxyFetchMock.mock.calls[2] as [unknown, RequestInit];
|
||||
expect((thirdInit.headers as Record<string, string>).Authorization).toBe('Bearer refreshed-1');
|
||||
});
|
||||
});
|
||||
54
packages/cli/src/utils/auth-fetch.ts
Normal file
54
packages/cli/src/utils/auth-fetch.ts
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
import { proxyFetch } from '@n8n/ai-utilities';
|
||||
|
||||
interface CreateAuthFetchOptions {
|
||||
initialHeaders: Record<string, string>;
|
||||
/**
|
||||
* Called on a 401 response. Should return a fresh set of auth headers, or
|
||||
* `null` if the refresh failed. The returned headers replace the cached
|
||||
* set used by subsequent requests.
|
||||
*/
|
||||
onUnauthorized?: () => Promise<Record<string, string> | null>;
|
||||
}
|
||||
|
||||
function headersToRecord(headers: HeadersInit | undefined): Record<string, string> {
|
||||
if (!headers) return {};
|
||||
if (headers instanceof Headers) return Object.fromEntries(headers.entries());
|
||||
if (Array.isArray(headers)) return Object.fromEntries(headers);
|
||||
return headers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a fetch wrapper that:
|
||||
* 1. routes through n8n's `proxyFetch` (so corporate HTTP_PROXY settings
|
||||
* apply uniformly),
|
||||
* 2. injects the latest auth headers on every request,
|
||||
* 3. on a single 401, calls `onUnauthorized` to refresh the token and
|
||||
* retries the request once with the new headers.
|
||||
*
|
||||
* This mirrors the langchain MCP node's `createAuthFetch` so an agent's MCP
|
||||
* connection behaves identically to one configured via the workflow editor.
|
||||
*/
|
||||
export function createAuthFetch({
|
||||
initialHeaders,
|
||||
onUnauthorized,
|
||||
}: CreateAuthFetchOptions): typeof fetch {
|
||||
let headers = initialHeaders;
|
||||
|
||||
return async (input: RequestInfo | URL, init?: RequestInit): Promise<Response> => {
|
||||
const response = await proxyFetch(input, {
|
||||
...init,
|
||||
headers: { ...headersToRecord(init?.headers), ...headers },
|
||||
});
|
||||
|
||||
if (response.status !== 401 || !onUnauthorized) return response;
|
||||
|
||||
const refreshed = await onUnauthorized();
|
||||
if (!refreshed) return response;
|
||||
|
||||
headers = refreshed;
|
||||
return await proxyFetch(input, {
|
||||
...init,
|
||||
headers: { ...headersToRecord(init?.headers), ...headers },
|
||||
});
|
||||
};
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user