Advanced RAG Overview
Beyond basic retrieval and generation, advanced RAG techniques address common limitations:- Corrective RAG (CRAG): Self-correcting retrieval with relevance grading
- Hybrid Search: Combining semantic and keyword search
- Knowledge Graphs: Multi-hop reasoning with structured knowledge
- Query Transformation: Improving retrieval through query optimization
- Reranking: Refining retrieved results for better context
Corrective RAG (CRAG)
CRAG implements a multi-stage workflow that grades document relevance and corrects poor retrievals:Complete CRAG Implementation
from langgraph.graph import StateGraph, END
from typing import Dict, TypedDict, List
from langchain.schema import Document
from langchain_anthropic import ChatAnthropic
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Qdrant
from langchain_community.tools import TavilySearchResults
from langchain_core.prompts import PromptTemplate
from qdrant_client import QdrantClient
import streamlit as st
class GraphState(TypedDict):
"""State for CRAG workflow."""
keys: Dict[str, any]
# Initialize components
embeddings = OpenAIEmbeddings(
model="text-embedding-3-small",
api_key=st.session_state.openai_api_key
)
client = QdrantClient(
url=st.session_state.qdrant_url,
api_key=st.session_state.qdrant_api_key
)
vectorstore = Qdrant(
client=client,
collection_name="documents",
embeddings=embeddings
)
llm = ChatAnthropic(
model="claude-sonnet-4-5",
api_key=st.session_state.anthropic_api_key,
temperature=0
)
def retrieve(state):
"""Retrieve documents from vector store."""
print("~-retrieve-~")
question = state["keys"]["question"]
retriever = vectorstore.as_retriever(
search_kwargs={'k': 5}
)
documents = retriever.get_relevant_documents(question)
return {"keys": {"documents": documents, "question": question}}
def grade_documents(state):
"""Grade document relevance to question."""
print("~-grade documents-~")
question = state["keys"]["question"]
documents = state["keys"]["documents"]
# Prompt for grading
prompt = PromptTemplate(
template="""You are grading document relevance.
Document: {context}
Question: {question}
Return ONLY a JSON object: {{"score": "yes"}} or {{"score": "no"}}
Rules:
- "yes" if document contains relevant information
- "no" if document is clearly irrelevant
- Be lenient - only filter obvious mismatches
""",
input_variables=["context", "question"]
)
chain = prompt | llm
filtered_docs = []
search_needed = "No"
for doc in documents:
# Grade each document
response = chain.invoke({
"question": question,
"context": doc.page_content
})
try:
import json
import re
# Extract JSON from response
json_match = re.search(r'\{.*\}', response.content)
if json_match:
score = json.loads(json_match.group())
if score.get("score") == "yes":
print("~-document relevant-~")
filtered_docs.append(doc)
else:
print("~-document not relevant-~")
search_needed = "Yes"
except Exception as e:
print(f"Error grading: {e}")
filtered_docs.append(doc) # Keep on error
return {
"keys": {
"documents": filtered_docs,
"question": question,
"run_web_search": search_needed
}
}
def transform_query(state):
"""Transform query for better retrieval."""
print("~-transform query-~")
question = state["keys"]["question"]
documents = state["keys"]["documents"]
prompt = PromptTemplate(
template="""Generate an optimized search query.
Original question: {question}
Create a better search query by:
- Identifying key concepts
- Adding relevant synonyms
- Removing ambiguity
Return only the improved query:
""",
input_variables=["question"]
)
chain = prompt | llm
better_question = chain.invoke({"question": question})
return {
"keys": {
"documents": documents,
"question": better_question.content
}
}
def web_search(state):
"""Search web for additional information."""
print("~-web search-~")
question = state["keys"]["question"]
documents = state["keys"]["documents"]
# Initialize Tavily search
search_tool = TavilySearchResults(
api_key=st.session_state.tavily_api_key,
max_results=3,
search_depth="advanced"
)
try:
search_results = search_tool.invoke({"query": question})
# Convert to documents
web_docs = []
for result in search_results:
content = f"Title: {result.get('title', '')}\n"
content += f"Content: {result.get('content', '')}"
web_docs.append(Document(
page_content=content,
metadata={"source": "web_search"}
))
documents.extend(web_docs)
st.success(f"Added {len(web_docs)} web results")
except Exception as e:
st.error(f"Web search error: {e}")
return {"keys": {"documents": documents, "question": question}}
def generate(state):
"""Generate final answer."""
print("~-generate-~")
question = state["keys"]["question"]
documents = state["keys"]["documents"]
# Create context from documents
context = "\n\n".join([doc.page_content for doc in documents])
prompt = PromptTemplate(
template="""Answer the question based on the context.
Context: {context}
Question: {question}
Provide a comprehensive answer with:
- Key points from the context
- Specific details and examples
- Clear and concise language
Answer:
""",
input_variables=["context", "question"]
)
chain = prompt | llm
response = chain.invoke({"context": context, "question": question})
return {
"keys": {
"documents": documents,
"question": question,
"generation": response.content
}
}
def decide_to_generate(state):
"""Decide whether to generate or search."""
search_needed = state["keys"]["run_web_search"]
if search_needed == "Yes":
return "transform_query"
else:
return "generate"
# Build workflow
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search", web_search)
# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate"
}
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)
app = workflow.compile()
# Use in Streamlit
st.title("🔄 Corrective RAG")
query = st.text_input("Ask a question:")
if st.button("Submit") and query:
inputs = {"keys": {"question": query}}
for output in app.stream(inputs):
for key, value in output.items():
with st.expander(f"Step: {key}"):
st.json(value["keys"])
# Show final answer
if "generation" in value["keys"]:
st.subheader("Answer:")
st.write(value["keys"]["generation"])
Hybrid Search RAG
Combines semantic (vector) search with keyword (BM25) search for better retrieval:hybrid_search.py
import streamlit as st
from raglite import (
RAGLiteConfig,
insert_document,
hybrid_search,
retrieve_chunks,
rerank_chunks,
rag
)
from rerankers import Reranker
from pathlib import Path
import anthropic
st.title("👀 Hybrid Search RAG")
# Configuration
config = RAGLiteConfig(
db_url="postgresql://user:pass@host/db",
llm="claude-3-opus-20240229",
embedder="text-embedding-3-large",
embedder_normalize=True,
chunk_max_size=2000,
reranker=Reranker("cohere", api_key=cohere_key, lang="en")
)
# Upload documents
uploaded_file = st.file_uploader("Upload PDF", type="pdf")
if uploaded_file:
# Save and process
path = Path(f"temp/{uploaded_file.name}")
path.write_bytes(uploaded_file.getvalue())
# Insert into RAGLite
insert_document(path, config=config)
st.success("Document processed!")
# Query
query = st.text_input("Ask a question:")
if st.button("Search") and query:
# Hybrid search: combines vector + BM25
chunk_ids, scores = hybrid_search(
query,
num_results=10,
config=config
)
if chunk_ids:
# Retrieve full chunks
chunks = retrieve_chunks(chunk_ids, config=config)
# Rerank for best results
reranked = rerank_chunks(query, chunks, config=config)
# Show results
st.subheader("Retrieved Chunks:")
for i, chunk in enumerate(reranked[:5], 1):
with st.expander(f"Result {i}"):
st.write(chunk['text'])
st.caption(f"Score: {chunk['score']:.3f}")
# Generate answer
with st.spinner("Generating answer..."):
answer = rag(query, config=config)
st.subheader("Answer:")
st.write(answer)
else:
# Fallback to general LLM
client = anthropic.Anthropic(api_key=anthropic_key)
message = client.messages.create(
model="claude-3-sonnet-20240229",
max_tokens=1024,
messages=[{"role": "user", "content": query}]
)
st.write(message.content[0].text)
How Hybrid Search Works
Vector Search
Semantic similarity using embeddings
vector_results = vectorstore.similarity_search(query, k=20)
Keyword Search
BM25 keyword matching
from rank_bm25 import BM25Okapi
bm25 = BM25Okapi(tokenized_corpus)
keyword_results = bm25.get_top_n(tokenized_query, documents, n=20)
Knowledge Graph RAG
Use graph structure for multi-hop reasoning:import streamlit as st
import ollama
from neo4j import GraphDatabase
from typing import List, Dict
from dataclasses import dataclass
@dataclass
class Entity:
id: str
name: str
entity_type: str
description: str
source_doc: str
source_chunk: str
@dataclass
class Relationship:
source: str
target: str
relation_type: str
description: str
source_doc: str
class KnowledgeGraphManager:
def __init__(self, uri: str, user: str, password: str):
self.driver = GraphDatabase.driver(uri, auth=(user, password))
def add_entity(self, entity: Entity):
"""Add entity to graph."""
with self.driver.session() as session:
session.run("""
MERGE (e:Entity {id: $id})
SET e.name = $name,
e.type = $entity_type,
e.description = $description,
e.source_doc = $source_doc
""", **entity.__dict__)
def add_relationship(self, rel: Relationship):
"""Add relationship between entities."""
with self.driver.session() as session:
session.run("""
MATCH (a:Entity {name: $source})
MATCH (b:Entity {name: $target})
MERGE (a)-[r:RELATES_TO {type: $rel_type}]->(b)
SET r.description = $description
""", **rel.__dict__)
def find_related_entities(self, entity_name: str, hops: int = 2):
"""Multi-hop traversal to find related entities."""
with self.driver.session() as session:
result = session.run(f"""
MATCH path = (start:Entity)-[*1..{hops}]-(related:Entity)
WHERE toLower(start.name) CONTAINS toLower($name)
RETURN related.name as name,
related.description as description,
related.source_doc as source,
[r in relationships(path) | r.description] as path_descriptions
LIMIT 20
""", name=entity_name)
return [dict(record) for record in result]
def extract_entities_with_llm(text: str, source_doc: str) -> tuple:
"""Extract entities and relationships using LLM."""
prompt = f"""
Extract entities and relationships from this text.
Text: {text}
Return JSON format:
{{
"entities": [
{{"name": "...", "type": "PERSON|ORG|CONCEPT|...", "description": "..."}}
],
"relationships": [
{{"source": "...", "target": "...", "type": "WORKS_FOR|CREATED|...", "description": "..."}}
]
}}
"""
response = ollama.chat(
model="llama3.2",
messages=[{"role": "user", "content": prompt}]
)
import json
data = json.loads(response['message']['content'])
entities = [
Entity(
id=e['name'].lower().replace(' ', '_'),
name=e['name'],
entity_type=e['type'],
description=e['description'],
source_doc=source_doc,
source_chunk=text[:200]
)
for e in data['entities']
]
relationships = [
Relationship(
source=r['source'],
target=r['target'],
relation_type=r['type'],
description=r['description'],
source_doc=source_doc
)
for r in data['relationships']
]
return entities, relationships
def generate_answer_with_citations(question: str, kg: KnowledgeGraphManager):
"""Generate answer using knowledge graph traversal."""
# Find relevant entities
related = kg.find_related_entities(question, hops=2)
# Build context from graph
context = "\n\n".join([
f"{r['name']}: {r['description']}\nSource: {r['source']}"
for r in related
])
# Generate answer
prompt = f"""
Answer using the knowledge graph context below.
Include [1], [2] citations for each claim.
Context:
{context}
Question: {question}
Answer with citations:
"""
response = ollama.chat(
model="llama3.2",
messages=[{"role": "user", "content": prompt}]
)
return response['message']['content'], related
# Streamlit UI
st.title("🔍 Knowledge Graph RAG")
# Neo4j connection
uri = st.text_input("Neo4j URI", "bolt://localhost:7687")
user = st.text_input("Username", "neo4j")
password = st.text_input("Password", type="password")
if uri and user and password:
kg = KnowledgeGraphManager(uri, user, password)
# Document upload
with st.sidebar:
st.header("Add Document")
doc_text = st.text_area("Paste document text")
doc_name = st.text_input("Document name")
if st.button("Extract & Add to Graph"):
if doc_text:
entities, relationships = extract_entities_with_llm(
doc_text,
doc_name
)
# Add to graph
for entity in entities:
kg.add_entity(entity)
for rel in relationships:
kg.add_relationship(rel)
st.success(f"Added {len(entities)} entities and {len(relationships)} relationships")
# Query
question = st.text_input("Ask a question:")
if st.button("Answer") and question:
with st.spinner("Traversing knowledge graph..."):
answer, sources = generate_answer_with_citations(question, kg)
st.subheader("Answer:")
st.write(answer)
st.subheader("Knowledge Graph Sources:")
for i, source in enumerate(sources, 1):
with st.expander(f"[{i}] {source['name']}"):
st.write(source['description'])
st.caption(f"Source: {source['source']}")
Query Transformation Techniques
- Multi-Query
- Step-Back Prompting
- Query Decomposition
from langchain.retrievers.multi_query import MultiQueryRetriever
# Generate multiple query variations
retriever = MultiQueryRetriever.from_llm(
retriever=vectorstore.as_retriever(),
llm=ChatOpenAI(model="gpt-4")
)
# Automatically generates variations like:
# Original: "What is RAG?"
# Variation 1: "Explain Retrieval-Augmented Generation"
# Variation 2: "How does RAG work in AI?"
# Variation 3: "RAG technique definition"
docs = retriever.get_relevant_documents("What is RAG?")
def step_back_query(original_query: str) -> str:
"""Generate broader, more general query."""
prompt = f"""
Given this specific question:
{original_query}
Generate a broader, more general question that would help
retrieve useful background information.
Return only the broader question.
"""
response = llm.invoke(prompt)
return response.content
# Example:
# Original: "How does GPT-4 handle context windows?"
# Step-back: "How do large language models manage context?"
broad_query = step_back_query(original_query)
background_docs = retriever.get_relevant_documents(broad_query)
specific_docs = retriever.get_relevant_documents(original_query)
all_docs = background_docs + specific_docs
def decompose_query(complex_query: str) -> List[str]:
"""Break complex query into sub-questions."""
prompt = f"""
Break this complex question into simpler sub-questions:
{complex_query}
Return a JSON list of sub-questions.
"""
response = llm.invoke(prompt)
import json
return json.loads(response.content)
# Example:
# Complex: "Compare RAG and fine-tuning for domain adaptation"
# Sub-questions:
# 1. "What is RAG?"
# 2. "What is fine-tuning?"
# 3. "How is RAG used for domain adaptation?"
# 4. "How is fine-tuning used for domain adaptation?"
sub_questions = decompose_query(complex_query)
all_docs = []
for sub_q in sub_questions:
docs = retriever.get_relevant_documents(sub_q)
all_docs.extend(docs)
Reranking Strategies
Cohere Reranker
Cohere Reranker
from rerankers import Reranker
reranker = Reranker(
"cohere",
api_key=cohere_key,
lang="en"
)
# Get initial results
docs = retriever.get_relevant_documents(query, k=20)
# Rerank for quality
reranked = reranker.rank(
query=query,
docs=[doc.page_content for doc in docs]
)
# Use top results
top_docs = reranked[:5]
Cross-Encoder Reranking
Cross-Encoder Reranking
from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# Score each doc against query
scores = model.predict([
[query, doc.page_content] for doc in docs
])
# Sort by score
ranked_docs = [
doc for _, doc in sorted(
zip(scores, docs),
key=lambda x: x[0],
reverse=True
)
]
LLM-based Reranking
LLM-based Reranking
def rerank_with_llm(query: str, docs: List[Document]) -> List[Document]:
"""Use LLM to score and rerank documents."""
scored_docs = []
for doc in docs:
prompt = f"""
Rate relevance (0-10) of this document to the query.
Query: {query}
Document: {doc.page_content[:500]}
Return only the number.
"""
score = int(llm.invoke(prompt).content)
scored_docs.append((score, doc))
# Sort by score
scored_docs.sort(reverse=True, key=lambda x: x[0])
return [doc for _, doc in scored_docs]
Best Practices
Combine Techniques
Use multiple techniques together:
- Hybrid search + reranking
- Query transformation + CRAG
- Knowledge graphs + vector search
Evaluation
Measure performance:
- Retrieval precision/recall
- Answer quality metrics
- Latency and cost
- A/B test variations
Fallback Strategies
Always have fallbacks:
- Web search when KB fails
- General LLM for out-of-scope
- Human escalation paths
Monitor Quality
Track metrics:
- Document relevance scores
- User feedback
- Failure cases
- Source attribution
Next Steps
Local RAG
Build privacy-focused RAG with Ollama and local models
