- Published on
Advanced RAG Architecture: Enterprise Patterns for Production Systems
- Authors
- Name
- Gary Huynh
- @gary_atruedev
Advanced RAG Architecture: Enterprise Patterns for Production Systems
Introduction: Beyond Basic RAG
The journey from a proof-of-concept RAG system to a production-grade enterprise solution is fraught with challenges that most tutorials don't address. When you're dealing with millions of documents, strict security requirements, and sub-second response times, the simple "embed-store-retrieve" pattern breaks down quickly.
In this comprehensive guide, we'll explore advanced RAG architecture patterns that have been battle-tested in enterprise environments. We'll cover everything from hybrid search strategies to multi-tenant data isolation, performance optimization techniques, and cost-effective scaling strategies.
The Enterprise RAG Challenge
Enterprise RAG systems face unique challenges:
- Scale: Handling 10M+ documents across multiple languages and formats
- Security: Document-level access control and data isolation
- Performance: Sub-200ms retrieval latency at 99th percentile
- Accuracy: 95%+ relevance scores for critical queries
- Cost: Optimizing embedding and inference costs at scale
- Compliance: GDPR, HIPAA, and other regulatory requirements
Let's dive into the architectural patterns that address these challenges.
Hybrid Search Architecture: Vector + Keyword + Knowledge Graph
The most significant limitation of pure vector search is its inability to handle exact matches, acronyms, and domain-specific terminology effectively. Enterprise RAG systems require a hybrid approach that combines multiple search strategies.
The Three-Pillar Search Architecture
@Component
public class HybridSearchOrchestrator {
private final VectorSearchService vectorSearch;
private final ElasticsearchService keywordSearch;
private final Neo4jGraphService graphSearch;
private final SearchRankingService rankingService;
@Value("${search.vector.weight:0.5}")
private double vectorWeight;
@Value("${search.keyword.weight:0.3}")
private double keywordWeight;
@Value("${search.graph.weight:0.2}")
private double graphWeight;
public SearchResults hybridSearch(SearchQuery query, SearchContext context) {
// Execute searches in parallel
CompletableFuture<List<VectorMatch>> vectorFuture =
CompletableFuture.supplyAsync(() ->
vectorSearch.search(query, context));
CompletableFuture<List<KeywordMatch>> keywordFuture =
CompletableFuture.supplyAsync(() ->
keywordSearch.search(query, context));
CompletableFuture<List<GraphMatch>> graphFuture =
CompletableFuture.supplyAsync(() ->
graphSearch.searchRelatedConcepts(query, context));
// Wait for all searches to complete
CompletableFuture.allOf(vectorFuture, keywordFuture, graphFuture).join();
// Merge and rank results
List<SearchResult> mergedResults = rankingService.mergeAndRank(
vectorFuture.join(),
keywordFuture.join(),
graphFuture.join(),
new RankingWeights(vectorWeight, keywordWeight, graphWeight)
);
return SearchResults.builder()
.results(mergedResults)
.searchLatency(System.currentTimeMillis() - startTime)
.searchStrategy("hybrid")
.build();
}
}
Vector Search Component
The vector search component handles semantic similarity matching:
@Service
public class VectorSearchService {
private final EmbeddingService embeddingService;
private final PineconeClient pineconeClient;
private final SearchOptimizer searchOptimizer;
public List<VectorMatch> search(SearchQuery query, SearchContext context) {
// Generate query embedding with caching
float[] queryEmbedding = embeddingService.embedWithCache(
query.getText(),
query.getLanguage()
);
// Optimize search parameters based on query characteristics
SearchParameters params = searchOptimizer.optimize(query, context);
// Build metadata filter for multi-tenant isolation
Map<String, Object> metadataFilter = buildMetadataFilter(context);
// Execute vector search
QueryRequest request = QueryRequest.builder()
.vector(queryEmbedding)
.topK(params.getTopK())
.filter(metadataFilter)
.includeMetadata(true)
.namespace(context.getTenantId())
.build();
QueryResponse response = pineconeClient.query(request);
// Post-process results with re-ranking
return response.getMatches().stream()
.map(match -> VectorMatch.builder()
.documentId(match.getId())
.score(match.getScore())
.metadata(match.getMetadata())
.snippet(extractRelevantSnippet(match, query))
.build())
.collect(Collectors.toList());
}
private String extractRelevantSnippet(Match match, SearchQuery query) {
// Use attention-based snippet extraction
String fullText = match.getMetadata().get("content").toString();
return snippetExtractor.extractMostRelevant(
fullText,
query.getText(),
MAX_SNIPPET_LENGTH
);
}
}
Keyword Search Component
The keyword search component handles exact matches and boolean queries:
@Service
public class ElasticsearchService {
private final RestHighLevelClient elasticsearchClient;
private final QueryAnalyzer queryAnalyzer;
public List<KeywordMatch> search(SearchQuery query, SearchContext context) {
// Analyze query for special operators and phrases
AnalyzedQuery analyzed = queryAnalyzer.analyze(query.getText());
// Build Elasticsearch query
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
// Add must clauses for exact phrases
analyzed.getPhrases().forEach(phrase ->
boolQuery.must(QueryBuilders.matchPhraseQuery("content", phrase))
);
// Add should clauses for individual terms
analyzed.getTerms().forEach(term ->
boolQuery.should(QueryBuilders.matchQuery("content", term)
.boost(term.getImportance()))
);
// Add filter for tenant isolation
boolQuery.filter(QueryBuilders.termQuery("tenantId", context.getTenantId()));
// Add filter for access control
boolQuery.filter(QueryBuilders.termsQuery("allowedGroups",
context.getUserGroups().toArray(new String[0])));
// Execute search
SearchRequest searchRequest = new SearchRequest("documents")
.source(SearchSourceBuilder.searchSource()
.query(boolQuery)
.size(query.getLimit())
.highlight(HighlightBuilder.highlight()
.field("content")
.preTags("<mark>")
.postTags("</mark>")));
SearchResponse response = elasticsearchClient.search(searchRequest,
RequestOptions.DEFAULT);
return Arrays.stream(response.getHits().getHits())
.map(this::convertToKeywordMatch)
.collect(Collectors.toList());
}
}
Knowledge Graph Component
The knowledge graph component provides context-aware search through entity relationships:
@Service
public class Neo4jGraphService {
private final Driver neo4jDriver;
private final EntityExtractor entityExtractor;
public List<GraphMatch> searchRelatedConcepts(SearchQuery query,
SearchContext context) {
// Extract entities from query
List<Entity> entities = entityExtractor.extract(query.getText());
if (entities.isEmpty()) {
return Collections.emptyList();
}
try (Session session = neo4jDriver.session()) {
// Build Cypher query for related concepts
String cypher = """
MATCH (e:Entity)-[r:RELATED_TO*1..2]-(related:Entity)
WHERE e.name IN $entityNames
AND e.tenantId = $tenantId
AND related.tenantId = $tenantId
WITH related,
min(length(r)) as distance,
count(distinct e) as connectionCount
MATCH (related)-[:APPEARS_IN]->(doc:Document)
WHERE doc.tenantId = $tenantId
AND any(group IN $userGroups WHERE group IN doc.allowedGroups)
RETURN distinct doc.id as documentId,
doc.title as title,
doc.content as content,
collect(distinct related.name) as relatedEntities,
min(distance) as minDistance,
sum(connectionCount) as totalConnections
ORDER BY totalConnections DESC, minDistance ASC
LIMIT $limit
""";
Map<String, Object> params = Map.of(
"entityNames", entities.stream()
.map(Entity::getName)
.collect(Collectors.toList()),
"tenantId", context.getTenantId(),
"userGroups", context.getUserGroups(),
"limit", query.getLimit()
);
return session.run(cypher, params)
.list(record -> GraphMatch.builder()
.documentId(record.get("documentId").asString())
.title(record.get("title").asString())
.relatedEntities(record.get("relatedEntities").asList(Value::asString))
.relationshipScore(calculateScore(
record.get("minDistance").asInt(),
record.get("totalConnections").asInt()
))
.build());
}
}
}
Advanced Ranking and Fusion
The ranking service implements sophisticated result fusion algorithms:
@Service
public class SearchRankingService {
private final RelevanceScorer relevanceScorer;
private final DiversityOptimizer diversityOptimizer;
public List<SearchResult> mergeAndRank(
List<VectorMatch> vectorMatches,
List<KeywordMatch> keywordMatches,
List<GraphMatch> graphMatches,
RankingWeights weights) {
// Create a map to aggregate scores by document ID
Map<String, AggregatedScore> scoreMap = new HashMap<>();
// Process vector matches
vectorMatches.forEach(match -> {
scoreMap.computeIfAbsent(match.getDocumentId(),
k -> new AggregatedScore())
.addVectorScore(match.getScore());
});
// Process keyword matches with BM25 normalization
keywordMatches.forEach(match -> {
scoreMap.computeIfAbsent(match.getDocumentId(),
k -> new AggregatedScore())
.addKeywordScore(normalizeBM25Score(match.getScore()));
});
// Process graph matches
graphMatches.forEach(match -> {
scoreMap.computeIfAbsent(match.getDocumentId(),
k -> new AggregatedScore())
.addGraphScore(match.getRelationshipScore());
});
// Calculate final scores with reciprocal rank fusion
List<SearchResult> results = scoreMap.entrySet().stream()
.map(entry -> {
AggregatedScore aggScore = entry.getValue();
double finalScore = calculateRRFScore(aggScore, weights);
return SearchResult.builder()
.documentId(entry.getKey())
.score(finalScore)
.vectorScore(aggScore.getVectorScore())
.keywordScore(aggScore.getKeywordScore())
.graphScore(aggScore.getGraphScore())
.build();
})
.sorted((a, b) -> Double.compare(b.getScore(), a.getScore()))
.collect(Collectors.toList());
// Apply diversity optimization to prevent result clustering
return diversityOptimizer.optimize(results);
}
private double calculateRRFScore(AggregatedScore aggScore, RankingWeights weights) {
// Reciprocal Rank Fusion with weighted components
double k = 60.0; // RRF constant
double rrfVector = aggScore.hasVectorScore() ?
1.0 / (k + aggScore.getVectorRank()) : 0;
double rrfKeyword = aggScore.hasKeywordScore() ?
1.0 / (k + aggScore.getKeywordRank()) : 0;
double rrfGraph = aggScore.hasGraphScore() ?
1.0 / (k + aggScore.getGraphRank()) : 0;
return weights.getVectorWeight() * rrfVector +
weights.getKeywordWeight() * rrfKeyword +
weights.getGraphWeight() * rrfGraph;
}
}
Multi-Tenant Data Isolation and Security
Enterprise RAG systems must provide strict data isolation between tenants while maintaining performance. This requires careful architectural decisions at every layer.
Namespace-Based Isolation Strategy
@Configuration
public class MultiTenantConfiguration {
@Bean
public TenantContextHolder tenantContextHolder() {
return new ThreadLocalTenantContextHolder();
}
@Bean
public TenantIsolationInterceptor tenantIsolationInterceptor() {
return new TenantIsolationInterceptor();
}
@Component
public static class TenantIsolationInterceptor implements HandlerInterceptor {
@Autowired
private TenantContextHolder tenantContextHolder;
@Override
public boolean preHandle(HttpServletRequest request,
HttpServletResponse response,
Object handler) {
// Extract tenant ID from JWT token
String token = extractToken(request);
Claims claims = jwtTokenProvider.validateToken(token);
String tenantId = claims.get("tenantId", String.class);
// Set tenant context for this request
tenantContextHolder.setTenantId(tenantId);
// Validate tenant is active
if (!tenantService.isActive(tenantId)) {
response.setStatus(HttpServletResponse.SC_FORBIDDEN);
return false;
}
return true;
}
@Override
public void afterCompletion(HttpServletRequest request,
HttpServletResponse response,
Object handler,
Exception ex) {
// Clear tenant context after request
tenantContextHolder.clear();
}
}
}
Document-Level Access Control
Implementing fine-grained access control at the document level:
@Service
public class DocumentAccessControlService {
private final AccessPolicyRepository policyRepository;
private final UserGroupService userGroupService;
private final AuditLogger auditLogger;
public boolean canAccess(String documentId, String userId, String tenantId) {
// Get user's groups and permissions
Set<String> userGroups = userGroupService.getUserGroups(userId, tenantId);
Set<Permission> userPermissions = getUserPermissions(userId, tenantId);
// Get document access policy
DocumentAccessPolicy policy = policyRepository
.findByDocumentIdAndTenantId(documentId, tenantId)
.orElse(DocumentAccessPolicy.DEFAULT_DENY);
// Check access rules
boolean hasAccess = evaluateAccessRules(policy, userGroups, userPermissions);
// Audit access attempt
auditLogger.logAccessAttempt(
userId,
documentId,
tenantId,
hasAccess
);
return hasAccess;
}
private boolean evaluateAccessRules(DocumentAccessPolicy policy,
Set<String> userGroups,
Set<Permission> userPermissions) {
// Check explicit denials first
if (policy.getDeniedGroups().stream()
.anyMatch(userGroups::contains)) {
return false;
}
// Check required permissions
if (!userPermissions.containsAll(policy.getRequiredPermissions())) {
return false;
}
// Check allowed groups
if (policy.getAllowedGroups().isEmpty()) {
// No specific groups means allow all (after passing other checks)
return true;
}
return policy.getAllowedGroups().stream()
.anyMatch(userGroups::contains);
}
}
Encryption at Rest and in Transit
Implementing field-level encryption for sensitive data:
@Component
public class FieldLevelEncryptionService {
private final AWSKeyManagementServiceClient kmsClient;
private final Map<String, DataKey> dataKeyCache = new ConcurrentHashMap<>();
@Value("${encryption.master.key.id}")
private String masterKeyId;
public EncryptedDocument encryptDocument(Document document, String tenantId) {
// Get or create data encryption key for tenant
DataKey dataKey = getOrCreateDataKey(tenantId);
// Encrypt sensitive fields
EncryptedDocument encrypted = EncryptedDocument.builder()
.id(document.getId())
.tenantId(tenantId)
.encryptedContent(encryptField(document.getContent(), dataKey))
.encryptedMetadata(encryptField(
jsonSerializer.serialize(document.getMetadata()),
dataKey
))
.publicMetadata(document.getPublicMetadata()) // Not encrypted
.encryptionKeyId(dataKey.getKeyId())
.build();
return encrypted;
}
private DataKey getOrCreateDataKey(String tenantId) {
return dataKeyCache.computeIfAbsent(tenantId, tid -> {
// Generate new data key using KMS
GenerateDataKeyRequest request = new GenerateDataKeyRequest()
.withKeyId(masterKeyId)
.withKeySpec(DataKeySpec.AES_256)
.withEncryptionContext(Map.of("tenantId", tid));
GenerateDataKeyResult result = kmsClient.generateDataKey(request);
return DataKey.builder()
.keyId(UUID.randomUUID().toString())
.plaintextKey(result.getPlaintext())
.encryptedKey(result.getCiphertextBlob())
.tenantId(tid)
.build();
});
}
private byte[] encryptField(String plaintext, DataKey dataKey) {
try {
Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
// Generate random IV
byte[] iv = new byte[12];
SecureRandom.getInstanceStrong().nextBytes(iv);
// Initialize cipher
SecretKeySpec keySpec = new SecretKeySpec(
dataKey.getPlaintextKey().array(),
"AES"
);
GCMParameterSpec gcmSpec = new GCMParameterSpec(128, iv);
cipher.init(Cipher.ENCRYPT_MODE, keySpec, gcmSpec);
// Encrypt data
byte[] ciphertext = cipher.doFinal(plaintext.getBytes(StandardCharsets.UTF_8));
// Combine IV and ciphertext
byte[] result = new byte[iv.length + ciphertext.length];
System.arraycopy(iv, 0, result, 0, iv.length);
System.arraycopy(ciphertext, 0, result, iv.length, ciphertext.length);
return result;
} catch (Exception e) {
throw new EncryptionException("Failed to encrypt field", e);
}
}
}
Performance Optimization for Millions of Documents
Scaling RAG systems to handle millions of documents requires sophisticated optimization strategies at every layer.
Intelligent Caching Architecture
@Configuration
public class CachingConfiguration {
@Bean
public CacheManager cacheManager(RedisConnectionFactory connectionFactory) {
RedisCacheConfiguration config = RedisCacheConfiguration
.defaultCacheConfig()
.entryTtl(Duration.ofMinutes(15))
.serializeKeysWith(RedisSerializationContext.SerializationPair
.fromSerializer(new StringRedisSerializer()))
.serializeValuesWith(RedisSerializationContext.SerializationPair
.fromSerializer(new GenericJackson2JsonRedisSerializer()));
return RedisCacheManager.builder(connectionFactory)
.cacheDefaults(config)
.withCacheConfiguration("embeddings",
config.entryTtl(Duration.ofHours(24)))
.withCacheConfiguration("search-results",
config.entryTtl(Duration.ofMinutes(5)))
.build();
}
@Component
public class AdaptiveCacheService {
private final RedisTemplate<String, Object> redisTemplate;
private final CacheMetricsCollector metricsCollector;
public <T> T getWithAdaptiveTTL(String key,
Supplier<T> loader,
Class<T> type) {
// Check cache
T cached = (T) redisTemplate.opsForValue().get(key);
if (cached != null) {
metricsCollector.recordHit(key);
return cached;
}
metricsCollector.recordMiss(key);
// Load data
T value = loader.get();
// Calculate adaptive TTL based on access patterns
Duration ttl = calculateAdaptiveTTL(key);
// Store in cache
redisTemplate.opsForValue().set(key, value, ttl);
return value;
}
private Duration calculateAdaptiveTTL(String key) {
// Get access frequency for this key pattern
AccessPattern pattern = metricsCollector.getAccessPattern(key);
if (pattern.getHitRate() > 0.8) {
// High hit rate - extend TTL
return Duration.ofHours(24);
} else if (pattern.getHitRate() > 0.5) {
// Medium hit rate
return Duration.ofHours(6);
} else {
// Low hit rate - short TTL
return Duration.ofMinutes(30);
}
}
}
}
Embedding Generation Pipeline
Optimizing embedding generation for large document volumes:
@Service
public class BatchEmbeddingService {
private final OpenAIService openAIService;
private final ExecutorService embeddingExecutor;
private final CircuitBreaker circuitBreaker;
public BatchEmbeddingService() {
this.embeddingExecutor = Executors.newFixedThreadPool(
Runtime.getRuntime().availableProcessors() * 2,
new ThreadFactoryBuilder()
.setNameFormat("embedding-worker-%d")
.build()
);
this.circuitBreaker = CircuitBreaker.ofDefaults("embedding-service");
}
public CompletableFuture<List<DocumentEmbedding>> generateEmbeddings(
List<Document> documents) {
// Chunk documents for optimal batch size
List<List<Document>> batches = Lists.partition(documents, 100);
// Process batches in parallel with circuit breaker
List<CompletableFuture<List<DocumentEmbedding>>> futures =
batches.stream()
.map(batch -> CompletableFuture.supplyAsync(
() -> circuitBreaker.executeSupplier(
() -> processBatch(batch)
),
embeddingExecutor
))
.collect(Collectors.toList());
// Combine results
return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]))
.thenApply(v -> futures.stream()
.map(CompletableFuture::join)
.flatMap(List::stream)
.collect(Collectors.toList()));
}
private List<DocumentEmbedding> processBatch(List<Document> batch) {
// Prepare texts for embedding
List<String> texts = batch.stream()
.map(doc -> truncateToTokenLimit(doc.getContent(), 8192))
.collect(Collectors.toList());
// Call embedding API with retry logic
EmbeddingResponse response = retryTemplate.execute(context ->
openAIService.createEmbeddings(
EmbeddingRequest.builder()
.model("text-embedding-3-large")
.input(texts)
.dimensions(1536) // Optimal for performance/quality
.build()
)
);
// Map embeddings back to documents
List<DocumentEmbedding> embeddings = new ArrayList<>();
for (int i = 0; i < batch.size(); i++) {
embeddings.add(DocumentEmbedding.builder()
.documentId(batch.get(i).getId())
.embedding(response.getData().get(i).getEmbedding())
.model("text-embedding-3-large")
.dimensions(1536)
.generatedAt(Instant.now())
.build());
}
return embeddings;
}
}
Query Optimization Strategies
Advanced query optimization for complex searches:
@Service
public class QueryOptimizer {
private final QueryPlanCache queryPlanCache;
private final QueryStatistics queryStats;
public OptimizedQuery optimize(SearchQuery query, SearchContext context) {
// Check if we have a cached query plan
String planKey = generatePlanKey(query, context);
QueryPlan cachedPlan = queryPlanCache.get(planKey);
if (cachedPlan != null && !cachedPlan.isStale()) {
return executeWithPlan(query, cachedPlan);
}
// Analyze query characteristics
QueryAnalysis analysis = analyzeQuery(query);
// Generate optimized query plan
QueryPlan plan = QueryPlan.builder()
.useVectorSearch(analysis.hasSemanticsIntent())
.useKeywordSearch(analysis.hasExactMatchIntent())
.useGraphSearch(analysis.hasRelationshipIntent())
.vectorIndexHint(selectOptimalVectorIndex(analysis, context))
.prefetchRelated(analysis.needsRelatedDocuments())
.parallelism(calculateOptimalParallelism(analysis))
.build();
// Cache the plan
queryPlanCache.put(planKey, plan);
return executeWithPlan(query, plan);
}
private String selectOptimalVectorIndex(QueryAnalysis analysis,
SearchContext context) {
// Select index based on query characteristics
if (analysis.isShortQuery() && context.isHighPrecisionRequired()) {
return "high_precision_index"; // Smaller embeddings, more accurate
} else if (analysis.isLongQuery() && context.isSpeedCritical()) {
return "fast_index"; // Quantized embeddings
} else {
return "balanced_index"; // Default balanced approach
}
}
private int calculateOptimalParallelism(QueryAnalysis analysis) {
// Dynamic parallelism based on query complexity
int baseParallelism = Runtime.getRuntime().availableProcessors();
if (analysis.getComplexity() == QueryComplexity.HIGH) {
return Math.min(baseParallelism * 2, 16);
} else if (analysis.getComplexity() == QueryComplexity.LOW) {
return Math.max(baseParallelism / 2, 1);
} else {
return baseParallelism;
}
}
}
Index Management and Optimization
Managing vector indexes for optimal performance:
@Component
public class VectorIndexManager {
private final PineconeClient pineconeClient;
private final IndexMetricsCollector metricsCollector;
private final ScheduledExecutorService scheduler;
@PostConstruct
public void initialize() {
// Schedule regular index optimization
scheduler.scheduleAtFixedRate(
this::optimizeIndexes,
1, 1, TimeUnit.HOURS
);
}
public void optimizeIndexes() {
List<IndexInfo> indexes = pineconeClient.listIndexes();
for (IndexInfo index : indexes) {
IndexMetrics metrics = metricsCollector.getMetrics(index.getName());
if (shouldOptimize(metrics)) {
optimizeIndex(index, metrics);
}
}
}
private boolean shouldOptimize(IndexMetrics metrics) {
return metrics.getFragmentation() > 0.3 ||
metrics.getQueryLatencyP99() > 200 ||
metrics.getIndexSizeRatio() > 1.5;
}
private void optimizeIndex(IndexInfo index, IndexMetrics metrics) {
log.info("Optimizing index: {} with metrics: {}",
index.getName(), metrics);
if (metrics.getFragmentation() > 0.3) {
// Trigger index compaction
pineconeClient.compactIndex(index.getName());
}
if (metrics.getQueryLatencyP99() > 200) {
// Adjust replica count based on load
int currentReplicas = index.getReplicas();
int targetReplicas = calculateOptimalReplicas(metrics);
if (targetReplicas != currentReplicas) {
pineconeClient.updateIndex(
index.getName(),
UpdateIndexRequest.builder()
.replicas(targetReplicas)
.build()
);
}
}
if (metrics.getIndexSizeRatio() > 1.5) {
// Consider re-indexing with better parameters
scheduleReindexing(index, metrics);
}
}
private int calculateOptimalReplicas(IndexMetrics metrics) {
double qps = metrics.getQueriesPerSecond();
double avgLatency = metrics.getAvgQueryLatency();
// Calculate required replicas for target latency
double targetLatency = 50.0; // 50ms target
int replicas = (int) Math.ceil(qps * avgLatency / targetLatency / 1000);
// Apply bounds
return Math.max(1, Math.min(replicas, 10));
}
}
Real-Time vs Batch Processing Architectures
Enterprise RAG systems must balance real-time responsiveness with efficient batch processing for large-scale operations.
Event-Driven Architecture for Real-Time Updates
@Configuration
@EnableKafka
public class RealTimeProcessingConfig {
@Bean
public KafkaTemplate<String, DocumentEvent> kafkaTemplate() {
return new KafkaTemplate<>(producerFactory());
}
@Component
public class DocumentEventProcessor {
private final KafkaTemplate<String, DocumentEvent> kafkaTemplate;
private final EmbeddingService embeddingService;
private final IndexingService indexingService;
@EventListener
@Async
public void handleDocumentCreated(DocumentCreatedEvent event) {
// Send to Kafka for processing
DocumentEvent kafkaEvent = DocumentEvent.builder()
.eventType(EventType.CREATED)
.documentId(event.getDocumentId())
.tenantId(event.getTenantId())
.content(event.getContent())
.metadata(event.getMetadata())
.timestamp(Instant.now())
.build();
kafkaTemplate.send("document-events",
event.getTenantId(),
kafkaEvent);
}
@KafkaListener(topics = "document-events",
containerFactory = "kafkaListenerContainerFactory")
public void processDocumentEvent(DocumentEvent event) {
MDC.put("tenantId", event.getTenantId());
MDC.put("documentId", event.getDocumentId());
try {
switch (event.getEventType()) {
case CREATED:
case UPDATED:
processDocumentUpdate(event);
break;
case DELETED:
processDocumentDeletion(event);
break;
}
} catch (Exception e) {
log.error("Failed to process document event", e);
// Send to DLQ for retry
sendToDeadLetterQueue(event, e);
} finally {
MDC.clear();
}
}
private void processDocumentUpdate(DocumentEvent event) {
// Generate embedding in real-time
float[] embedding = embeddingService.generateEmbedding(
event.getContent()
);
// Update vector index immediately
indexingService.upsertVector(
event.getDocumentId(),
embedding,
event.getMetadata(),
event.getTenantId()
);
// Update search index asynchronously
CompletableFuture.runAsync(() ->
indexingService.updateSearchIndex(event)
);
// Update knowledge graph if needed
if (event.getMetadata().containsKey("entities")) {
updateKnowledgeGraph(event);
}
}
}
}
Batch Processing Pipeline
Implementing efficient batch processing for bulk operations:
@Configuration
@EnableBatchProcessing
public class BatchProcessingConfig {
@Bean
public Job documentProcessingJob(JobBuilderFactory jobBuilderFactory,
StepBuilderFactory stepBuilderFactory) {
return jobBuilderFactory.get("documentProcessingJob")
.incrementer(new RunIdIncrementer())
.listener(jobExecutionListener())
.flow(documentExtractionStep(stepBuilderFactory))
.next(embeddingGenerationStep(stepBuilderFactory))
.next(indexingStep(stepBuilderFactory))
.next(validationStep(stepBuilderFactory))
.end()
.build();
}
@Bean
public Step embeddingGenerationStep(StepBuilderFactory stepBuilderFactory) {
return stepBuilderFactory.get("embeddingGenerationStep")
.<ExtractedDocument, EmbeddedDocument>chunk(1000)
.reader(documentReader())
.processor(embeddingProcessor())
.writer(embeddingWriter())
.faultTolerant()
.retryLimit(3)
.retry(TransientException.class)
.skipLimit(100)
.skip(CorruptedDocumentException.class)
.listener(stepExecutionListener())
.taskExecutor(taskExecutor())
.throttleLimit(10) // Limit concurrent API calls
.build();
}
@Component
public class BatchEmbeddingProcessor
implements ItemProcessor<ExtractedDocument, EmbeddedDocument> {
private final EmbeddingService embeddingService;
private final RateLimiter rateLimiter;
public BatchEmbeddingProcessor() {
// 100 requests per second rate limit
this.rateLimiter = RateLimiter.create(100.0);
}
@Override
public EmbeddedDocument process(ExtractedDocument document) {
// Rate limit API calls
rateLimiter.acquire();
try {
// Chunk document if too large
List<String> chunks = chunkDocument(document.getContent());
// Generate embeddings for each chunk
List<float[]> embeddings = chunks.parallelStream()
.map(chunk -> embeddingService.generateEmbedding(chunk))
.collect(Collectors.toList());
// Aggregate embeddings (weighted average based on chunk importance)
float[] aggregatedEmbedding = aggregateEmbeddings(
embeddings,
chunks
);
return EmbeddedDocument.builder()
.documentId(document.getId())
.tenantId(document.getTenantId())
.embedding(aggregatedEmbedding)
.chunks(chunks.size())
.processingTime(System.currentTimeMillis() - startTime)
.build();
} catch (Exception e) {
log.error("Failed to process document: {}", document.getId(), e);
throw new ProcessingException("Embedding generation failed", e);
}
}
private List<String> chunkDocument(String content) {
// Smart chunking with overlap
int chunkSize = 1000; // tokens
int overlap = 200; // tokens
List<String> chunks = new ArrayList<>();
List<String> sentences = sentenceDetector.detect(content);
StringBuilder currentChunk = new StringBuilder();
int currentTokens = 0;
for (String sentence : sentences) {
int sentenceTokens = tokenizer.countTokens(sentence);
if (currentTokens + sentenceTokens > chunkSize) {
chunks.add(currentChunk.toString());
// Start new chunk with overlap
currentChunk = new StringBuilder();
currentTokens = 0;
// Add overlapping content
int overlapStart = Math.max(0, sentences.indexOf(sentence) - 2);
for (int i = overlapStart; i < sentences.indexOf(sentence); i++) {
currentChunk.append(sentences.get(i)).append(" ");
currentTokens += tokenizer.countTokens(sentences.get(i));
}
}
currentChunk.append(sentence).append(" ");
currentTokens += sentenceTokens;
}
if (currentChunk.length() > 0) {
chunks.add(currentChunk.toString());
}
return chunks;
}
}
}
Hybrid Processing Strategy
Implementing intelligent routing between real-time and batch processing:
@Service
public class ProcessingOrchestrator {
private final RealTimeProcessor realTimeProcessor;
private final BatchProcessor batchProcessor;
private final ProcessingMetrics metrics;
@Value("${processing.realtime.threshold:10}")
private int realtimeThreshold;
@Value("${processing.batch.window:300}")
private int batchWindowSeconds;
public void processDocuments(List<Document> documents, ProcessingContext context) {
// Analyze processing requirements
ProcessingStrategy strategy = determineStrategy(documents, context);
switch (strategy) {
case REAL_TIME:
processRealTime(documents, context);
break;
case BATCH:
processBatch(documents, context);
break;
case HYBRID:
processHybrid(documents, context);
break;
}
}
private ProcessingStrategy determineStrategy(List<Document> documents,
ProcessingContext context) {
int documentCount = documents.size();
boolean isUrgent = context.getPriority() == Priority.HIGH;
double systemLoad = metrics.getCurrentLoad();
if (documentCount <= realtimeThreshold && isUrgent) {
return ProcessingStrategy.REAL_TIME;
} else if (documentCount > 1000 || systemLoad > 0.8) {
return ProcessingStrategy.BATCH;
} else {
return ProcessingStrategy.HYBRID;
}
}
private void processHybrid(List<Document> documents, ProcessingContext context) {
// Separate documents by priority
Map<Priority, List<Document>> priorityGroups = documents.stream()
.collect(Collectors.groupingBy(doc ->
determinePriority(doc, context)
));
// Process high priority in real-time
if (priorityGroups.containsKey(Priority.HIGH)) {
realTimeProcessor.process(
priorityGroups.get(Priority.HIGH),
context
);
}
// Queue others for batch processing
List<Document> batchDocuments = new ArrayList<>();
batchDocuments.addAll(priorityGroups.getOrDefault(Priority.MEDIUM,
Collections.emptyList()));
batchDocuments.addAll(priorityGroups.getOrDefault(Priority.LOW,
Collections.emptyList()));
if (!batchDocuments.isEmpty()) {
batchProcessor.queueForProcessing(batchDocuments, context);
}
}
}
Advanced Retrieval Strategies
Moving beyond simple similarity search to implement sophisticated retrieval strategies.
Multi-Stage Retrieval Pipeline
@Service
public class MultiStageRetriever {
private final InitialRetriever initialRetriever;
private final ReRanker reRanker;
private final ContextEnricher contextEnricher;
private final RelevanceFeedback relevanceFeedback;
public RetrievalResult retrieve(Query query, RetrievalContext context) {
// Stage 1: Initial broad retrieval
List<Document> candidates = initialRetriever.retrieve(
query,
context.withTopK(1000) // Cast wide net
);
// Stage 2: Re-ranking with cross-encoder
List<RankedDocument> reranked = reRanker.rerank(
query,
candidates,
context.withTopK(100) // Narrow down
);
// Stage 3: Context enrichment
List<EnrichedDocument> enriched = contextEnricher.enrich(
reranked,
context
);
// Stage 4: Final scoring with relevance feedback
List<ScoredDocument> finalResults = relevanceFeedback.applyFeedback(
enriched,
query,
context
);
return RetrievalResult.builder()
.documents(finalResults.stream()
.limit(context.getTopK())
.collect(Collectors.toList()))
.totalCandidates(candidates.size())
.retrievalStrategy("multi-stage")
.stages(List.of("initial", "rerank", "enrich", "feedback"))
.build();
}
}
@Component
public class CrossEncoderReRanker implements ReRanker {
private final CrossEncoderModel model;
private final ExecutorService executorService;
@Override
public List<RankedDocument> rerank(Query query,
List<Document> candidates,
RetrievalContext context) {
// Batch candidates for efficient processing
List<List<Document>> batches = Lists.partition(candidates, 32);
// Process batches in parallel
List<CompletableFuture<List<RankedDocument>>> futures =
batches.stream()
.map(batch -> CompletableFuture.supplyAsync(
() -> rerankBatch(query, batch),
executorService
))
.collect(Collectors.toList());
// Combine and sort results
return futures.stream()
.map(CompletableFuture::join)
.flatMap(List::stream)
.sorted(Comparator.comparing(RankedDocument::getScore).reversed())
.collect(Collectors.toList());
}
private List<RankedDocument> rerankBatch(Query query, List<Document> batch) {
// Prepare input pairs
List<String[]> pairs = batch.stream()
.map(doc -> new String[]{query.getText(), doc.getContent()})
.collect(Collectors.toList());
// Get cross-encoder scores
float[] scores = model.predict(pairs);
// Create ranked documents
List<RankedDocument> ranked = new ArrayList<>();
for (int i = 0; i < batch.size(); i++) {
ranked.add(RankedDocument.builder()
.document(batch.get(i))
.score(scores[i])
.rankingMethod("cross-encoder")
.build());
}
return ranked;
}
}
Contextual Query Expansion
Implementing sophisticated query expansion techniques:
@Service
public class ContextualQueryExpander {
private final WordEmbeddingService wordEmbeddings;
private final ConceptGraphService conceptGraph;
private final QueryHistoryService queryHistory;
public ExpandedQuery expand(Query originalQuery, ExpansionContext context) {
// Extract key terms from query
List<Term> keyTerms = extractKeyTerms(originalQuery);
// Expand using multiple strategies
Set<String> expansions = new HashSet<>();
// 1. Synonym expansion using word embeddings
expansions.addAll(expandWithSynonyms(keyTerms));
// 2. Concept expansion using knowledge graph
expansions.addAll(expandWithConcepts(keyTerms, context));
// 3. Historical expansion based on user behavior
expansions.addAll(expandWithHistory(originalQuery, context));
// 4. Domain-specific expansion
expansions.addAll(expandWithDomainKnowledge(keyTerms, context));
// Build expanded query with weights
return buildExpandedQuery(originalQuery, expansions, context);
}
private Set<String> expandWithConcepts(List<Term> terms, ExpansionContext context) {
Set<String> conceptExpansions = new HashSet<>();
for (Term term : terms) {
// Find related concepts in knowledge graph
List<Concept> relatedConcepts = conceptGraph.findRelated(
term.getValue(),
context.getDomain(),
RelationType.IS_A,
RelationType.PART_OF,
RelationType.RELATES_TO
);
// Add high-confidence expansions
relatedConcepts.stream()
.filter(concept -> concept.getConfidence() > 0.7)
.map(Concept::getName)
.forEach(conceptExpansions::add);
}
return conceptExpansions;
}
private ExpandedQuery buildExpandedQuery(Query original,
Set<String> expansions,
ExpansionContext context) {
// Calculate weights for expansions
Map<String, Double> weightedExpansions = new HashMap<>();
for (String expansion : expansions) {
double weight = calculateExpansionWeight(
expansion,
original,
context
);
weightedExpansions.put(expansion, weight);
}
// Build expanded query string
StringBuilder expandedText = new StringBuilder(original.getText());
// Add top expansions with boost syntax
weightedExpansions.entrySet().stream()
.sorted(Map.Entry.<String, Double>comparingByValue().reversed())
.limit(10)
.forEach(entry -> {
expandedText.append(" OR (")
.append(entry.getKey())
.append("^")
.append(String.format("%.2f", entry.getValue()))
.append(")");
});
return ExpandedQuery.builder()
.originalQuery(original)
.expandedText(expandedText.toString())
.expansions(weightedExpansions)
.expansionStrategy(context.getStrategy())
.build();
}
}
Adaptive Retrieval Strategy
Implementing retrieval strategies that adapt based on query characteristics:
@Service
public class AdaptiveRetrievalStrategy {
private final QueryClassifier queryClassifier;
private final Map<QueryType, RetrievalStrategy> strategies;
private final PerformanceMonitor performanceMonitor;
public AdaptiveRetrievalStrategy() {
this.strategies = initializeStrategies();
}
public RetrievalResult retrieve(Query query, RetrievalContext context) {
// Classify query type
QueryType queryType = queryClassifier.classify(query);
// Get performance history for this query type
PerformanceStats stats = performanceMonitor.getStats(queryType);
// Select optimal strategy based on query type and performance
RetrievalStrategy strategy = selectStrategy(queryType, stats, context);
// Execute retrieval with monitoring
long startTime = System.currentTimeMillis();
RetrievalResult result = strategy.retrieve(query, context);
long latency = System.currentTimeMillis() - startTime;
// Record performance for adaptive optimization
performanceMonitor.record(queryType, strategy, latency, result);
return result;
}
private RetrievalStrategy selectStrategy(QueryType queryType,
PerformanceStats stats,
RetrievalContext context) {
// Get base strategy for query type
RetrievalStrategy baseStrategy = strategies.get(queryType);
// Adapt based on performance requirements
if (context.getMaxLatency() < 100) {
// Need fast response - use cached or approximate strategy
return new CachedRetrievalStrategy(baseStrategy);
} else if (stats.getAverageRelevance() < 0.7) {
// Low relevance - enhance with query expansion
return new EnhancedRetrievalStrategy(baseStrategy);
} else {
return baseStrategy;
}
}
private Map<QueryType, RetrievalStrategy> initializeStrategies() {
Map<QueryType, RetrievalStrategy> strategies = new HashMap<>();
// Factual queries - prioritize keyword search
strategies.put(QueryType.FACTUAL, new HybridStrategy(
0.3, // vector weight
0.7 // keyword weight
));
// Conceptual queries - prioritize semantic search
strategies.put(QueryType.CONCEPTUAL, new HybridStrategy(
0.7, // vector weight
0.3 // keyword weight
));
// Navigational queries - use metadata and structure
strategies.put(QueryType.NAVIGATIONAL, new MetadataFirstStrategy());
// Complex queries - use multi-stage retrieval
strategies.put(QueryType.COMPLEX, new MultiStageStrategy());
return strategies;
}
}
Monitoring and Observability for RAG Systems
Comprehensive monitoring is crucial for maintaining RAG system performance in production.
Metrics Collection Framework
@Configuration
public class MetricsConfiguration {
@Bean
public MeterRegistry meterRegistry() {
return new CompositeMeterRegistry()
.add(new PrometheusRegistry(PrometheusConfig.DEFAULT))
.add(new CloudWatchRegistry(cloudWatchConfig()));
}
@Component
public class RAGMetricsCollector {
private final MeterRegistry registry;
// Query metrics
private final Timer queryLatency;
private final Counter queryCount;
private final Gauge activeQueries;
// Retrieval metrics
private final Timer retrievalLatency;
private final DistributionSummary retrievalResultCount;
private final Counter cacheHits;
private final Counter cacheMisses;
// Embedding metrics
private final Timer embeddingLatency;
private final Counter embeddingTokens;
private final Gauge embeddingQueueSize;
// Quality metrics
private final DistributionSummary relevanceScore;
private final Counter feedbackPositive;
private final Counter feedbackNegative;
public RAGMetricsCollector(MeterRegistry registry) {
this.registry = registry;
// Initialize metrics
this.queryLatency = Timer.builder("rag.query.latency")
.description("Query processing latency")
.tags("component", "query")
.register(registry);
this.retrievalResultCount = DistributionSummary
.builder("rag.retrieval.result_count")
.description("Number of documents retrieved")
.baseUnit("documents")
.register(registry);
this.relevanceScore = DistributionSummary
.builder("rag.quality.relevance_score")
.description("Relevance scores of retrieved documents")
.scale(100) // Convert to percentage
.register(registry);
}
public void recordQuery(QueryMetrics metrics) {
queryLatency.record(metrics.getLatency(), TimeUnit.MILLISECONDS);
queryCount.increment();
registry.gauge("rag.query.active",
Tags.of("tenant", metrics.getTenantId()),
metrics.getActiveCount());
// Record detailed breakdown
metrics.getLatencyBreakdown().forEach((stage, latency) -> {
registry.timer("rag.query.stage.latency",
Tags.of("stage", stage))
.record(latency, TimeUnit.MILLISECONDS);
});
}
public void recordRetrievalQuality(String queryId,
List<RetrievalResult> results) {
results.forEach(result -> {
relevanceScore.record(result.getRelevanceScore());
// Track result position vs relevance
registry.counter("rag.quality.relevance_by_position",
Tags.of(
"position", String.valueOf(result.getPosition()),
"relevant", String.valueOf(result.getRelevanceScore() > 0.7)
)).increment();
});
}
}
}
Distributed Tracing Implementation
@Configuration
@EnableTracing
public class TracingConfiguration {
@Bean
public Tracer tracer() {
return OpenTelemetry.builder()
.setTracerProvider(
SdkTracerProvider.builder()
.addSpanProcessor(
BatchSpanProcessor.builder(
OtlpGrpcSpanExporter.builder()
.setEndpoint("http://otel-collector:4317")
.build()
).build()
)
.build()
)
.build()
.getTracer("rag-system");
}
@Component
public class TracedRAGService {
private final Tracer tracer;
private final RAGService ragService;
public QueryResponse query(QueryRequest request) {
Span span = tracer.spanBuilder("rag.query")
.setAttribute("query.text", request.getQuery())
.setAttribute("query.tenant", request.getTenantId())
.setAttribute("query.user", request.getUserId())
.startSpan();
try (Scope scope = span.makeCurrent()) {
// Trace embedding generation
Span embeddingSpan = tracer.spanBuilder("rag.embedding")
.startSpan();
float[] embedding;
try (Scope embeddingScope = embeddingSpan.makeCurrent()) {
embedding = embeddingService.generate(request.getQuery());
embeddingSpan.setAttribute("embedding.dimensions",
embedding.length);
} finally {
embeddingSpan.end();
}
// Trace retrieval
Span retrievalSpan = tracer.spanBuilder("rag.retrieval")
.startSpan();
List<Document> documents;
try (Scope retrievalScope = retrievalSpan.makeCurrent()) {
documents = retrievalService.retrieve(embedding, request);
retrievalSpan.setAttribute("retrieval.count",
documents.size());
} finally {
retrievalSpan.end();
}
// Trace generation
Span generationSpan = tracer.spanBuilder("rag.generation")
.startSpan();
String response;
try (Scope generationScope = generationSpan.makeCurrent()) {
response = generationService.generate(
request.getQuery(),
documents
);
generationSpan.setAttribute("generation.tokens",
tokenCounter.count(response));
} finally {
generationSpan.end();
}
span.setAttribute("query.success", true);
return QueryResponse.success(response);
} catch (Exception e) {
span.setStatus(StatusCode.ERROR, e.getMessage());
span.recordException(e);
throw e;
} finally {
span.end();
}
}
}
}
Performance Monitoring Dashboard
@RestController
@RequestMapping("/api/monitoring")
public class MonitoringController {
private final MetricsAggregator metricsAggregator;
private final HealthChecker healthChecker;
@GetMapping("/dashboard")
public DashboardData getDashboard(@RequestParam(defaultValue = "1h") String timeRange) {
TimeRange range = TimeRange.parse(timeRange);
return DashboardData.builder()
.systemHealth(healthChecker.checkHealth())
.queryMetrics(metricsAggregator.getQueryMetrics(range))
.retrievalMetrics(metricsAggregator.getRetrievalMetrics(range))
.qualityMetrics(metricsAggregator.getQualityMetrics(range))
.resourceMetrics(metricsAggregator.getResourceMetrics(range))
.alerts(metricsAggregator.getActiveAlerts())
.build();
}
@Component
public class MetricsAggregator {
private final MeterRegistry registry;
private final TimeSeriesDatabase tsdb;
public QueryMetricsSummary getQueryMetrics(TimeRange range) {
return QueryMetricsSummary.builder()
.totalQueries(getMetricSum("rag.query.count", range))
.averageLatency(getMetricAverage("rag.query.latency", range))
.p99Latency(getMetricPercentile("rag.query.latency", 0.99, range))
.throughput(calculateThroughput("rag.query.count", range))
.errorRate(calculateErrorRate(range))
.latencyBreakdown(getLatencyBreakdown(range))
.build();
}
public QualityMetricsSummary getQualityMetrics(TimeRange range) {
return QualityMetricsSummary.builder()
.averageRelevance(getMetricAverage("rag.quality.relevance_score", range))
.topKAccuracy(calculateTopKAccuracy(range))
.userSatisfaction(calculateUserSatisfaction(range))
.retrievalPrecision(calculatePrecision(range))
.retrievalRecall(calculateRecall(range))
.build();
}
private Map<String, Double> getLatencyBreakdown(TimeRange range) {
Map<String, Double> breakdown = new HashMap<>();
List<String> stages = List.of("embedding", "retrieval", "reranking", "generation");
for (String stage : stages) {
double latency = getMetricAverage(
"rag.query.stage.latency",
Tags.of("stage", stage),
range
);
breakdown.put(stage, latency);
}
return breakdown;
}
}
}
Alerting and Anomaly Detection
@Component
public class RAGAlertingService {
private final AlertManager alertManager;
private final AnomalyDetector anomalyDetector;
private final NotificationService notificationService;
@Scheduled(fixedRate = 60000) // Every minute
public void checkSystemHealth() {
// Check query latency
checkQueryLatency();
// Check retrieval quality
checkRetrievalQuality();
// Check resource usage
checkResourceUsage();
// Detect anomalies
detectAnomalies();
}
private void checkQueryLatency() {
double p99Latency = metricsCollector.getP99Latency("1m");
double avgLatency = metricsCollector.getAvgLatency("1m");
if (p99Latency > 500) { // 500ms threshold
Alert alert = Alert.builder()
.severity(Severity.HIGH)
.title("High query latency detected")
.description(String.format(
"P99 latency: %.2fms, Average: %.2fms",
p99Latency, avgLatency
))
.metric("rag.query.latency.p99")
.value(p99Latency)
.threshold(500.0)
.build();
alertManager.triggerAlert(alert);
notificationService.notifyOncall(alert);
}
}
private void detectAnomalies() {
// Get recent metrics
TimeSeriesData queryVolume = metricsCollector.getTimeSeries(
"rag.query.count",
"1h",
"1m"
);
// Detect anomalies using statistical methods
List<Anomaly> anomalies = anomalyDetector.detect(queryVolume);
for (Anomaly anomaly : anomalies) {
if (anomaly.getSeverity() > 0.8) {
Alert alert = Alert.builder()
.severity(Severity.MEDIUM)
.title("Anomaly detected in query volume")
.description(String.format(
"Detected %s anomaly: expected %.2f, actual %.2f",
anomaly.getType(),
anomaly.getExpectedValue(),
anomaly.getActualValue()
))
.anomalyScore(anomaly.getSeverity())
.build();
alertManager.triggerAlert(alert);
}
}
}
}
Cost Optimization at Scale
Managing costs effectively while maintaining performance is crucial for enterprise RAG systems.
Intelligent Caching Strategy
@Service
public class CostOptimizedCachingService {
private final RedisTemplate<String, CachedResult> redisTemplate;
private final CostCalculator costCalculator;
private final CacheAnalytics cacheAnalytics;
public Optional<QueryResult> getCachedResult(Query query, CacheContext context) {
String cacheKey = generateCacheKey(query, context);
// Check if caching is cost-effective for this query
if (!shouldCache(query, context)) {
return Optional.empty();
}
CachedResult cached = redisTemplate.opsForValue().get(cacheKey);
if (cached != null) {
// Update access statistics
cacheAnalytics.recordHit(cacheKey, cached.getCost());
// Check if cache entry is still cost-effective
if (isCostEffective(cached)) {
return Optional.of(cached.getResult());
} else {
// Evict non-cost-effective entry
redisTemplate.delete(cacheKey);
}
}
cacheAnalytics.recordMiss(cacheKey);
return Optional.empty();
}
public void cacheResult(Query query, QueryResult result, CacheContext context) {
String cacheKey = generateCacheKey(query, context);
// Calculate cost of generating this result
double generationCost = costCalculator.calculateGenerationCost(query, result);
// Predict cache value based on query patterns
CacheValuePrediction prediction = predictCacheValue(query, context);
if (prediction.getExpectedValue() > generationCost * 1.5) {
CachedResult cachedResult = CachedResult.builder()
.result(result)
.cost(generationCost)
.timestamp(Instant.now())
.expectedHits(prediction.getExpectedHits())
.build();
// Set TTL based on predicted usage
Duration ttl = calculateOptimalTTL(prediction, generationCost);
redisTemplate.opsForValue().set(cacheKey, cachedResult, ttl);
cacheAnalytics.recordCaching(cacheKey, generationCost, ttl);
}
}
private CacheValuePrediction predictCacheValue(Query query, CacheContext context) {
// Analyze historical query patterns
QueryPattern pattern = queryPatternAnalyzer.analyze(query);
// Get historical statistics for similar queries
QueryStatistics stats = cacheAnalytics.getStatistics(pattern);
// Predict future access patterns
double expectedHits = stats.getAverageHits() *
pattern.getPopularityScore() *
context.getUserCount();
double expectedValue = expectedHits *
costCalculator.getAverageQueryCost(pattern);
return CacheValuePrediction.builder()
.expectedHits(expectedHits)
.expectedValue(expectedValue)
.confidence(stats.getConfidence())
.build();
}
}
Embedding Optimization
@Service
public class EmbeddingOptimizationService {
private final Map<EmbeddingModel, EmbeddingProvider> providers;
private final CostTracker costTracker;
private final QualityMonitor qualityMonitor;
public EmbeddingStrategy selectOptimalStrategy(Document document,
EmbeddingContext context) {
// Analyze document characteristics
DocumentAnalysis analysis = analyzeDocument(document);
// Get quality requirements
QualityRequirements requirements = context.getQualityRequirements();
// Evaluate strategies
List<EmbeddingOption> options = evaluateOptions(analysis, requirements);
// Select optimal strategy based on cost/quality trade-off
return selectOptimalOption(options, context);
}
private List<EmbeddingOption> evaluateOptions(DocumentAnalysis analysis,
QualityRequirements requirements) {
List<EmbeddingOption> options = new ArrayList<>();
// Option 1: Full document embedding with large model
options.add(EmbeddingOption.builder()
.model("text-embedding-3-large")
.dimensions(3072)
.strategy(EmbeddingStrategy.FULL_DOCUMENT)
.estimatedCost(calculateCost(analysis.getTokenCount(), 0.00013))
.estimatedQuality(0.95)
.build());
// Option 2: Chunked embedding with medium model
options.add(EmbeddingOption.builder()
.model("text-embedding-3-small")
.dimensions(1536)
.strategy(EmbeddingStrategy.CHUNKED)
.estimatedCost(calculateCost(analysis.getTokenCount(), 0.00002))
.estimatedQuality(0.85)
.build());
// Option 3: Summary-based embedding
if (analysis.getTokenCount() > 8000) {
options.add(EmbeddingOption.builder()
.model("text-embedding-3-small")
.dimensions(1536)
.strategy(EmbeddingStrategy.SUMMARY_BASED)
.estimatedCost(
calculateCost(2000, 0.00002) + // Summary generation
calculateCost(500, 0.00002) // Embedding
)
.estimatedQuality(0.75)
.build());
}
// Option 4: Cached embedding reuse
if (hasSimilarEmbedding(analysis)) {
options.add(EmbeddingOption.builder()
.model("cached")
.dimensions(1536)
.strategy(EmbeddingStrategy.CACHED_REUSE)
.estimatedCost(0.0)
.estimatedQuality(0.80)
.build());
}
return options;
}
@Component
public class AdaptiveEmbeddingPipeline {
private final EmbeddingCache embeddingCache;
private final BatchProcessor batchProcessor;
public void processDocuments(List<Document> documents) {
// Group documents by optimal strategy
Map<EmbeddingStrategy, List<Document>> strategyGroups =
documents.stream()
.collect(Collectors.groupingBy(doc ->
selectOptimalStrategy(doc, getContext())
));
// Process each group with its optimal strategy
strategyGroups.forEach((strategy, docs) -> {
switch (strategy) {
case FULL_DOCUMENT:
processFullDocuments(docs);
break;
case CHUNKED:
processChunked(docs);
break;
case SUMMARY_BASED:
processSummaryBased(docs);
break;
case CACHED_REUSE:
processCachedReuse(docs);
break;
}
});
}
private void processChunked(List<Document> documents) {
// Intelligent chunking to minimize API calls
ChunkingStrategy chunkingStrategy = ChunkingStrategy.builder()
.maxChunkSize(2000) // tokens
.overlap(200) // tokens
.boundaryDetection(true)
.semanticClustering(true)
.build();
for (Document doc : documents) {
List<Chunk> chunks = chunker.chunk(doc, chunkingStrategy);
// Batch chunks for efficient API usage
List<List<Chunk>> batches = Lists.partition(chunks, 100);
for (List<Chunk> batch : batches) {
List<float[]> embeddings = embeddingService.embedBatch(
batch.stream()
.map(Chunk::getText)
.collect(Collectors.toList())
);
// Store embeddings with metadata
for (int i = 0; i < batch.size(); i++) {
embeddingCache.store(
batch.get(i).getId(),
embeddings.get(i),
batch.get(i).getMetadata()
);
}
}
}
}
}
}
Resource Pooling and Allocation
@Service
public class ResourcePoolManager {
private final Map<ResourceType, ResourcePool> resourcePools;
private final ResourceMonitor resourceMonitor;
private final CostAllocator costAllocator;
@PostConstruct
public void initialize() {
// Initialize resource pools
resourcePools.put(ResourceType.EMBEDDING_API,
new RateLimitedPool(1000, TimeUnit.MINUTES)); // 1000 requests/min
resourcePools.put(ResourceType.VECTOR_SEARCH,
new CapacityPool(100)); // 100 concurrent searches
resourcePools.put(ResourceType.LLM_INFERENCE,
new TokenBucketPool(1_000_000, TimeUnit.HOURS)); // 1M tokens/hour
}
public <T> CompletableFuture<T> allocateAndExecute(
ResourceRequest request,
Function<ResourceAllocation, T> task) {
// Calculate priority based on tenant SLA and request characteristics
int priority = calculatePriority(request);
// Try to allocate resources
CompletableFuture<ResourceAllocation> allocationFuture =
allocateResources(request, priority);
return allocationFuture.thenApply(allocation -> {
try {
// Track resource usage for cost allocation
resourceMonitor.startTracking(allocation);
// Execute task with allocated resources
T result = task.apply(allocation);
// Record usage for billing
ResourceUsage usage = resourceMonitor.stopTracking(allocation);
costAllocator.allocateCost(request.getTenantId(), usage);
return result;
} finally {
// Release resources
releaseResources(allocation);
}
});
}
private CompletableFuture<ResourceAllocation> allocateResources(
ResourceRequest request,
int priority) {
return CompletableFuture.supplyAsync(() -> {
ResourceAllocation.Builder builder = ResourceAllocation.builder();
for (ResourceRequirement requirement : request.getRequirements()) {
ResourcePool pool = resourcePools.get(requirement.getType());
// Wait for available resources with priority queue
Resource resource = pool.acquire(
requirement.getAmount(),
priority,
request.getMaxWaitTime()
);
builder.addResource(requirement.getType(), resource);
}
return builder.build();
});
}
@Component
public class DynamicResourceScaler {
private final CloudProvider cloudProvider;
private final CostPredictor costPredictor;
@Scheduled(fixedRate = 300000) // Every 5 minutes
public void adjustResourcePools() {
// Get current usage metrics
Map<ResourceType, UsageMetrics> currentUsage =
resourceMonitor.getCurrentUsage();
// Predict future usage
Map<ResourceType, UsagePrediction> predictions =
currentUsage.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> costPredictor.predictUsage(
entry.getKey(),
entry.getValue()
)
));
// Adjust pool sizes based on predictions
predictions.forEach((type, prediction) -> {
ResourcePool pool = resourcePools.get(type);
if (prediction.getExpectedUsage() > pool.getCapacity() * 0.8) {
// Scale up
int newCapacity = (int) (prediction.getExpectedUsage() * 1.2);
pool.resize(newCapacity);
// Provision additional cloud resources if needed
if (type == ResourceType.VECTOR_SEARCH) {
cloudProvider.scaleVectorIndexReplicas(
calculateRequiredReplicas(newCapacity)
);
}
} else if (prediction.getExpectedUsage() < pool.getCapacity() * 0.3) {
// Scale down to save costs
int newCapacity = Math.max(
(int) (prediction.getExpectedUsage() * 1.5),
pool.getMinimumCapacity()
);
pool.resize(newCapacity);
}
});
}
}
}
Migration Strategies from Simple to Advanced RAG
Migrating from a simple RAG implementation to an enterprise-grade system requires careful planning and execution.
Phased Migration Approach
@Service
public class RAGMigrationOrchestrator {
private final MigrationPhaseExecutor phaseExecutor;
private final MigrationValidator validator;
private final RollbackManager rollbackManager;
public MigrationResult executeMigration(MigrationPlan plan) {
MigrationContext context = MigrationContext.builder()
.plan(plan)
.startTime(Instant.now())
.build();
try {
// Phase 1: Data preparation and validation
executePhase(Phase.DATA_PREPARATION, context);
// Phase 2: Parallel infrastructure setup
executePhase(Phase.INFRASTRUCTURE_SETUP, context);
// Phase 3: Incremental data migration
executePhase(Phase.DATA_MIGRATION, context);
// Phase 4: Feature parity testing
executePhase(Phase.FEATURE_PARITY, context);
// Phase 5: Performance optimization
executePhase(Phase.PERFORMANCE_TUNING, context);
// Phase 6: Cutover with rollback capability
executePhase(Phase.CUTOVER, context);
return MigrationResult.success(context);
} catch (MigrationException e) {
log.error("Migration failed at phase: {}", context.getCurrentPhase(), e);
// Execute rollback
rollbackManager.rollback(context);
return MigrationResult.failure(context, e);
}
}
private void executePhase(Phase phase, MigrationContext context) {
log.info("Starting migration phase: {}", phase);
context.setCurrentPhase(phase);
// Validate prerequisites
ValidationResult validation = validator.validatePhase(phase, context);
if (!validation.isValid()) {
throw new MigrationException(
"Phase validation failed: " + validation.getErrors()
);
}
// Execute phase
PhaseResult result = phaseExecutor.execute(phase, context);
context.addPhaseResult(phase, result);
// Verify phase completion
if (!result.isSuccessful()) {
throw new MigrationException(
"Phase execution failed: " + result.getError()
);
}
log.info("Completed migration phase: {} in {}ms",
phase, result.getDuration());
}
}
@Component
public class DataMigrationPhase implements PhaseExecutor {
private final DocumentMigrationService documentMigration;
private final EmbeddingMigrationService embeddingMigration;
private final IndexMigrationService indexMigration;
@Override
public PhaseResult execute(MigrationContext context) {
// Set up parallel processing
ExecutorService executor = Executors.newFixedThreadPool(
context.getParallelism()
);
try {
// Migrate documents in batches
CompletableFuture<MigrationStats> docsFuture =
CompletableFuture.supplyAsync(
() -> migrateDocuments(context),
executor
);
// Regenerate embeddings with new model
CompletableFuture<MigrationStats> embeddingsFuture =
CompletableFuture.supplyAsync(
() -> migrateEmbeddings(context),
executor
);
// Rebuild indexes
CompletableFuture<MigrationStats> indexesFuture =
CompletableFuture.supplyAsync(
() -> migrateIndexes(context),
executor
);
// Wait for all migrations to complete
CompletableFuture.allOf(docsFuture, embeddingsFuture, indexesFuture)
.join();
return PhaseResult.success(Map.of(
"documents", docsFuture.join(),
"embeddings", embeddingsFuture.join(),
"indexes", indexesFuture.join()
));
} finally {
executor.shutdown();
}
}
private MigrationStats migrateDocuments(MigrationContext context) {
long totalDocuments = documentMigration.countDocuments(context);
long migratedCount = 0;
// Process in chunks to avoid memory issues
int chunkSize = 10000;
for (long offset = 0; offset < totalDocuments; offset += chunkSize) {
List<Document> batch = documentMigration.fetchBatch(
offset,
chunkSize,
context
);
// Transform documents to new format
List<EnhancedDocument> transformed = batch.stream()
.map(doc -> transformDocument(doc, context))
.collect(Collectors.toList());
// Store in new system
documentMigration.storeBatch(transformed, context);
migratedCount += transformed.size();
// Report progress
double progress = (double) migratedCount / totalDocuments * 100;
context.reportProgress(Phase.DATA_MIGRATION, progress);
}
return MigrationStats.builder()
.totalItems(totalDocuments)
.successfulItems(migratedCount)
.failedItems(0)
.duration(System.currentTimeMillis() - startTime)
.build();
}
}
Zero-Downtime Migration Strategy
@Service
public class ZeroDowntimeMigrationService {
private final DualSystemRouter router;
private final ShadowModeValidator validator;
private final TrafficShifter trafficShifter;
public void executeZeroDowntimeMigration(MigrationConfig config) {
// Step 1: Set up shadow mode
enableShadowMode(config);
// Step 2: Gradual traffic shifting
executeGradualShift(config);
// Step 3: Full cutover with instant rollback
executeCutover(config);
}
private void enableShadowMode(MigrationConfig config) {
// Configure dual writes
router.enableDualWrites(config);
// Start shadow validation
validator.startValidation(ValidationConfig.builder()
.compareResults(true)
.logDiscrepancies(true)
.alertThreshold(0.01) // 1% discrepancy threshold
.build());
// Monitor for stability
Duration stabilityPeriod = Duration.ofHours(24);
monitorShadowMode(stabilityPeriod);
}
private void executeGradualShift(MigrationConfig config) {
List<TrafficShiftStage> stages = List.of(
new TrafficShiftStage(1, Duration.ofHours(1)), // 1% for 1 hour
new TrafficShiftStage(5, Duration.ofHours(2)), // 5% for 2 hours
new TrafficShiftStage(10, Duration.ofHours(4)), // 10% for 4 hours
new TrafficShiftStage(25, Duration.ofHours(8)), // 25% for 8 hours
new TrafficShiftStage(50, Duration.ofHours(12)), // 50% for 12 hours
new TrafficShiftStage(90, Duration.ofHours(24)) // 90% for 24 hours
);
for (TrafficShiftStage stage : stages) {
// Shift traffic
trafficShifter.shiftTraffic(stage.getPercentage());
// Monitor metrics
MonitoringResult result = monitorStage(stage);
// Validate performance
if (!result.meetsSlA()) {
log.warn("SLA violation at {}% traffic. Rolling back.",
stage.getPercentage());
trafficShifter.rollback();
throw new MigrationException("SLA violation during traffic shift");
}
}
}
@Component
public class DualSystemRouter {
private final RAGService legacySystem;
private final RAGService newSystem;
private final ComparisonService comparisonService;
public QueryResponse route(QueryRequest request) {
// Determine routing based on configuration
RoutingDecision decision = makeRoutingDecision(request);
switch (decision.getMode()) {
case LEGACY_ONLY:
return legacySystem.query(request);
case NEW_ONLY:
return newSystem.query(request);
case SHADOW:
return executeShadowMode(request);
case PERCENTAGE_BASED:
return executePercentageBased(request, decision);
}
}
private QueryResponse executeShadowMode(QueryRequest request) {
// Execute on legacy system (primary)
QueryResponse legacyResponse = legacySystem.query(request);
// Execute on new system asynchronously
CompletableFuture.runAsync(() -> {
try {
QueryResponse newResponse = newSystem.query(request);
// Compare results
ComparisonResult comparison = comparisonService.compare(
legacyResponse,
newResponse
);
// Log discrepancies
if (comparison.hasDiscrepancies()) {
logDiscrepancy(request, comparison);
}
// Record metrics
recordComparisonMetrics(comparison);
} catch (Exception e) {
log.error("Shadow execution failed", e);
}
});
// Return legacy response
return legacyResponse;
}
}
}
Performance Validation Framework
@Service
public class MigrationPerformanceValidator {
private final LoadGenerator loadGenerator;
private final MetricsCollector metricsCollector;
private final SLAValidator slaValidator;
public ValidationReport validatePerformance(
ValidationConfig config,
SystemUnderTest system) {
// Generate representative load
LoadProfile loadProfile = generateLoadProfile(config);
// Execute performance tests
PerformanceTestResult result = executePerformanceTest(
loadProfile,
system
);
// Validate against SLAs
SLAValidationResult slaResult = slaValidator.validate(
result,
config.getSLAs()
);
// Generate comprehensive report
return generateValidationReport(result, slaResult);
}
private PerformanceTestResult executePerformanceTest(
LoadProfile profile,
SystemUnderTest system) {
List<TestScenario> scenarios = List.of(
// Simple queries
TestScenario.builder()
.name("simple_queries")
.queryComplexity(QueryComplexity.SIMPLE)
.concurrentUsers(100)
.duration(Duration.ofMinutes(10))
.build(),
// Complex queries with high load
TestScenario.builder()
.name("complex_queries_high_load")
.queryComplexity(QueryComplexity.COMPLEX)
.concurrentUsers(500)
.duration(Duration.ofMinutes(30))
.build(),
// Burst traffic
TestScenario.builder()
.name("burst_traffic")
.trafficPattern(TrafficPattern.BURST)
.peakUsers(1000)
.duration(Duration.ofMinutes(5))
.build()
);
Map<String, ScenarioResult> results = new HashMap<>();
for (TestScenario scenario : scenarios) {
ScenarioResult scenarioResult = loadGenerator.execute(
scenario,
system
);
results.put(scenario.getName(), scenarioResult);
}
return PerformanceTestResult.builder()
.scenarios(results)
.aggregateMetrics(calculateAggregateMetrics(results))
.build();
}
private ValidationReport generateValidationReport(
PerformanceTestResult testResult,
SLAValidationResult slaResult) {
return ValidationReport.builder()
.timestamp(Instant.now())
.overallStatus(slaResult.allPassed() ? Status.PASSED : Status.FAILED)
.performanceMetrics(testResult.getAggregateMetrics())
.slaValidations(slaResult.getValidations())
.recommendations(generateRecommendations(testResult, slaResult))
.riskAssessment(assessRisks(testResult, slaResult))
.build();
}
}
Conclusion
Building enterprise-scale RAG systems requires far more than connecting an embedding model to a vector database. The patterns and architectures we've explored in this guide represent battle-tested solutions to real production challenges.
Key takeaways for implementing advanced RAG architectures:
-
Hybrid Search is Essential: Pure vector search isn't enough. Combine vector, keyword, and graph-based search for comprehensive retrieval.
-
Security Cannot Be an Afterthought: Implement multi-tenant isolation and document-level access control from the beginning.
-
Performance at Scale Requires Architecture: Intelligent caching, batch processing, and resource pooling are crucial for handling millions of documents.
-
Observability Drives Reliability: Comprehensive monitoring, tracing, and alerting are essential for maintaining production SLAs.
-
Cost Optimization is Critical: Implement adaptive strategies for embedding generation, caching, and resource allocation to control costs.
-
Migration Requires Planning: Zero-downtime migration strategies with shadow mode and gradual traffic shifting minimize risk.
Remember that these patterns are not one-size-fits-all solutions. Adapt them to your specific requirements, constraints, and scale. The key is to build a flexible architecture that can evolve with your needs while maintaining the performance, security, and reliability that enterprise systems demand.
As you implement these patterns, continuously measure and optimize. The difference between a good RAG system and a great one lies in the details of implementation and the relentless pursuit of improvement based on real-world usage patterns.