mirror of
https://github.com/n8n-io/n8n.git
synced 2026-05-31 00:37:10 +02:00
refactor(instance-ai): support native stream resume
This commit is contained in:
parent
7f37612d2f
commit
69c43fdfaa
|
|
@ -218,6 +218,17 @@ async function* fromChunks(chunks: unknown[]) {
|
|||
}
|
||||
}
|
||||
|
||||
function readableFromChunks(chunks: unknown[]) {
|
||||
return new ReadableStream<unknown>({
|
||||
start(controller) {
|
||||
for (const chunk of chunks) {
|
||||
controller.enqueue(chunk);
|
||||
}
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function createDeferred<T>() {
|
||||
let resolve!: (value: T | PromiseLike<T>) => void;
|
||||
let reject!: (reason?: unknown) => void;
|
||||
|
|
@ -461,6 +472,63 @@ describe('executeResumableStream', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('auto-resumes native agent streams', async () => {
|
||||
const eventBus = createEventBus();
|
||||
const resume = jest.fn().mockResolvedValue({
|
||||
runId: 'agent-run-2',
|
||||
stream: readableFromChunks([{ type: 'text-delta', delta: 'Done.' }]),
|
||||
getState: jest.fn(),
|
||||
});
|
||||
const waitForConfirmation = jest.fn().mockResolvedValue({ approved: true });
|
||||
|
||||
const result = await executeResumableStream({
|
||||
agent: { resume },
|
||||
stream: {
|
||||
runId: 'agent-run-1',
|
||||
streamFormat: 'agent',
|
||||
fullStream: fromChunks([
|
||||
{
|
||||
type: 'tool-call-suspended',
|
||||
toolCallId: 'tool-call-1',
|
||||
toolName: 'pause-for-user',
|
||||
suspendPayload: {
|
||||
requestId: 'request-1',
|
||||
message: 'Please confirm',
|
||||
},
|
||||
},
|
||||
]),
|
||||
},
|
||||
context: {
|
||||
threadId: 'thread-1',
|
||||
runId: 'run-1',
|
||||
agentId: 'agent-1',
|
||||
eventBus,
|
||||
signal: new AbortController().signal,
|
||||
logger: { info: jest.fn(), warn: jest.fn(), error: jest.fn(), debug: jest.fn() },
|
||||
},
|
||||
control: {
|
||||
mode: 'auto',
|
||||
waitForConfirmation,
|
||||
},
|
||||
});
|
||||
|
||||
expect(waitForConfirmation).toHaveBeenCalledWith('request-1');
|
||||
expect(resume).toHaveBeenCalledWith(
|
||||
'stream',
|
||||
{ approved: true },
|
||||
{ runId: 'agent-run-1', toolCallId: 'tool-call-1' },
|
||||
);
|
||||
expect(result.status).toBe('completed');
|
||||
expect(result.agentRunId).toBe('agent-run-2');
|
||||
expect(eventBus.publish).toHaveBeenCalledWith(
|
||||
'thread-1',
|
||||
expect.objectContaining({
|
||||
type: 'text-delta',
|
||||
payload: { text: 'Done.' },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('registers auto confirmations before the stream finishes draining', async () => {
|
||||
const eventBus = createEventBus();
|
||||
const finishGate = createDeferred<undefined>();
|
||||
|
|
|
|||
|
|
@ -2,10 +2,19 @@ import type { WorkSummary } from '../../stream/work-summary-accumulator';
|
|||
import { executeResumableStream } from '../resumable-stream-executor';
|
||||
import { streamAgentRun } from '../stream-runner';
|
||||
|
||||
jest.mock('../resumable-stream-executor', () => ({
|
||||
executeResumableStream: jest.fn(),
|
||||
createLlmStepTraceHooks: jest.fn(),
|
||||
}));
|
||||
jest.mock('../resumable-stream-executor', () => {
|
||||
const actual =
|
||||
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
||||
jest.requireActual<typeof import('../resumable-stream-executor')>(
|
||||
'../resumable-stream-executor',
|
||||
);
|
||||
|
||||
return {
|
||||
...actual,
|
||||
executeResumableStream: jest.fn(),
|
||||
createLlmStepTraceHooks: jest.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
const emptyWorkSummary: WorkSummary = { toolCalls: [], totalToolCalls: 0, totalToolErrors: 0 };
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import type { InstanceAiEvent } from '@n8n/api-types';
|
||||
import type { StreamResult } from '@n8n/agents';
|
||||
import type { RunTree } from 'langsmith';
|
||||
|
||||
import type { InstanceAiEventBus } from '../event-bus';
|
||||
|
|
@ -6,7 +7,7 @@ import type { Logger } from '../logger';
|
|||
import { mapAgentChunkToEvent, mapMastraChunkToEvent } from '../stream/map-chunk';
|
||||
import { WorkSummaryAccumulator, type WorkSummary } from '../stream/work-summary-accumulator';
|
||||
import { getTraceParentRun, setTraceParentOverride } from '../tracing/langsmith-tracing';
|
||||
import { asResumable, parseSuspension } from '../utils/stream-helpers';
|
||||
import { parseSuspension, resumeStream } from '../utils/stream-helpers';
|
||||
import type { SuspensionInfo } from '../utils/stream-helpers';
|
||||
|
||||
type ConfirmationRequestEvent = Extract<InstanceAiEvent, { type: 'confirmation-request' }>;
|
||||
|
|
@ -167,6 +168,64 @@ function isRecord(value: unknown): value is Record<string, unknown> {
|
|||
return value !== null && typeof value === 'object' && !Array.isArray(value);
|
||||
}
|
||||
|
||||
function isAsyncIterable(value: unknown): value is AsyncIterable<unknown> {
|
||||
return (
|
||||
value !== null &&
|
||||
typeof value === 'object' &&
|
||||
typeof Reflect.get(value, Symbol.asyncIterator) === 'function'
|
||||
);
|
||||
}
|
||||
|
||||
function isReadableStream(value: unknown): value is ReadableStream<unknown> {
|
||||
return (
|
||||
value !== null &&
|
||||
typeof value === 'object' &&
|
||||
typeof Reflect.get(value, 'getReader') === 'function'
|
||||
);
|
||||
}
|
||||
|
||||
function isResumableStreamSource(value: unknown): value is ResumableStreamSource {
|
||||
return isRecord(value) && isAsyncIterable(value.fullStream);
|
||||
}
|
||||
|
||||
function isNativeStreamResult(value: unknown): value is StreamResult {
|
||||
return isRecord(value) && isReadableStream(value.stream);
|
||||
}
|
||||
|
||||
async function* readableStreamToAsyncIterable(stream: ReadableStream<unknown>) {
|
||||
const reader = stream.getReader();
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) return;
|
||||
yield value;
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
|
||||
export function normalizeStreamSource(
|
||||
result: unknown,
|
||||
options?: { streamFormat?: ResumableStreamFormat },
|
||||
): ResumableStreamSource {
|
||||
if (isResumableStreamSource(result)) {
|
||||
return options?.streamFormat && !result.streamFormat
|
||||
? { ...result, streamFormat: options.streamFormat }
|
||||
: result;
|
||||
}
|
||||
|
||||
if (isNativeStreamResult(result)) {
|
||||
return {
|
||||
runId: result.runId,
|
||||
streamFormat: 'agent',
|
||||
fullStream: readableStreamToAsyncIterable(result.stream),
|
||||
};
|
||||
}
|
||||
|
||||
throw new Error('Unsupported agent stream result');
|
||||
}
|
||||
|
||||
function getFiniteNumber(value: unknown): number | undefined {
|
||||
return typeof value === 'number' && Number.isFinite(value) ? value : undefined;
|
||||
}
|
||||
|
|
@ -2024,15 +2083,19 @@ export async function executeResumableStream(
|
|||
runId: activeAgentRunId,
|
||||
toolCallId: suspension.toolCallId,
|
||||
};
|
||||
const resumed = await asResumable(options.agent).resumeStream(resumeData, {
|
||||
const resumed = await resumeStream(options.agent, resumeData, {
|
||||
...resumeOptions,
|
||||
...(options.llmStepTraceHooks?.executionOptions ?? {}),
|
||||
});
|
||||
const resumedSource = normalizeStreamSource(resumed, {
|
||||
streamFormat: activeSource.streamFormat,
|
||||
});
|
||||
|
||||
activeAgentRunId = (typeof resumed.runId === 'string' ? resumed.runId : '') || activeAgentRunId;
|
||||
activeSource = { ...resumed, streamFormat: activeSource.streamFormat };
|
||||
activeStream = resumed.fullStream;
|
||||
text = resumed.text;
|
||||
activeAgentRunId =
|
||||
(typeof resumedSource.runId === 'string' ? resumedSource.runId : '') || activeAgentRunId;
|
||||
activeSource = resumedSource;
|
||||
activeStream = resumedSource.fullStream;
|
||||
text = resumedSource.text;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import type { InstanceAiEvent } from '@n8n/api-types';
|
||||
import type { StreamResult } from '@n8n/agents';
|
||||
|
||||
import type { InstanceAiEventBus } from '../event-bus';
|
||||
import type { Logger } from '../logger';
|
||||
|
|
@ -7,20 +6,16 @@ import {
|
|||
createLlmStepTraceHooks,
|
||||
executeResumableStream,
|
||||
type LlmStepTraceHooks,
|
||||
normalizeStreamSource,
|
||||
type ResumableStreamSource,
|
||||
type TraceStatus,
|
||||
} from './resumable-stream-executor';
|
||||
import { getTraceParentRun, withTraceParentContext } from '../tracing/langsmith-tracing';
|
||||
import { asResumable, isRecord } from '../utils/stream-helpers';
|
||||
import { resumeStream } from '../utils/stream-helpers';
|
||||
import type { SuspensionInfo } from '../utils/stream-helpers';
|
||||
|
||||
type StreamableAgentStreamResult = ResumableStreamSource | StreamResult;
|
||||
|
||||
export interface StreamableAgent {
|
||||
stream: (
|
||||
input: unknown,
|
||||
options: Record<string, unknown>,
|
||||
) => Promise<StreamableAgentStreamResult>;
|
||||
stream: (input: unknown, options: Record<string, unknown>) => Promise<unknown>;
|
||||
}
|
||||
|
||||
export interface StreamRunOptions {
|
||||
|
|
@ -40,59 +35,6 @@ export interface StreamRunResult {
|
|||
confirmationEvent?: Extract<InstanceAiEvent, { type: 'confirmation-request' }>;
|
||||
}
|
||||
|
||||
function isAsyncIterable(value: unknown): value is AsyncIterable<unknown> {
|
||||
return (
|
||||
value !== null &&
|
||||
typeof value === 'object' &&
|
||||
typeof Reflect.get(value, Symbol.asyncIterator) === 'function'
|
||||
);
|
||||
}
|
||||
|
||||
function isReadableStream(value: unknown): value is ReadableStream<unknown> {
|
||||
return (
|
||||
value !== null &&
|
||||
typeof value === 'object' &&
|
||||
typeof Reflect.get(value, 'getReader') === 'function'
|
||||
);
|
||||
}
|
||||
|
||||
function isResumableStreamSource(value: unknown): value is ResumableStreamSource {
|
||||
return isRecord(value) && isAsyncIterable(value.fullStream);
|
||||
}
|
||||
|
||||
function isNativeStreamResult(value: unknown): value is StreamResult {
|
||||
return isRecord(value) && isReadableStream(value.stream);
|
||||
}
|
||||
|
||||
async function* readableStreamToAsyncIterable(stream: ReadableStream<unknown>) {
|
||||
const reader = stream.getReader();
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) return;
|
||||
yield value;
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeStreamSource(result: StreamableAgentStreamResult): ResumableStreamSource {
|
||||
if (isResumableStreamSource(result)) {
|
||||
return result;
|
||||
}
|
||||
|
||||
if (isNativeStreamResult(result)) {
|
||||
return {
|
||||
runId: result.runId,
|
||||
streamFormat: 'agent',
|
||||
fullStream: readableStreamToAsyncIterable(result.stream),
|
||||
};
|
||||
}
|
||||
|
||||
throw new Error('Unsupported agent stream result');
|
||||
}
|
||||
|
||||
export async function streamAgentRun(
|
||||
agent: StreamableAgent,
|
||||
input: unknown,
|
||||
|
|
@ -121,7 +63,7 @@ export async function resumeAgentRun(
|
|||
const resumeTraceParent = getTraceParentRun();
|
||||
return await withTraceParentContext(resumeTraceParent, async () => {
|
||||
const llmStepTraceHooks = createLlmStepTraceHooks(resumeTraceParent);
|
||||
const resumed = await asResumable(agent).resumeStream(resumeData, {
|
||||
const resumed = await resumeStream(agent, resumeData, {
|
||||
...resumeOptions,
|
||||
...(llmStepTraceHooks?.executionOptions ?? {}),
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { isRecord, parseSuspension, asResumable } from '../stream-helpers';
|
||||
import { isRecord, parseSuspension, asResumable, resumeStream } from '../stream-helpers';
|
||||
|
||||
describe('isRecord', () => {
|
||||
it('returns true for plain objects', () => {
|
||||
|
|
@ -127,3 +127,31 @@ describe('asResumable', () => {
|
|||
expect(resumable.resumeStream).toBe(agent.resumeStream);
|
||||
});
|
||||
});
|
||||
|
||||
describe('resumeStream', () => {
|
||||
it('uses Mastra-style resumeStream when available', async () => {
|
||||
const resumed = { runId: 'run-2' };
|
||||
const agent = { resumeStream: jest.fn().mockResolvedValue(resumed) };
|
||||
|
||||
await expect(resumeStream(agent, { approved: true }, { runId: 'run-1' })).resolves.toBe(
|
||||
resumed,
|
||||
);
|
||||
expect(agent.resumeStream).toHaveBeenCalledWith({ approved: true }, { runId: 'run-1' });
|
||||
});
|
||||
|
||||
it('uses native agent resume in stream mode when resumeStream is absent', async () => {
|
||||
const resumed = { runId: 'run-2' };
|
||||
const agent = { resume: jest.fn().mockResolvedValue(resumed) };
|
||||
|
||||
await expect(resumeStream(agent, { approved: true }, { runId: 'run-1' })).resolves.toBe(
|
||||
resumed,
|
||||
);
|
||||
expect(agent.resume).toHaveBeenCalledWith('stream', { approved: true }, { runId: 'run-1' });
|
||||
});
|
||||
|
||||
it('throws when the agent cannot resume streams', async () => {
|
||||
await expect(resumeStream({}, { approved: true }, { runId: 'run-1' })).rejects.toThrow(
|
||||
'Agent does not support stream resume',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -34,13 +34,38 @@ export function parseSuspension(chunk: unknown): SuspensionInfo | null {
|
|||
|
||||
/** Type for Mastra's resumeStream method (not exported by the framework). */
|
||||
export interface Resumable {
|
||||
resumeStream: (
|
||||
resumeStream?: (
|
||||
data: Record<string, unknown>,
|
||||
options: Record<string, unknown>,
|
||||
) => Promise<{ runId?: string; fullStream: AsyncIterable<unknown>; text: Promise<string> }>;
|
||||
) => Promise<unknown>;
|
||||
resume?: (
|
||||
method: 'stream',
|
||||
data: Record<string, unknown>,
|
||||
options: Record<string, unknown>,
|
||||
) => Promise<unknown>;
|
||||
}
|
||||
|
||||
/** Cast an agent to Resumable for suspend/resume operations. */
|
||||
export function asResumable(agent: unknown): Resumable {
|
||||
return agent as Resumable;
|
||||
}
|
||||
|
||||
export async function resumeStream(
|
||||
agent: unknown,
|
||||
data: Record<string, unknown>,
|
||||
options: Record<string, unknown>,
|
||||
): Promise<unknown> {
|
||||
if (!isRecord(agent)) {
|
||||
throw new Error('Agent does not support stream resume');
|
||||
}
|
||||
|
||||
if (typeof agent.resumeStream === 'function') {
|
||||
return await agent.resumeStream(data, options);
|
||||
}
|
||||
|
||||
if (typeof agent.resume === 'function') {
|
||||
return await agent.resume('stream', data, options);
|
||||
}
|
||||
|
||||
throw new Error('Agent does not support stream resume');
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user