feat(AI Agent Node): Support reading only maxTokensFromMemory (no-changelog) (#20915)

This commit is contained in:
Jaakko Husso 2025-10-20 09:17:58 +03:00 committed by GitHub
parent be27e94cb5
commit b9f66aee4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 135 additions and 16 deletions

View File

@ -12,11 +12,25 @@ const enableStreaminOption: INodeProperties = {
description: 'Whether this agent will stream the response in real-time as it generates text',
};
const maxTokensFromMemoryOption: INodeProperties = {
displayName: 'Max Tokens To Read From Memory',
name: 'maxTokensFromMemory',
type: 'hidden',
default: 0,
description:
'The maximum number of tokens to read from the chat memory history. Set to 0 to read all history.',
};
export const toolsAgentProperties: INodeProperties = {
displayName: 'Options',
name: 'options',
type: 'collection',
default: {},
placeholder: 'Add Option',
options: [...commonOptions, enableStreaminOption, getBatchingOptionFields(undefined, 1)],
options: [
...commonOptions,
enableStreaminOption,
getBatchingOptionFields(undefined, 1),
maxTokensFromMemoryOption,
],
};

View File

@ -1,15 +1,11 @@
import type { StreamEvent } from '@langchain/core/dist/tracers/event_stream';
import type { IterableReadableStream } from '@langchain/core/dist/utils/stream';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { AIMessageChunk, MessageContentText } from '@langchain/core/messages';
import { AIMessage } from '@langchain/core/messages';
import type { AIMessageChunk, BaseMessage, MessageContentText } from '@langchain/core/messages';
import { AIMessage, trimMessages } from '@langchain/core/messages';
import type { ToolCall } from '@langchain/core/messages/tool';
import type { ChatPromptTemplate } from '@langchain/core/prompts';
import { RunnableSequence } from '@langchain/core/runnables';
import { getPromptInputByType } from '@utils/helpers';
import {
getOptionalOutputParser,
type N8nOutputParser,
} from '@utils/output_parsers/N8nOutputParser';
import { type AgentRunnableSequence, createToolCallingAgent } from 'langchain/agents';
import type { BaseChatMemory } from 'langchain/memory';
import type { DynamicStructuredTool, Tool } from 'langchain/tools';
@ -32,6 +28,12 @@ import type {
} from 'n8n-workflow';
import assert from 'node:assert';
import { getPromptInputByType } from '@utils/helpers';
import {
getOptionalOutputParser,
type N8nOutputParser,
} from '@utils/output_parsers/N8nOutputParser';
import {
fixEmptyContentMessage,
getAgentStepsParser,
@ -42,7 +44,6 @@ import {
preparePrompt,
} from '../common';
import { SYSTEM_MESSAGE } from '../prompt';
import type { ToolCall } from '@langchain/core/messages/tool';
type ToolCallRequest = {
tool: string;
@ -405,6 +406,7 @@ export async function toolsAgentExecute(
returnIntermediateSteps?: boolean;
passthroughBinaryImages?: boolean;
enableStreaming?: boolean;
maxTokensFromMemory?: number;
};
if (options.enableStreaming === undefined) {
@ -448,11 +450,10 @@ export async function toolsAgentExecute(
isStreamingAvailable &&
this.getNode().typeVersion >= 2.1
) {
let chatHistory = undefined;
let chatHistory: BaseMessage[] | undefined = undefined;
if (memory) {
// Load memory variables to respect context window length
const memoryVariables = await memory.loadMemoryVariables({});
chatHistory = memoryVariables['chat_history'];
chatHistory = await loadChatHistory(memory, model, options.maxTokensFromMemory);
}
const eventStream = executor.streamEvents(
{
@ -487,11 +488,10 @@ export async function toolsAgentExecute(
return result;
} else {
// Handle regular execution
let chatHistory = undefined;
let chatHistory: BaseMessage[] | undefined = undefined;
if (memory) {
// Load memory variables to respect context window length
const memoryVariables = await memory.loadMemoryVariables({});
chatHistory = memoryVariables['chat_history'];
chatHistory = await loadChatHistory(memory, model, options.maxTokensFromMemory);
}
const modelResponse = await executor.invoke({
...invokeParams,
@ -603,3 +603,24 @@ export async function toolsAgentExecute(
// Otherwise return execution data
return [returnData];
}
async function loadChatHistory(
memory: BaseChatMemory,
model: BaseChatModel,
maxTokensFromMemory?: number,
): Promise<BaseMessage[]> {
const memoryVariables = await memory.loadMemoryVariables({});
let chatHistory = memoryVariables['chat_history'] as BaseMessage[];
if (maxTokensFromMemory) {
chatHistory = await trimMessages(chatHistory, {
strategy: 'last',
maxTokens: maxTokensFromMemory,
tokenCounter: model,
includeSystem: true,
startOn: 'human',
allowPartial: true,
});
}
return chatHistory;
}

View File

@ -1,5 +1,10 @@
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { AIMessageChunk } from '@langchain/core/messages';
import {
AIMessage,
HumanMessage,
type BaseMessage,
type AIMessageChunk,
} from '@langchain/core/messages';
import type { ChatPromptTemplate } from '@langchain/core/prompts';
import { RunnableSequence } from '@langchain/core/runnables';
import { mock } from 'jest-mock-extended';
@ -631,6 +636,82 @@ describe('toolsAgentExecute V3', () => {
);
});
it('should trim chat history to fit within `maxTokensFromMemory` limits', async () => {
const mockNode = mock<INode>();
mockNode.typeVersion = 3;
mockContext.getNode.mockReturnValue(mockNode);
mockContext.getInputData.mockReturnValue([{ json: { text: 'test input' } }]);
// Mock model that only counts "tokens", and for simplicity we say each character is a token.
// Normally we pass a BaseLanguageModel and its `BaseLanguageModel.getNumTokens()` is used but I couldn't to mock that.
const mockModel = (async (messages: BaseMessage[]) => {
return await Promise.resolve(
messages.map((m: BaseMessage) => m.content.length).reduce((a, b) => a + b, 0),
);
}) as unknown as BaseChatModel;
const mockHistory: BaseMessage[] = [
new HumanMessage({ content: 'Lorem ipsum dolor sit amet, consectetur adipiscing elit.' }), // 56 "tokens"
new AIMessage({ content: 'Vivamus volutpat felis a sapien viverra pretium.' }), // 48 "tokens"
new HumanMessage({
content: 'Nam id felis condimentum, venenatis ligula non, pulvinar nunc.',
}), // 62
new AIMessage({ content: 'Praesent eget ante magna.' }), // 25 "tokens"
new HumanMessage({
content: 'Curabitur euismod sem at dui efficitur, at convallis erat facilisis.', // 68 "tokens"
}),
new AIMessage({ content: 'Sed nec eros euismod, tincidunt nunc at, fermentum massa.' }), // 57 "tokens"
];
const mockMemory = mock<BaseChatMemory>();
(mockMemory.saveContext as jest.Mock) = jest.fn();
(mockMemory.loadMemoryVariables as jest.Mock) = jest.fn().mockResolvedValue({
chat_history: mockHistory,
});
mockMemory.chatHistory = { getMessages: jest.fn().mockResolvedValue(mockHistory) } as any;
const mockAgent = mock<any>();
const mockRunnableSequence = mock<any>();
mockRunnableSequence.singleAction = true;
mockRunnableSequence.streamRunnable = false;
mockRunnableSequence.invoke = jest
.fn()
.mockResolvedValue({ returnValues: { output: 'success' } });
(createToolCallingAgent as jest.Mock).mockReturnValue(mockAgent);
(RunnableSequence.from as jest.Mock).mockReturnValue(mockRunnableSequence);
jest.spyOn(commonHelpers, 'getChatModel').mockResolvedValue(mockModel);
jest.spyOn(commonHelpers, 'getOptionalMemory').mockResolvedValue(mockMemory);
jest.spyOn(commonHelpers, 'getTools').mockResolvedValue([mock<Tool>()]);
jest.spyOn(commonHelpers, 'prepareMessages').mockResolvedValue([]);
jest.spyOn(commonHelpers, 'preparePrompt').mockReturnValue(mock<ChatPromptTemplate>());
jest.spyOn(helpers, 'getPromptInputByType').mockReturnValue('test input');
mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => {
if (param === 'needsFallback') return false;
if (param === 'options.enableStreaming') return false;
if (param === 'options')
return {
systemMessage: 'You are a helpful assistant',
maxIterations: 10,
returnIntermediateSteps: false,
passthroughBinaryImages: true,
maxTokensFromMemory: 250, // Last four messages fit (25+68+57+62=212), first two (56+48=104) get removed
};
return defaultValue;
});
mockContext.getExecutionCancelSignal.mockReturnValue(new AbortController().signal);
await toolsAgentExecute.call(mockContext);
expect(mockRunnableSequence.invoke).toHaveBeenCalledWith(
expect.objectContaining({
chat_history: [mockHistory[2], mockHistory[3], mockHistory[4], mockHistory[5]],
}),
);
});
it('should handle errors in batch processing when continueOnFail is true', async () => {
const mockNode = mock<INode>();
mockNode.typeVersion = 3;

View File

@ -65,6 +65,9 @@ function getParameterOptionLabel(
}
function displayNodeParameter(parameter: INodeProperties) {
if (parameter.type === 'hidden') {
return false;
}
if (parameter.displayOptions === undefined) {
// If it is not defined no need to do a proper check
return true;