Retrieval-Augmented Generation (RAG) has become the go-to architecture for building AI applications that need access to external knowledge. But there’s a huge gap between a simple RAG demo and a production system that serves millions of users reliably. This guide covers everything you need to build RAG systems that actually work in the real world.
RAG combines retrieval and generation in a simple but powerful pattern:
But production RAG systems require much more sophistication:
class ProductionRAGSystem:
def __init__(self):
self.embedder = EmbeddingModel()
self.vector_store = VectorDatabase()
self.reranker = RerankerModel()
self.llm = LanguageModel()
self.cache = RedisCache()
self.monitoring = MonitoringService()
async def query(self, question: str) -> str:
# Multi-stage retrieval with monitoring
start_time = time.time()
try:
# Check cache first
cached_result = await self.cache.get(question)
if cached_result:
self.monitoring.log_cache_hit(question)
return cached_result
# Retrieve candidates
candidates = await self.retrieve_candidates(question)
# Rerank for relevance
relevant_docs = await self.rerank_documents(question, candidates)
# Generate response
response = await self.generate_response(question, relevant_docs)
# Cache and monitor
await self.cache.set(question, response)
self.monitoring.log_query(question, response, time.time() - start_time)
return response
except Exception as e:
self.monitoring.log_error(question, e)
raise
Pinecone: Best for getting started
import pinecone
# Easy setup, managed service
pinecone.init(api_key="your-key")
index = pinecone.Index("your-index")
# But limited customization and can be expensive at scale
Weaviate: Best for flexibility
import weaviate
# Rich query capabilities, GraphQL API
client = weaviate.Client("http://localhost:8080")
# Good for complex data relationships
result = client.query.get("Document") \
.with_near_text({"concepts": ["machine learning"]}) \
.with_additional(["certainty"]) \
.with_limit(10) \
.do()
Qdrant: Best for performance
from qdrant_client import QdrantClient
# High performance, written in Rust
client = QdrantClient("localhost", port=6333)
# Excellent for high-throughput applications
search_result = client.search(
collection_name="documents",
query_vector=embedding,
limit=10
)
Chroma: Best for development
import chromadb
# Simple, embeddable database
client = chromadb.Client()
collection = client.create_collection("documents")
# Great for prototyping and small applications
results = collection.query(
query_texts=["What is machine learning?"],
n_results=10
)
# HNSW (Hierarchical Navigable Small World) configuration
index_config = {
"index_type": "HNSW",
"params": {
"M": 16, # Number of connections (16-64 for most cases)
"efConstruction": 200, # Higher = better quality, slower indexing
"ef": 100, # Higher = better recall, slower search
}
}
# IVF (Inverted File) for larger datasets
index_config = {
"index_type": "IVF_FLAT",
"params": {
"nlist": 1024, # Number of clusters (sqrt(n) is good starting point)
"nprobe": 10, # Number of clusters to search
}
}
class OptimizedEmbedder:
def __init__(self):
# Use smaller, faster models for retrieval
self.retrieval_model = SentenceTransformer('all-MiniLM-L6-v2') # 384 dims
# Use larger models for final ranking if needed
self.ranking_model = SentenceTransformer('all-mpnet-base-v2') # 768 dims
def embed_for_retrieval(self, text: str) -> np.ndarray:
# Fast embedding for initial retrieval
return self.retrieval_model.encode(text)
def embed_for_ranking(self, text: str) -> np.ndarray:
# High-quality embedding for reranking
return self.ranking_model.encode(text)
class HybridRetriever:
def __init__(self, vector_store, keyword_index):
self.vector_store = vector_store
self.keyword_index = keyword_index # Elasticsearch, etc.
async def retrieve(self, query: str, top_k: int = 20) -> List[Document]:
# Parallel retrieval
vector_results, keyword_results = await asyncio.gather(
self.vector_search(query, top_k),
self.keyword_search(query, top_k)
)
# Combine and deduplicate
combined_results = self.merge_results(vector_results, keyword_results)
# Reciprocal rank fusion for scoring
return self.rerank_fusion(combined_results, top_k)
def rerank_fusion(self, results: List[Tuple[Document, float]], k: int = 60) -> List[Document]:
"""Reciprocal Rank Fusion (RRF) scoring"""
score_dict = {}
for rank, (doc, score) in enumerate(results):
doc_id = doc.id
if doc_id not in score_dict:
score_dict[doc_id] = {"doc": doc, "score": 0}
# RRF formula: 1 / (k + rank)
score_dict[doc_id]["score"] += 1 / (k + rank + 1)
# Sort by combined score
sorted_results = sorted(score_dict.values(), key=lambda x: x["score"], reverse=True)
return [item["doc"] for item in sorted_results]
class MultiQueryRetriever:
def __init__(self, retriever, query_generator):
self.retriever = retriever
self.query_generator = query_generator
async def retrieve(self, original_query: str, top_k: int = 10) -> List[Document]:
# Generate multiple variations of the query
query_variations = await self.query_generator.generate_variations(original_query)
# Retrieve for each variation
all_results = []
for query in query_variations:
results = await self.retriever.retrieve(query, top_k)
all_results.extend(results)
# Deduplicate and rank
return self.deduplicate_and_rank(all_results, top_k)
class QueryGenerator:
def __init__(self, llm):
self.llm = llm
async def generate_variations(self, query: str) -> List[str]:
prompt = f"""
Generate 3 different ways to ask this question that would help find relevant information:
Original question: {query}
Variations:
1.
2.
3.
"""
response = await self.llm.generate(prompt)
return self.parse_variations(response)
class HierarchicalRetriever:
def __init__(self, summary_index, detail_index):
self.summary_index = summary_index # Document summaries
self.detail_index = detail_index # Full document chunks
async def retrieve(self, query: str, top_k: int = 10) -> List[Document]:
# First, find relevant documents via summaries
relevant_summaries = await self.summary_index.search(query, top_k=50)
# Extract document IDs
doc_ids = [summary.document_id for summary in relevant_summaries]
# Then search within those documents for specific chunks
detailed_results = await self.detail_index.search(
query,
filter={"document_id": {"$in": doc_ids}},
top_k=top_k
)
return detailed_results
import re
from typing import List, Dict
class IntelligentChunker:
def __init__(self, chunk_size: int = 512, overlap: int = 50):
self.chunk_size = chunk_size
self.overlap = overlap
def chunk_document(self, text: str, metadata: Dict = None) -> List[Dict]:
"""Chunk document with awareness of structure"""
# Detect document type and apply appropriate strategy
if self.is_code_document(text):
return self.chunk_code(text, metadata)
elif self.is_structured_document(text):
return self.chunk_structured(text, metadata)
else:
return self.chunk_semantic(text, metadata)
def chunk_code(self, text: str, metadata: Dict) -> List[Dict]:
"""Chunk code while preserving function/class boundaries"""
chunks = []
# Split by functions/classes first
function_pattern = r'(def\s+\w+|class\s+\w+|async\s+def\s+\w+)'
sections = re.split(function_pattern, text)
current_chunk = ""
for section in sections:
if len(current_chunk + section) > self.chunk_size:
if current_chunk:
chunks.append(self.create_chunk(current_chunk, metadata))
current_chunk = section
else:
current_chunk += section
if current_chunk:
chunks.append(self.create_chunk(current_chunk, metadata))
return chunks
def chunk_structured(self, text: str, metadata: Dict) -> List[Dict]:
"""Chunk structured documents (markdown, etc.)"""
# Split by headers first
header_pattern = r'(#{1,6}\s+.*?)(?=\n)'
sections = re.split(header_pattern, text)
chunks = []
current_section = ""
current_header = ""
for i, section in enumerate(sections):
if re.match(header_pattern, section):
current_header = section
else:
# Chunk within section
section_chunks = self.chunk_text_by_sentences(section)
for chunk_text in section_chunks:
chunk_metadata = {**metadata, "header": current_header}
chunks.append(self.create_chunk(chunk_text, chunk_metadata))
return chunks
def chunk_semantic(self, text: str, metadata: Dict) -> List[Dict]:
"""Semantic chunking using sentence boundaries"""
sentences = self.split_sentences(text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk + sentence) > self.chunk_size:
if current_chunk:
chunks.append(self.create_chunk(current_chunk, metadata))
current_chunk = sentence
else:
current_chunk += " " + sentence
if current_chunk:
chunks.append(self.create_chunk(current_chunk, metadata))
return chunks
def create_chunk(self, text: str, metadata: Dict) -> Dict:
return {
"text": text.strip(),
"metadata": metadata,
"chunk_id": hashlib.md5(text.encode()).hexdigest()
}
from sentence_transformers import CrossEncoder
class Reranker:
def __init__(self):
# Specialized reranking model
self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
def rerank(self, query: str, documents: List[str], top_k: int = 10) -> List[Tuple[str, float]]:
"""Rerank documents using cross-encoder"""
# Create query-document pairs
pairs = [(query, doc) for doc in documents]
# Score all pairs
scores = self.cross_encoder.predict(pairs)
# Sort by score and return top_k
scored_docs = list(zip(documents, scores))
scored_docs.sort(key=lambda x: x[1], reverse=True)
return scored_docs[:top_k]
class ProductionReranker:
def __init__(self):
self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
self.cache = {}
async def rerank_batch(self, queries: List[str], document_lists: List[List[str]]) -> List[List[Tuple[str, float]]]:
"""Efficient batch reranking"""
all_pairs = []
query_doc_mapping = []
for i, (query, docs) in enumerate(zip(queries, document_lists)):
for j, doc in enumerate(docs):
all_pairs.append((query, doc))
query_doc_mapping.append((i, j))
# Batch scoring for efficiency
scores = self.cross_encoder.predict(all_pairs)
# Reorganize results
results = [[] for _ in queries]
for (query_idx, doc_idx), score in zip(query_doc_mapping, scores):
doc = document_lists[query_idx][doc_idx]
results[query_idx].append((doc, float(score)))
# Sort each query's results
for result in results:
result.sort(key=lambda x: x[1], reverse=True)
return results
class RAGMonitor:
def __init__(self):
self.metrics_store = MetricsStore()
def evaluate_retrieval(self, queries: List[str], retrieved_docs: List[List[str]], ground_truth: List[List[str]]) -> Dict:
"""Evaluate retrieval quality"""
metrics = {
"precision_at_k": [],
"recall_at_k": [],
"mrr": [], # Mean Reciprocal Rank
"ndcg": [] # Normalized Discounted Cumulative Gain
}
for query, retrieved, truth in zip(queries, retrieved_docs, ground_truth):
# Precision@K
relevant_retrieved = set(retrieved[:10]) & set(truth)
precision = len(relevant_retrieved) / min(10, len(retrieved))
metrics["precision_at_k"].append(precision)
# Recall@K
recall = len(relevant_retrieved) / len(truth) if truth else 0
metrics["recall_at_k"].append(recall)
# MRR
for i, doc in enumerate(retrieved):
if doc in truth:
metrics["mrr"].append(1 / (i + 1))
break
else:
metrics["mrr"].append(0)
# Average metrics
return {k: sum(v) / len(v) for k, v in metrics.items()}
def evaluate_generation(self, questions: List[str], generated_answers: List[str], reference_answers: List[str]) -> Dict:
"""Evaluate generation quality"""
from rouge_score import rouge_scorer
from bert_score import score
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'])
rouge_scores = []
for gen, ref in zip(generated_answers, reference_answers):
scores = scorer.score(ref, gen)
rouge_scores.append({
'rouge1': scores['rouge1'].fmeasure,
'rouge2': scores['rouge2'].fmeasure,
'rougeL': scores['rougeL'].fmeasure
})
# BERTScore for semantic similarity
P, R, F1 = score(generated_answers, reference_answers, lang="en")
return {
"rouge1": sum(s['rouge1'] for s in rouge_scores) / len(rouge_scores),
"rouge2": sum(s['rouge2'] for s in rouge_scores) / len(rouge_scores),
"rougeL": sum(s['rougeL'] for s in rouge_scores) / len(rouge_scores),
"bert_score": F1.mean().item()
}
async def log_query_metrics(self, query: str, retrieved_docs: List[str], generated_answer: str, user_feedback: float = None):
"""Log real-time metrics"""
metrics = {
"timestamp": datetime.utcnow(),
"query": query,
"num_retrieved": len(retrieved_docs),
"answer_length": len(generated_answer),
"user_feedback": user_feedback
}
await self.metrics_store.log(metrics)
class RAGABTesting:
def __init__(self):
self.experiments = {}
def create_experiment(self, name: str, control_config: Dict, treatment_config: Dict, traffic_split: float = 0.5):
"""Create A/B test experiment"""
self.experiments[name] = {
"control": control_config,
"treatment": treatment_config,
"traffic_split": traffic_split,
"metrics": {"control": [], "treatment": []}
}
async def query_with_experiment(self, experiment_name: str, query: str, user_id: str) -> str:
"""Route query to control or treatment"""
experiment = self.experiments[experiment_name]
# Deterministic assignment based on user_id
import hashlib
hash_value = int(hashlib.md5(user_id.encode()).hexdigest(), 16)
is_treatment = (hash_value % 100) < (experiment["traffic_split"] * 100)
config = experiment["treatment"] if is_treatment else experiment["control"]
variant = "treatment" if is_treatment else "control"
# Run query with selected configuration
start_time = time.time()
result = await self.run_rag_query(query, config)
latency = time.time() - start_time
# Log metrics
experiment["metrics"][variant].append({
"latency": latency,
"user_id": user_id,
"query": query,
"timestamp": datetime.utcnow()
})
return result
def analyze_experiment(self, experiment_name: str) -> Dict:
"""Analyze A/B test results"""
experiment = self.experiments[experiment_name]
control_metrics = experiment["metrics"]["control"]
treatment_metrics = experiment["metrics"]["treatment"]
# Statistical analysis
control_latencies = [m["latency"] for m in control_metrics]
treatment_latencies = [m["latency"] for m in treatment_metrics]
from scipy import stats
t_stat, p_value = stats.ttest_ind(control_latencies, treatment_latencies)
return {
"control_mean_latency": np.mean(control_latencies),
"treatment_mean_latency": np.mean(treatment_latencies),
"latency_improvement": (np.mean(control_latencies) - np.mean(treatment_latencies)) / np.mean(control_latencies),
"statistical_significance": p_value < 0.05,
"p_value": p_value,
"sample_sizes": {
"control": len(control_metrics),
"treatment": len(treatment_metrics)
}
}
# kubernetes deployment
apiVersion: apps/v1
kind: Deployment
metadata:
name: rag-service
spec:
replicas: 3
selector:
matchLabels:
app: rag-service
template:
metadata:
labels:
app: rag-service
spec:
containers:
- name: rag-service
image: your-rag-service:latest
ports:
- containerPort: 8000
env:
- name: VECTOR_DB_URL
value: "http://qdrant-service:6333"
- name: LLM_API_KEY
valueFrom:
secretKeyRef:
name: llm-api-secret
key: api-key
resources:
requests:
memory: "2Gi"
cpu: "1000m"
limits:
memory: "4Gi"
cpu: "2000m"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /ready
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
name: rag-service
spec:
selector:
app: rag-service
ports:
- port: 80
targetPort: 8000
type: LoadBalancer
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: rag-service-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: rag-service
minReplicas: 3
maxReplicas: 20
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: 80
import redis
import json
import hashlib
from typing import Optional
class RAGCache:
def __init__(self, redis_url: str = "redis://localhost:6379"):
self.redis_client = redis.from_url(redis_url)
self.default_ttl = 3600 # 1 hour
def _make_key(self, query: str, config_hash: str) -> str:
"""Create cache key from query and configuration"""
content = f"{query}:{config_hash}"
return f"rag:{hashlib.md5(content.encode()).hexdigest()}"
async def get_cached_response(self, query: str, config: Dict) -> Optional[str]:
"""Get cached response if available"""
config_hash = hashlib.md5(json.dumps(config, sort_keys=True).encode()).hexdigest()
key = self._make_key(query, config_hash)
cached = self.redis_client.get(key)
if cached:
return json.loads(cached)["response"]
return None
async def cache_response(self, query: str, config: Dict, response: str, ttl: int = None):
"""Cache response with TTL"""
config_hash = hashlib.md5(json.dumps(config, sort_keys=True).encode()).hexdigest()
key = self._make_key(query, config_hash)
cache_data = {
"response": response,
"timestamp": datetime.utcnow().isoformat(),
"config_hash": config_hash
}
self.redis_client.setex(
key,
ttl or self.default_ttl,
json.dumps(cache_data)
)
def invalidate_pattern(self, pattern: str):
"""Invalidate cache entries matching pattern"""
keys = self.redis_client.keys(f"rag:{pattern}")
if keys:
self.redis_client.delete(*keys)
class TokenOptimizer:
def __init__(self, model_name: str):
self.model_name = model_name
self.input_cost_per_token = self.get_input_cost(model_name)
self.output_cost_per_token = self.get_output_cost(model_name)
def estimate_cost(self, prompt: str, max_tokens: int) -> float:
"""Estimate cost for a query"""
import tiktoken
encoder = tiktoken.encoding_for_model(self.model_name)
input_tokens = len(encoder.encode(prompt))
input_cost = input_tokens * self.input_cost_per_token
output_cost = max_tokens * self.output_cost_per_token
return input_cost + output_cost
def optimize_context(self, query: str, documents: List[str], max_cost: float, max_tokens: int = 4000) -> List[str]:
"""Select documents to fit within cost/token budget"""
import tiktoken
encoder = tiktoken.encoding_for_model(self.model_name)
query_tokens = len(encoder.encode(query))
available_tokens = max_tokens - query_tokens - 500 # Reserve for response
selected_docs = []
total_tokens = 0
# Sort documents by relevance score (assumed to be pre-computed)
for doc in documents:
doc_tokens = len(encoder.encode(doc))
if total_tokens + doc_tokens <= available_tokens:
selected_docs.append(doc)
total_tokens += doc_tokens
else:
# Try to fit partial document
remaining_tokens = available_tokens - total_tokens
if remaining_tokens > 100: # Minimum useful chunk
truncated_doc = self.truncate_document(doc, remaining_tokens, encoder)
selected_docs.append(truncated_doc)
break
return selected_docs
def truncate_document(self, document: str, max_tokens: int, encoder) -> str:
"""Intelligently truncate document to fit token budget"""
sentences = document.split('. ')
truncated = ""
for sentence in sentences:
test_text = truncated + sentence + ". "
if len(encoder.encode(test_text)) <= max_tokens:
truncated = test_text
else:
break
return truncated.strip()
Building production RAG systems requires attention to many details beyond the basic retrieval-generation pattern. The key areas to focus on:
The RAG pattern is powerful, but production systems need robust engineering around the core concept. Start simple, measure everything, and iterate based on real user feedback and performance data.
Remember: the best RAG system is one that reliably serves your users’ needs, not the one with the highest benchmark scores. Focus on end-to-end user experience and build incrementally toward that goal.