From 96e5027055a61ddb590cdf97f5b650437e712df5 Mon Sep 17 00:00:00 2001 From: Jake Turner Date: Wed, 11 Mar 2026 05:52:46 +0000 Subject: [PATCH] feat(AI Assistant): performance improvements and smarter RAG context usage --- admin/app/controllers/ollama_controller.ts | 48 ++++- admin/app/services/rag_service.ts | 238 ++++++++++++++------- admin/constants/ollama.ts | 12 +- admin/docs/release-notes.md | 3 + admin/types/rag.ts | 29 +++ 5 files changed, 242 insertions(+), 88 deletions(-) diff --git a/admin/app/controllers/ollama_controller.ts b/admin/app/controllers/ollama_controller.ts index 1557062..d4af8aa 100644 --- a/admin/app/controllers/ollama_controller.ts +++ b/admin/app/controllers/ollama_controller.ts @@ -5,7 +5,7 @@ import { modelNameSchema } from '#validators/download' import { chatSchema, getAvailableModelsSchema } from '#validators/ollama' import { inject } from '@adonisjs/core' import type { HttpContext } from '@adonisjs/core/http' -import { DEFAULT_QUERY_REWRITE_MODEL, SYSTEM_PROMPTS } from '../../constants/ollama.js' +import { DEFAULT_QUERY_REWRITE_MODEL, RAG_CONTEXT_LIMITS, SYSTEM_PROMPTS } from '../../constants/ollama.js' import logger from '@adonisjs/core/services/logger' import type { Message } from 'ollama' @@ -66,9 +66,28 @@ export default class OllamaController { logger.debug(`[RAG] Retrieved ${relevantDocs.length} relevant documents for query: "${rewrittenQuery}"`) - // If relevant context is found, inject as a system message + // If relevant context is found, inject as a system message with adaptive limits if (relevantDocs.length > 0) { - const contextText = relevantDocs + // Determine context budget based on model size + const { maxResults, maxTokens } = this.getContextLimitsForModel(reqData.model) + let trimmedDocs = relevantDocs.slice(0, maxResults) + + // Apply token cap if set (estimate ~4 chars per token) + // Always include the first (most relevant) result — the cap only gates subsequent results + if (maxTokens > 0) { + const charCap = maxTokens * 4 + let totalChars = 0 + trimmedDocs = trimmedDocs.filter((doc, idx) => { + totalChars += doc.text.length + return idx === 0 || totalChars <= charCap + }) + } + + logger.debug( + `[RAG] Injecting ${trimmedDocs.length}/${relevantDocs.length} results (model: ${reqData.model}, maxResults: ${maxResults}, maxTokens: ${maxTokens || 'unlimited'})` + ) + + const contextText = trimmedDocs .map((doc, idx) => `[Context ${idx + 1}] (Relevance: ${(doc.score * 100).toFixed(1)}%)\n${doc.text}`) .join('\n\n') @@ -174,6 +193,25 @@ export default class OllamaController { return await this.ollamaService.getModels() } + /** + * Determines RAG context limits based on model size extracted from the model name. + * Parses size indicators like "1b", "3b", "8b", "70b" from model names/tags. + */ + private getContextLimitsForModel(modelName: string): { maxResults: number; maxTokens: number } { + // Extract parameter count from model name (e.g., "llama3.2:3b", "qwen2.5:1.5b", "gemma:7b") + const sizeMatch = modelName.match(/(\d+\.?\d*)[bB]/) + const paramBillions = sizeMatch ? parseFloat(sizeMatch[1]) : 8 // default to 8B if unknown + + for (const tier of RAG_CONTEXT_LIMITS) { + if (paramBillions <= tier.maxParams) { + return { maxResults: tier.maxResults, maxTokens: tier.maxTokens } + } + } + + // Fallback: no limits + return { maxResults: 5, maxTokens: 0 } + } + private async rewriteQueryWithContext( messages: Message[] ): Promise { @@ -199,8 +237,8 @@ export default class OllamaController { }) .join('\n') - const availableModels = await this.ollamaService.getAvailableModels({ query: null, limit: 500 }) - const rewriteModelAvailable = availableModels?.models.some(model => model.name === DEFAULT_QUERY_REWRITE_MODEL) + const installedModels = await this.ollamaService.getModels(true) + const rewriteModelAvailable = installedModels?.some(model => model.name === DEFAULT_QUERY_REWRITE_MODEL) if (!rewriteModelAvailable) { logger.warn(`[RAG] Query rewrite model "${DEFAULT_QUERY_REWRITE_MODEL}" not available. Skipping query rewriting.`) const lastUserMessage = [...messages].reverse().find(msg => msg.role === 'user') diff --git a/admin/app/services/rag_service.ts b/admin/app/services/rag_service.ts index 05981e3..e6ac043 100644 --- a/admin/app/services/rag_service.ts +++ b/admin/app/services/rag_service.ts @@ -16,11 +16,13 @@ import { join, resolve, sep } from 'node:path' import KVStore from '#models/kv_store' import { ZIMExtractionService } from './zim_extraction_service.js' import { ZIM_BATCH_SIZE } from '../../constants/zim_extraction.js' +import { ProcessAndEmbedFileResponse, ProcessZIMFileResponse, RAGResult, RerankedRAGResult } from '../../types/rag.js' @inject() export class RagService { private qdrant: QdrantClient | null = null private qdrantInitPromise: Promise | null = null + private embeddingModelVerified = false public static UPLOADS_STORAGE_PATH = 'storage/kb_uploads' public static CONTENT_COLLECTION_NAME = 'nomad_knowledge_base' public static EMBEDDING_MODEL = 'nomic-embed-text:v1.5' @@ -33,6 +35,7 @@ export class RagService { // Nomic Embed Text v1.5 uses task-specific prefixes for optimal performance public static SEARCH_DOCUMENT_PREFIX = 'search_document: ' public static SEARCH_QUERY_PREFIX = 'search_query: ' + public static EMBEDDING_BATCH_SIZE = 8 // Conservative batch size for low-end hardware constructor( private dockerService: DockerService, @@ -75,6 +78,16 @@ export class RagService { }, }) } + + // Create payload indexes for faster filtering (idempotent — Qdrant ignores duplicates) + await this.qdrant!.createPayloadIndex(collectionName, { + field_name: 'source', + field_schema: 'keyword', + }) + await this.qdrant!.createPayloadIndex(collectionName, { + field_name: 'content_type', + field_schema: 'keyword', + }) } catch (error) { logger.error('Error ensuring Qdrant collection:', error) throw error @@ -148,14 +161,57 @@ export class RagService { /** * Preprocesses a query to improve retrieval by expanding it with context. * This helps match documents even when using different terminology. + * TODO: We could probably move this to a separate QueryPreprocessor class if it grows more complex, but for now it's manageable here. */ + private static QUERY_EXPANSION_DICTIONARY: Record = { + 'bob': 'bug out bag', + 'bov': 'bug out vehicle', + 'bol': 'bug out location', + 'edc': 'every day carry', + 'mre': 'meal ready to eat', + 'shtf': 'shit hits the fan', + 'teotwawki': 'the end of the world as we know it', + 'opsec': 'operational security', + 'ifak': 'individual first aid kit', + 'ghb': 'get home bag', + 'ghi': 'get home in', + 'wrol': 'without rule of law', + 'emp': 'electromagnetic pulse', + 'ham': 'ham amateur radio', + 'nbr': 'nuclear biological radiological', + 'cbrn': 'chemical biological radiological nuclear', + 'sar': 'search and rescue', + 'comms': 'communications radio', + 'fifo': 'first in first out', + 'mylar': 'mylar bag food storage', + 'paracord': 'paracord 550 cord', + 'ferro': 'ferro rod fire starter', + 'bivvy': 'bivvy bivy emergency shelter', + 'bdu': 'battle dress uniform', + 'gmrs': 'general mobile radio service', + 'frs': 'family radio service', + 'nbc': 'nuclear biological chemical', + } + private preprocessQuery(query: string): string { - // Future: this is a placeholder for more advanced query expansion techniques. - // For now, we simply trim whitespace. Improvements could include: - // - Synonym expansion using a thesaurus - // - Adding related terms based on domain knowledge - // - Using a language model to rephrase or elaborate the query - const expanded = query.trim() + let expanded = query.trim() + + // Expand known domain abbreviations/acronyms + const words = expanded.toLowerCase().split(/\s+/) + const expansions: string[] = [] + + for (const word of words) { + const cleaned = word.replace(/[^\w]/g, '') + if (RagService.QUERY_EXPANSION_DICTIONARY[cleaned]) { + expansions.push(RagService.QUERY_EXPANSION_DICTIONARY[cleaned]) + } + } + + if (expansions.length > 0) { + expanded = `${expanded} ${expansions.join(' ')}` + logger.debug(`[RAG] Query expanded with domain terms: "${expanded}"`) + } + logger.debug(`[RAG] Original query: "${query}"`) logger.debug(`[RAG] Preprocessed query: "${expanded}"`) return expanded @@ -187,22 +243,26 @@ export class RagService { RagService.EMBEDDING_DIMENSION ) - const allModels = await this.ollamaService.getModels(true) - const embeddingModel = allModels.find((model) => model.name === RagService.EMBEDDING_MODEL) + if (!this.embeddingModelVerified) { + const allModels = await this.ollamaService.getModels(true) + const embeddingModel = allModels.find((model) => model.name === RagService.EMBEDDING_MODEL) - if (!embeddingModel) { - try { - const downloadResult = await this.ollamaService.downloadModel(RagService.EMBEDDING_MODEL) - if (!downloadResult.success) { - throw new Error(downloadResult.message || 'Unknown error during model download') + if (!embeddingModel) { + try { + const downloadResult = await this.ollamaService.downloadModel(RagService.EMBEDDING_MODEL) + if (!downloadResult.success) { + throw new Error(downloadResult.message || 'Unknown error during model download') + } + } catch (modelError) { + logger.error( + `[RAG] Embedding model ${RagService.EMBEDDING_MODEL} not found locally and failed to download:`, + modelError + ) + this.embeddingModelVerified = false + return null } - } catch (modelError) { - logger.error( - `[RAG] Embedding model ${RagService.EMBEDDING_MODEL} not found locally and failed to download:`, - modelError - ) - return null } + this.embeddingModelVerified = true } // TokenChunker uses character-based tokenization (1 char = 1 token) @@ -227,7 +287,8 @@ export class RagService { const ollamaClient = await this.ollamaService.getClient() - const embeddings: number[][] = [] + // Prepare all chunk texts with prefix and truncation + const prefixedChunks: string[] = [] for (let i = 0; i < chunks.length; i++) { let chunkText = chunks[i] @@ -237,7 +298,6 @@ export class RagService { const estimatedTokens = this.estimateTokenCount(withPrefix) if (estimatedTokens > RagService.MAX_SAFE_TOKENS) { - // This should be rare - log for debugging if it's occurring frequently const prefixTokens = this.estimateTokenCount(prefixText) const maxTokensForText = RagService.MAX_SAFE_TOKENS - prefixTokens logger.warn( @@ -246,17 +306,30 @@ export class RagService { chunkText = this.truncateToTokenLimit(chunkText, maxTokensForText) } - logger.debug(`[RAG] Generating embedding for chunk ${i + 1}/${chunks.length}`) + prefixedChunks.push(RagService.SEARCH_DOCUMENT_PREFIX + chunkText) + } - const response = await ollamaClient.embeddings({ + // Batch embed chunks for performance + const embeddings: number[][] = [] + const batchSize = RagService.EMBEDDING_BATCH_SIZE + const totalBatches = Math.ceil(prefixedChunks.length / batchSize) + + for (let batchIdx = 0; batchIdx < totalBatches; batchIdx++) { + const batchStart = batchIdx * batchSize + const batch = prefixedChunks.slice(batchStart, batchStart + batchSize) + + logger.debug(`[RAG] Embedding batch ${batchIdx + 1}/${totalBatches} (${batch.length} chunks)`) + + const response = await ollamaClient.embed({ model: RagService.EMBEDDING_MODEL, - prompt: RagService.SEARCH_DOCUMENT_PREFIX + chunkText, + input: batch, }) - embeddings.push(response.embedding) + embeddings.push(...response.embeddings) if (onProgress) { - await onProgress(((i + 1) / chunks.length) * 100) + const progress = ((batchStart + batch.length) / prefixedChunks.length) * 100 + await onProgress(progress) } } @@ -395,14 +468,7 @@ export class RagService { deleteAfterEmbedding: boolean, batchOffset?: number, onProgress?: (percent: number) => Promise - ): Promise<{ - success: boolean - message: string - chunks?: number - hasMoreBatches?: boolean - articlesProcessed?: number - totalArticles?: number - }> { + ): Promise { const zimExtractionService = new ZIMExtractionService() // Process in batches to avoid lock timeout @@ -540,14 +606,7 @@ export class RagService { deleteAfterEmbedding: boolean = false, batchOffset?: number, onProgress?: (percent: number) => Promise - ): Promise<{ - success: boolean - message: string - chunks?: number - hasMoreBatches?: boolean - articlesProcessed?: number - totalArticles?: number - }> { + ): Promise { try { const fileType = determineFileType(filepath) logger.debug(`[RAG] Processing file: ${filepath} (detected type: ${fileType})`) @@ -631,14 +690,18 @@ export class RagService { return [] } - const allModels = await this.ollamaService.getModels(true) - const embeddingModel = allModels.find((model) => model.name === RagService.EMBEDDING_MODEL) + if (!this.embeddingModelVerified) { + const allModels = await this.ollamaService.getModels(true) + const embeddingModel = allModels.find((model) => model.name === RagService.EMBEDDING_MODEL) - if (!embeddingModel) { - logger.warn( - `[RAG] ${RagService.EMBEDDING_MODEL} not found. Cannot perform similarity search.` - ) - return [] + if (!embeddingModel) { + logger.warn( + `[RAG] ${RagService.EMBEDDING_MODEL} not found. Cannot perform similarity search.` + ) + this.embeddingModelVerified = false + return [] + } + this.embeddingModelVerified = true } // Preprocess query for better matching @@ -666,9 +729,9 @@ export class RagService { return [] } - const response = await ollamaClient.embeddings({ + const response = await ollamaClient.embed({ model: RagService.EMBEDDING_MODEL, - prompt: prefixedQuery, + input: [prefixedQuery], }) // Perform semantic search with a higher limit to enable reranking @@ -678,7 +741,7 @@ export class RagService { ) const searchResults = await this.qdrant!.search(RagService.CONTENT_COLLECTION_NAME, { - vector: response.embedding, + vector: response.embeddings[0], limit: searchLimit, score_threshold: scoreThreshold, with_payload: true, @@ -687,7 +750,7 @@ export class RagService { logger.debug(`[RAG] Found ${searchResults.length} results above threshold ${scoreThreshold}`) // Map results with metadata for reranking - const resultsWithMetadata = searchResults.map((result) => ({ + const resultsWithMetadata: RAGResult[] = searchResults.map((result) => ({ text: (result.payload?.text as string) || '', score: result.score, keywords: (result.payload?.keywords as string) || '', @@ -700,6 +763,7 @@ export class RagService { hierarchy: result.payload?.hierarchy as string | undefined, document_id: result.payload?.document_id as string | undefined, content_type: result.payload?.content_type as string | undefined, + source: result.payload?.source as string | undefined, })) const rerankedResults = this.rerankResults(resultsWithMetadata, keywords, query) @@ -711,8 +775,11 @@ export class RagService { ) }) + // Apply source diversity penalty to avoid all results from the same document + const diverseResults = this.applySourceDiversity(rerankedResults) + // Return top N results with enhanced metadata - return rerankedResults.slice(0, limit).map((result) => ({ + return diverseResults.slice(0, limit).map((result) => ({ text: result.text, score: result.finalScore, metadata: { @@ -748,34 +815,10 @@ export class RagService { * outweigh the overhead. */ private rerankResults( - results: Array<{ - text: string - score: number - keywords: string - chunk_index: number - created_at: number - article_title?: string - section_title?: string - full_title?: string - hierarchy?: string - document_id?: string - content_type?: string - }>, + results: Array, queryKeywords: string[], originalQuery: string - ): Array<{ - text: string - score: number - finalScore: number - chunk_index: number - created_at: number - article_title?: string - section_title?: string - full_title?: string - hierarchy?: string - document_id?: string - content_type?: string - }> { + ): Array { return results .map((result) => { let finalScore = result.score @@ -851,6 +894,37 @@ export class RagService { .sort((a, b) => b.finalScore - a.finalScore) } + /** + * Applies a diversity penalty so results from the same source are down-weighted. + * Uses greedy selection: for each result, apply 0.85^n penalty where n is the + * number of results already selected from the same source. + */ + private applySourceDiversity( + results: Array + ) { + const sourceCounts = new Map() + const DIVERSITY_PENALTY = 0.85 + + return results + .map((result) => { + const sourceKey = result.document_id || result.source || 'unknown' + const count = sourceCounts.get(sourceKey) || 0 + const penalty = Math.pow(DIVERSITY_PENALTY, count) + const diverseScore = result.finalScore * penalty + + sourceCounts.set(sourceKey, count + 1) + + if (count > 0) { + logger.debug( + `[RAG] Source diversity penalty for "${sourceKey}": ${result.finalScore.toFixed(4)} → ${diverseScore.toFixed(4)} (seen ${count}x)` + ) + } + + return { ...result, finalScore: diverseScore } + }) + .sort((a, b) => b.finalScore - a.finalScore) + } + /** * Retrieve all unique source files that have been stored in the knowledge base. * @returns Array of unique full source paths @@ -866,12 +940,12 @@ export class RagService { let offset: string | number | null | Record = null const batchSize = 100 - // Scroll through all points in the collection + // Scroll through all points in the collection (only fetch source field) do { const scrollResult = await this.qdrant!.scroll(RagService.CONTENT_COLLECTION_NAME, { limit: batchSize, offset: offset, - with_payload: true, + with_payload: ['source'], with_vector: false, }) diff --git a/admin/constants/ollama.ts b/admin/constants/ollama.ts index dd0a1a6..5581832 100644 --- a/admin/constants/ollama.ts +++ b/admin/constants/ollama.ts @@ -64,6 +64,16 @@ export const FALLBACK_RECOMMENDED_OLLAMA_MODELS: NomadOllamaModel[] = [ export const DEFAULT_QUERY_REWRITE_MODEL = 'qwen2.5:3b' // default to qwen2.5 for query rewriting with good balance of text task performance and resource usage +/** + * Adaptive RAG context limits based on model size. + * Smaller models get overwhelmed with too much context, so we cap it. + */ +export const RAG_CONTEXT_LIMITS: { maxParams: number; maxResults: number; maxTokens: number }[] = [ + { maxParams: 3, maxResults: 2, maxTokens: 1000 }, // 1-3B models + { maxParams: 8, maxResults: 4, maxTokens: 2500 }, // 4-8B models + { maxParams: Infinity, maxResults: 5, maxTokens: 0 }, // 13B+ (no cap) +] + export const SYSTEM_PROMPTS = { default: ` Format all responses using markdown for better readability. Vanilla markdown or GitHub-flavored markdown is preferred. @@ -113,7 +123,7 @@ Ensure that your suggestions are comma-seperated with no conjunctions like "and" Do not use line breaks, new lines, or extra spacing to separate the suggestions. Format: suggestion1, suggestion2, suggestion3 `, - title_generation: `You are a title generator. Given the start of a conversation, generate a concise, descriptive title under 60 characters. Return ONLY the title text with no quotes, punctuation wrapping, or extra formatting.`, + title_generation: `You are a title generator. Given the start of a conversation, generate a concise, descriptive title under 50 characters. Return ONLY the title text with no quotes, punctuation wrapping, or extra formatting.`, query_rewrite: ` You are a query rewriting assistant. Your task is to reformulate the user's latest question to include relevant context from the conversation history. diff --git a/admin/docs/release-notes.md b/admin/docs/release-notes.md index 9dc4db6..c2012e4 100644 --- a/admin/docs/release-notes.md +++ b/admin/docs/release-notes.md @@ -4,13 +4,16 @@ ### Features - **AI Assistant**: Added improved user guidance for troubleshooting GPU pass-through issues +- **AI Assistant**: The last used model is now automatically selected when a new chat is started - **Settings**: Nomad now automatically performs nightly checks for available app updates, and users can select and apply updates from the Apps page in Settings ### Bug Fixes - **Settings**: Fixed an issue where the AI Assistant settings page would be shown in navigation even if the AI Assistant was not installed, thus causing 404 errors when clicked - **Security**: Path traversal and SSRF mitigations +- **AI Assistant**: Fixed an issue that was causing intermittent failures saving chat session titles ### Improvements +- **AI Assistant**: Extensive performance improvements and improved RAG intelligence/context usage ## Version 1.28.0 - March 5, 2026 diff --git a/admin/types/rag.ts b/admin/types/rag.ts index f44dca9..1d429ea 100644 --- a/admin/types/rag.ts +++ b/admin/types/rag.ts @@ -5,3 +5,32 @@ export type EmbedJobWithProgress = { progress: number status: string } + +export type ProcessAndEmbedFileResponse = { + success: boolean + message: string + chunks?: number + hasMoreBatches?: boolean + articlesProcessed?: number + totalArticles?: number +} +export type ProcessZIMFileResponse = ProcessAndEmbedFileResponse + +export type RAGResult = { + text: string + score: number + keywords: string + chunk_index: number + created_at: number + article_title?: string + section_title?: string + full_title?: string + hierarchy?: string + document_id?: string + content_type?: string + source?: string +} + +export type RerankedRAGResult = Omit & { + finalScore: number +} \ No newline at end of file