mirror of
https://github.com/n8n-io/n8n.git
synced 2026-05-29 15:57:00 +02:00
feat(AI Agent Node): Support reading only maxTokensFromMemory (no-changelog) (#20915)
This commit is contained in:
parent
be27e94cb5
commit
b9f66aee4d
|
|
@ -12,11 +12,25 @@ const enableStreaminOption: INodeProperties = {
|
|||
description: 'Whether this agent will stream the response in real-time as it generates text',
|
||||
};
|
||||
|
||||
const maxTokensFromMemoryOption: INodeProperties = {
|
||||
displayName: 'Max Tokens To Read From Memory',
|
||||
name: 'maxTokensFromMemory',
|
||||
type: 'hidden',
|
||||
default: 0,
|
||||
description:
|
||||
'The maximum number of tokens to read from the chat memory history. Set to 0 to read all history.',
|
||||
};
|
||||
|
||||
export const toolsAgentProperties: INodeProperties = {
|
||||
displayName: 'Options',
|
||||
name: 'options',
|
||||
type: 'collection',
|
||||
default: {},
|
||||
placeholder: 'Add Option',
|
||||
options: [...commonOptions, enableStreaminOption, getBatchingOptionFields(undefined, 1)],
|
||||
options: [
|
||||
...commonOptions,
|
||||
enableStreaminOption,
|
||||
getBatchingOptionFields(undefined, 1),
|
||||
maxTokensFromMemoryOption,
|
||||
],
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,15 +1,11 @@
|
|||
import type { StreamEvent } from '@langchain/core/dist/tracers/event_stream';
|
||||
import type { IterableReadableStream } from '@langchain/core/dist/utils/stream';
|
||||
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import type { AIMessageChunk, MessageContentText } from '@langchain/core/messages';
|
||||
import { AIMessage } from '@langchain/core/messages';
|
||||
import type { AIMessageChunk, BaseMessage, MessageContentText } from '@langchain/core/messages';
|
||||
import { AIMessage, trimMessages } from '@langchain/core/messages';
|
||||
import type { ToolCall } from '@langchain/core/messages/tool';
|
||||
import type { ChatPromptTemplate } from '@langchain/core/prompts';
|
||||
import { RunnableSequence } from '@langchain/core/runnables';
|
||||
import { getPromptInputByType } from '@utils/helpers';
|
||||
import {
|
||||
getOptionalOutputParser,
|
||||
type N8nOutputParser,
|
||||
} from '@utils/output_parsers/N8nOutputParser';
|
||||
import { type AgentRunnableSequence, createToolCallingAgent } from 'langchain/agents';
|
||||
import type { BaseChatMemory } from 'langchain/memory';
|
||||
import type { DynamicStructuredTool, Tool } from 'langchain/tools';
|
||||
|
|
@ -32,6 +28,12 @@ import type {
|
|||
} from 'n8n-workflow';
|
||||
import assert from 'node:assert';
|
||||
|
||||
import { getPromptInputByType } from '@utils/helpers';
|
||||
import {
|
||||
getOptionalOutputParser,
|
||||
type N8nOutputParser,
|
||||
} from '@utils/output_parsers/N8nOutputParser';
|
||||
|
||||
import {
|
||||
fixEmptyContentMessage,
|
||||
getAgentStepsParser,
|
||||
|
|
@ -42,7 +44,6 @@ import {
|
|||
preparePrompt,
|
||||
} from '../common';
|
||||
import { SYSTEM_MESSAGE } from '../prompt';
|
||||
import type { ToolCall } from '@langchain/core/messages/tool';
|
||||
|
||||
type ToolCallRequest = {
|
||||
tool: string;
|
||||
|
|
@ -405,6 +406,7 @@ export async function toolsAgentExecute(
|
|||
returnIntermediateSteps?: boolean;
|
||||
passthroughBinaryImages?: boolean;
|
||||
enableStreaming?: boolean;
|
||||
maxTokensFromMemory?: number;
|
||||
};
|
||||
|
||||
if (options.enableStreaming === undefined) {
|
||||
|
|
@ -448,11 +450,10 @@ export async function toolsAgentExecute(
|
|||
isStreamingAvailable &&
|
||||
this.getNode().typeVersion >= 2.1
|
||||
) {
|
||||
let chatHistory = undefined;
|
||||
let chatHistory: BaseMessage[] | undefined = undefined;
|
||||
if (memory) {
|
||||
// Load memory variables to respect context window length
|
||||
const memoryVariables = await memory.loadMemoryVariables({});
|
||||
chatHistory = memoryVariables['chat_history'];
|
||||
chatHistory = await loadChatHistory(memory, model, options.maxTokensFromMemory);
|
||||
}
|
||||
const eventStream = executor.streamEvents(
|
||||
{
|
||||
|
|
@ -487,11 +488,10 @@ export async function toolsAgentExecute(
|
|||
return result;
|
||||
} else {
|
||||
// Handle regular execution
|
||||
let chatHistory = undefined;
|
||||
let chatHistory: BaseMessage[] | undefined = undefined;
|
||||
if (memory) {
|
||||
// Load memory variables to respect context window length
|
||||
const memoryVariables = await memory.loadMemoryVariables({});
|
||||
chatHistory = memoryVariables['chat_history'];
|
||||
chatHistory = await loadChatHistory(memory, model, options.maxTokensFromMemory);
|
||||
}
|
||||
const modelResponse = await executor.invoke({
|
||||
...invokeParams,
|
||||
|
|
@ -603,3 +603,24 @@ export async function toolsAgentExecute(
|
|||
// Otherwise return execution data
|
||||
return [returnData];
|
||||
}
|
||||
async function loadChatHistory(
|
||||
memory: BaseChatMemory,
|
||||
model: BaseChatModel,
|
||||
maxTokensFromMemory?: number,
|
||||
): Promise<BaseMessage[]> {
|
||||
const memoryVariables = await memory.loadMemoryVariables({});
|
||||
let chatHistory = memoryVariables['chat_history'] as BaseMessage[];
|
||||
|
||||
if (maxTokensFromMemory) {
|
||||
chatHistory = await trimMessages(chatHistory, {
|
||||
strategy: 'last',
|
||||
maxTokens: maxTokensFromMemory,
|
||||
tokenCounter: model,
|
||||
includeSystem: true,
|
||||
startOn: 'human',
|
||||
allowPartial: true,
|
||||
});
|
||||
}
|
||||
|
||||
return chatHistory;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import type { AIMessageChunk } from '@langchain/core/messages';
|
||||
import {
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
type BaseMessage,
|
||||
type AIMessageChunk,
|
||||
} from '@langchain/core/messages';
|
||||
import type { ChatPromptTemplate } from '@langchain/core/prompts';
|
||||
import { RunnableSequence } from '@langchain/core/runnables';
|
||||
import { mock } from 'jest-mock-extended';
|
||||
|
|
@ -631,6 +636,82 @@ describe('toolsAgentExecute V3', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should trim chat history to fit within `maxTokensFromMemory` limits', async () => {
|
||||
const mockNode = mock<INode>();
|
||||
mockNode.typeVersion = 3;
|
||||
mockContext.getNode.mockReturnValue(mockNode);
|
||||
mockContext.getInputData.mockReturnValue([{ json: { text: 'test input' } }]);
|
||||
|
||||
// Mock model that only counts "tokens", and for simplicity we say each character is a token.
|
||||
// Normally we pass a BaseLanguageModel and its `BaseLanguageModel.getNumTokens()` is used but I couldn't to mock that.
|
||||
const mockModel = (async (messages: BaseMessage[]) => {
|
||||
return await Promise.resolve(
|
||||
messages.map((m: BaseMessage) => m.content.length).reduce((a, b) => a + b, 0),
|
||||
);
|
||||
}) as unknown as BaseChatModel;
|
||||
|
||||
const mockHistory: BaseMessage[] = [
|
||||
new HumanMessage({ content: 'Lorem ipsum dolor sit amet, consectetur adipiscing elit.' }), // 56 "tokens"
|
||||
new AIMessage({ content: 'Vivamus volutpat felis a sapien viverra pretium.' }), // 48 "tokens"
|
||||
new HumanMessage({
|
||||
content: 'Nam id felis condimentum, venenatis ligula non, pulvinar nunc.',
|
||||
}), // 62
|
||||
new AIMessage({ content: 'Praesent eget ante magna.' }), // 25 "tokens"
|
||||
new HumanMessage({
|
||||
content: 'Curabitur euismod sem at dui efficitur, at convallis erat facilisis.', // 68 "tokens"
|
||||
}),
|
||||
new AIMessage({ content: 'Sed nec eros euismod, tincidunt nunc at, fermentum massa.' }), // 57 "tokens"
|
||||
];
|
||||
const mockMemory = mock<BaseChatMemory>();
|
||||
(mockMemory.saveContext as jest.Mock) = jest.fn();
|
||||
(mockMemory.loadMemoryVariables as jest.Mock) = jest.fn().mockResolvedValue({
|
||||
chat_history: mockHistory,
|
||||
});
|
||||
mockMemory.chatHistory = { getMessages: jest.fn().mockResolvedValue(mockHistory) } as any;
|
||||
|
||||
const mockAgent = mock<any>();
|
||||
const mockRunnableSequence = mock<any>();
|
||||
mockRunnableSequence.singleAction = true;
|
||||
mockRunnableSequence.streamRunnable = false;
|
||||
mockRunnableSequence.invoke = jest
|
||||
.fn()
|
||||
.mockResolvedValue({ returnValues: { output: 'success' } });
|
||||
|
||||
(createToolCallingAgent as jest.Mock).mockReturnValue(mockAgent);
|
||||
(RunnableSequence.from as jest.Mock).mockReturnValue(mockRunnableSequence);
|
||||
|
||||
jest.spyOn(commonHelpers, 'getChatModel').mockResolvedValue(mockModel);
|
||||
jest.spyOn(commonHelpers, 'getOptionalMemory').mockResolvedValue(mockMemory);
|
||||
jest.spyOn(commonHelpers, 'getTools').mockResolvedValue([mock<Tool>()]);
|
||||
jest.spyOn(commonHelpers, 'prepareMessages').mockResolvedValue([]);
|
||||
jest.spyOn(commonHelpers, 'preparePrompt').mockReturnValue(mock<ChatPromptTemplate>());
|
||||
jest.spyOn(helpers, 'getPromptInputByType').mockReturnValue('test input');
|
||||
|
||||
mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => {
|
||||
if (param === 'needsFallback') return false;
|
||||
if (param === 'options.enableStreaming') return false;
|
||||
if (param === 'options')
|
||||
return {
|
||||
systemMessage: 'You are a helpful assistant',
|
||||
maxIterations: 10,
|
||||
returnIntermediateSteps: false,
|
||||
passthroughBinaryImages: true,
|
||||
maxTokensFromMemory: 250, // Last four messages fit (25+68+57+62=212), first two (56+48=104) get removed
|
||||
};
|
||||
return defaultValue;
|
||||
});
|
||||
|
||||
mockContext.getExecutionCancelSignal.mockReturnValue(new AbortController().signal);
|
||||
|
||||
await toolsAgentExecute.call(mockContext);
|
||||
|
||||
expect(mockRunnableSequence.invoke).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
chat_history: [mockHistory[2], mockHistory[3], mockHistory[4], mockHistory[5]],
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors in batch processing when continueOnFail is true', async () => {
|
||||
const mockNode = mock<INode>();
|
||||
mockNode.typeVersion = 3;
|
||||
|
|
|
|||
|
|
@ -65,6 +65,9 @@ function getParameterOptionLabel(
|
|||
}
|
||||
|
||||
function displayNodeParameter(parameter: INodeProperties) {
|
||||
if (parameter.type === 'hidden') {
|
||||
return false;
|
||||
}
|
||||
if (parameter.displayOptions === undefined) {
|
||||
// If it is not defined no need to do a proper check
|
||||
return true;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user