Skip to main content

Advanced RAG Overview

Beyond basic retrieval and generation, advanced RAG techniques address common limitations:
  • Corrective RAG (CRAG): Self-correcting retrieval with relevance grading
  • Hybrid Search: Combining semantic and keyword search
  • Knowledge Graphs: Multi-hop reasoning with structured knowledge
  • Query Transformation: Improving retrieval through query optimization
  • Reranking: Refining retrieved results for better context

Corrective RAG (CRAG)

CRAG implements a multi-stage workflow that grades document relevance and corrects poor retrievals:

Complete CRAG Implementation

from langgraph.graph import StateGraph, END
from typing import Dict, TypedDict, List
from langchain.schema import Document
from langchain_anthropic import ChatAnthropic
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Qdrant
from langchain_community.tools import TavilySearchResults
from langchain_core.prompts import PromptTemplate
from qdrant_client import QdrantClient
import streamlit as st

class GraphState(TypedDict):
    """State for CRAG workflow."""
    keys: Dict[str, any]

# Initialize components
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-small",
    api_key=st.session_state.openai_api_key
)

client = QdrantClient(
    url=st.session_state.qdrant_url,
    api_key=st.session_state.qdrant_api_key
)

vectorstore = Qdrant(
    client=client,
    collection_name="documents",
    embeddings=embeddings
)

llm = ChatAnthropic(
    model="claude-sonnet-4-5",
    api_key=st.session_state.anthropic_api_key,
    temperature=0
)

def retrieve(state):
    """Retrieve documents from vector store."""
    print("~-retrieve-~")
    question = state["keys"]["question"]
    
    retriever = vectorstore.as_retriever(
        search_kwargs={'k': 5}
    )
    documents = retriever.get_relevant_documents(question)
    
    return {"keys": {"documents": documents, "question": question}}

def grade_documents(state):
    """Grade document relevance to question."""
    print("~-grade documents-~")
    question = state["keys"]["question"]
    documents = state["keys"]["documents"]
    
    # Prompt for grading
    prompt = PromptTemplate(
        template="""You are grading document relevance.
        
        Document: {context}
        Question: {question}
        
        Return ONLY a JSON object: {{"score": "yes"}} or {{"score": "no"}}
        
        Rules:
        - "yes" if document contains relevant information
        - "no" if document is clearly irrelevant
        - Be lenient - only filter obvious mismatches
        """,
        input_variables=["context", "question"]
    )
    
    chain = prompt | llm
    
    filtered_docs = []
    search_needed = "No"
    
    for doc in documents:
        # Grade each document
        response = chain.invoke({
            "question": question,
            "context": doc.page_content
        })
        
        try:
            import json
            import re
            # Extract JSON from response
            json_match = re.search(r'\{.*\}', response.content)
            if json_match:
                score = json.loads(json_match.group())
                
                if score.get("score") == "yes":
                    print("~-document relevant-~")
                    filtered_docs.append(doc)
                else:
                    print("~-document not relevant-~")
                    search_needed = "Yes"
        except Exception as e:
            print(f"Error grading: {e}")
            filtered_docs.append(doc)  # Keep on error
    
    return {
        "keys": {
            "documents": filtered_docs,
            "question": question,
            "run_web_search": search_needed
        }
    }

def transform_query(state):
    """Transform query for better retrieval."""
    print("~-transform query-~")
    question = state["keys"]["question"]
    documents = state["keys"]["documents"]
    
    prompt = PromptTemplate(
        template="""Generate an optimized search query.
        
        Original question: {question}
        
        Create a better search query by:
        - Identifying key concepts
        - Adding relevant synonyms
        - Removing ambiguity
        
        Return only the improved query:
        """,
        input_variables=["question"]
    )
    
    chain = prompt | llm
    better_question = chain.invoke({"question": question})
    
    return {
        "keys": {
            "documents": documents,
            "question": better_question.content
        }
    }

def web_search(state):
    """Search web for additional information."""
    print("~-web search-~")
    question = state["keys"]["question"]
    documents = state["keys"]["documents"]
    
    # Initialize Tavily search
    search_tool = TavilySearchResults(
        api_key=st.session_state.tavily_api_key,
        max_results=3,
        search_depth="advanced"
    )
    
    try:
        search_results = search_tool.invoke({"query": question})
        
        # Convert to documents
        web_docs = []
        for result in search_results:
            content = f"Title: {result.get('title', '')}\n"
            content += f"Content: {result.get('content', '')}"
            
            web_docs.append(Document(
                page_content=content,
                metadata={"source": "web_search"}
            ))
        
        documents.extend(web_docs)
        st.success(f"Added {len(web_docs)} web results")
        
    except Exception as e:
        st.error(f"Web search error: {e}")
    
    return {"keys": {"documents": documents, "question": question}}

def generate(state):
    """Generate final answer."""
    print("~-generate-~")
    question = state["keys"]["question"]
    documents = state["keys"]["documents"]
    
    # Create context from documents
    context = "\n\n".join([doc.page_content for doc in documents])
    
    prompt = PromptTemplate(
        template="""Answer the question based on the context.
        
        Context: {context}
        
        Question: {question}
        
        Provide a comprehensive answer with:
        - Key points from the context
        - Specific details and examples
        - Clear and concise language
        
        Answer:
        """,
        input_variables=["context", "question"]
    )
    
    chain = prompt | llm
    response = chain.invoke({"context": context, "question": question})
    
    return {
        "keys": {
            "documents": documents,
            "question": question,
            "generation": response.content
        }
    }

def decide_to_generate(state):
    """Decide whether to generate or search."""
    search_needed = state["keys"]["run_web_search"]
    
    if search_needed == "Yes":
        return "transform_query"
    else:
        return "generate"

# Build workflow
workflow = StateGraph(GraphState)

# Add nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search", web_search)

# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")

workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate"
    }
)

workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)

app = workflow.compile()

# Use in Streamlit
st.title("🔄 Corrective RAG")

query = st.text_input("Ask a question:")

if st.button("Submit") and query:
    inputs = {"keys": {"question": query}}
    
    for output in app.stream(inputs):
        for key, value in output.items():
            with st.expander(f"Step: {key}"):
                st.json(value["keys"])
    
    # Show final answer
    if "generation" in value["keys"]:
        st.subheader("Answer:")
        st.write(value["keys"]["generation"])

Hybrid Search RAG

Combines semantic (vector) search with keyword (BM25) search for better retrieval:
hybrid_search.py
import streamlit as st
from raglite import (
    RAGLiteConfig,
    insert_document,
    hybrid_search,
    retrieve_chunks,
    rerank_chunks,
    rag
)
from rerankers import Reranker
from pathlib import Path
import anthropic

st.title("👀 Hybrid Search RAG")

# Configuration
config = RAGLiteConfig(
    db_url="postgresql://user:pass@host/db",
    llm="claude-3-opus-20240229",
    embedder="text-embedding-3-large",
    embedder_normalize=True,
    chunk_max_size=2000,
    reranker=Reranker("cohere", api_key=cohere_key, lang="en")
)

# Upload documents
uploaded_file = st.file_uploader("Upload PDF", type="pdf")

if uploaded_file:
    # Save and process
    path = Path(f"temp/{uploaded_file.name}")
    path.write_bytes(uploaded_file.getvalue())
    
    # Insert into RAGLite
    insert_document(path, config=config)
    st.success("Document processed!")

# Query
query = st.text_input("Ask a question:")

if st.button("Search") and query:
    # Hybrid search: combines vector + BM25
    chunk_ids, scores = hybrid_search(
        query,
        num_results=10,
        config=config
    )
    
    if chunk_ids:
        # Retrieve full chunks
        chunks = retrieve_chunks(chunk_ids, config=config)
        
        # Rerank for best results
        reranked = rerank_chunks(query, chunks, config=config)
        
        # Show results
        st.subheader("Retrieved Chunks:")
        for i, chunk in enumerate(reranked[:5], 1):
            with st.expander(f"Result {i}"):
                st.write(chunk['text'])
                st.caption(f"Score: {chunk['score']:.3f}")
        
        # Generate answer
        with st.spinner("Generating answer..."):
            answer = rag(query, config=config)
            st.subheader("Answer:")
            st.write(answer)
    else:
        # Fallback to general LLM
        client = anthropic.Anthropic(api_key=anthropic_key)
        message = client.messages.create(
            model="claude-3-sonnet-20240229",
            max_tokens=1024,
            messages=[{"role": "user", "content": query}]
        )
        st.write(message.content[0].text)

How Hybrid Search Works

1

Vector Search

Semantic similarity using embeddings
vector_results = vectorstore.similarity_search(query, k=20)
2

Keyword Search

BM25 keyword matching
from rank_bm25 import BM25Okapi

bm25 = BM25Okapi(tokenized_corpus)
keyword_results = bm25.get_top_n(tokenized_query, documents, n=20)
3

Combine Results

Merge and deduplicate
combined = list(set(vector_results + keyword_results))
4

Rerank

Use reranking model for final ordering
reranker = Reranker("cohere")
final_results = reranker.rank(query, combined)

Knowledge Graph RAG

Use graph structure for multi-hop reasoning:
import streamlit as st
import ollama
from neo4j import GraphDatabase
from typing import List, Dict
from dataclasses import dataclass

@dataclass
class Entity:
    id: str
    name: str
    entity_type: str
    description: str
    source_doc: str
    source_chunk: str

@dataclass
class Relationship:
    source: str
    target: str
    relation_type: str
    description: str
    source_doc: str

class KnowledgeGraphManager:
    def __init__(self, uri: str, user: str, password: str):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
    
    def add_entity(self, entity: Entity):
        """Add entity to graph."""
        with self.driver.session() as session:
            session.run("""
                MERGE (e:Entity {id: $id})
                SET e.name = $name,
                    e.type = $entity_type,
                    e.description = $description,
                    e.source_doc = $source_doc
            """, **entity.__dict__)
    
    def add_relationship(self, rel: Relationship):
        """Add relationship between entities."""
        with self.driver.session() as session:
            session.run("""
                MATCH (a:Entity {name: $source})
                MATCH (b:Entity {name: $target})
                MERGE (a)-[r:RELATES_TO {type: $rel_type}]->(b)
                SET r.description = $description
            """, **rel.__dict__)
    
    def find_related_entities(self, entity_name: str, hops: int = 2):
        """Multi-hop traversal to find related entities."""
        with self.driver.session() as session:
            result = session.run(f"""
                MATCH path = (start:Entity)-[*1..{hops}]-(related:Entity)
                WHERE toLower(start.name) CONTAINS toLower($name)
                RETURN related.name as name,
                       related.description as description,
                       related.source_doc as source,
                       [r in relationships(path) | r.description] as path_descriptions
                LIMIT 20
            """, name=entity_name)
            return [dict(record) for record in result]

def extract_entities_with_llm(text: str, source_doc: str) -> tuple:
    """Extract entities and relationships using LLM."""
    prompt = f"""
    Extract entities and relationships from this text.
    
    Text: {text}
    
    Return JSON format:
    {{
        "entities": [
            {{"name": "...", "type": "PERSON|ORG|CONCEPT|...", "description": "..."}}
        ],
        "relationships": [
            {{"source": "...", "target": "...", "type": "WORKS_FOR|CREATED|...", "description": "..."}}
        ]
    }}
    """
    
    response = ollama.chat(
        model="llama3.2",
        messages=[{"role": "user", "content": prompt}]
    )
    
    import json
    data = json.loads(response['message']['content'])
    
    entities = [
        Entity(
            id=e['name'].lower().replace(' ', '_'),
            name=e['name'],
            entity_type=e['type'],
            description=e['description'],
            source_doc=source_doc,
            source_chunk=text[:200]
        )
        for e in data['entities']
    ]
    
    relationships = [
        Relationship(
            source=r['source'],
            target=r['target'],
            relation_type=r['type'],
            description=r['description'],
            source_doc=source_doc
        )
        for r in data['relationships']
    ]
    
    return entities, relationships

def generate_answer_with_citations(question: str, kg: KnowledgeGraphManager):
    """Generate answer using knowledge graph traversal."""
    # Find relevant entities
    related = kg.find_related_entities(question, hops=2)
    
    # Build context from graph
    context = "\n\n".join([
        f"{r['name']}: {r['description']}\nSource: {r['source']}"
        for r in related
    ])
    
    # Generate answer
    prompt = f"""
    Answer using the knowledge graph context below.
    Include [1], [2] citations for each claim.
    
    Context:
    {context}
    
    Question: {question}
    
    Answer with citations:
    """
    
    response = ollama.chat(
        model="llama3.2",
        messages=[{"role": "user", "content": prompt}]
    )
    
    return response['message']['content'], related

# Streamlit UI
st.title("🔍 Knowledge Graph RAG")

# Neo4j connection
uri = st.text_input("Neo4j URI", "bolt://localhost:7687")
user = st.text_input("Username", "neo4j")
password = st.text_input("Password", type="password")

if uri and user and password:
    kg = KnowledgeGraphManager(uri, user, password)
    
    # Document upload
    with st.sidebar:
        st.header("Add Document")
        doc_text = st.text_area("Paste document text")
        doc_name = st.text_input("Document name")
        
        if st.button("Extract & Add to Graph"):
            if doc_text:
                entities, relationships = extract_entities_with_llm(
                    doc_text,
                    doc_name
                )
                
                # Add to graph
                for entity in entities:
                    kg.add_entity(entity)
                
                for rel in relationships:
                    kg.add_relationship(rel)
                
                st.success(f"Added {len(entities)} entities and {len(relationships)} relationships")
    
    # Query
    question = st.text_input("Ask a question:")
    
    if st.button("Answer") and question:
        with st.spinner("Traversing knowledge graph..."):
            answer, sources = generate_answer_with_citations(question, kg)
            
            st.subheader("Answer:")
            st.write(answer)
            
            st.subheader("Knowledge Graph Sources:")
            for i, source in enumerate(sources, 1):
                with st.expander(f"[{i}] {source['name']}"):
                    st.write(source['description'])
                    st.caption(f"Source: {source['source']}")

Query Transformation Techniques

from langchain.retrievers.multi_query import MultiQueryRetriever

# Generate multiple query variations
retriever = MultiQueryRetriever.from_llm(
    retriever=vectorstore.as_retriever(),
    llm=ChatOpenAI(model="gpt-4")
)

# Automatically generates variations like:
# Original: "What is RAG?"
# Variation 1: "Explain Retrieval-Augmented Generation"
# Variation 2: "How does RAG work in AI?"
# Variation 3: "RAG technique definition"

docs = retriever.get_relevant_documents("What is RAG?")

Reranking Strategies

from rerankers import Reranker

reranker = Reranker(
    "cohere",
    api_key=cohere_key,
    lang="en"
)

# Get initial results
docs = retriever.get_relevant_documents(query, k=20)

# Rerank for quality
reranked = reranker.rank(
    query=query,
    docs=[doc.page_content for doc in docs]
)

# Use top results
top_docs = reranked[:5]
from sentence_transformers import CrossEncoder

model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# Score each doc against query
scores = model.predict([
    [query, doc.page_content] for doc in docs
])

# Sort by score
ranked_docs = [
    doc for _, doc in sorted(
        zip(scores, docs),
        key=lambda x: x[0],
        reverse=True
    )
]
def rerank_with_llm(query: str, docs: List[Document]) -> List[Document]:
    """Use LLM to score and rerank documents."""
    scored_docs = []
    
    for doc in docs:
        prompt = f"""
        Rate relevance (0-10) of this document to the query.
        
        Query: {query}
        
        Document: {doc.page_content[:500]}
        
        Return only the number.
        """
        
        score = int(llm.invoke(prompt).content)
        scored_docs.append((score, doc))
    
    # Sort by score
    scored_docs.sort(reverse=True, key=lambda x: x[0])
    return [doc for _, doc in scored_docs]

Best Practices

Combine Techniques

Use multiple techniques together:
  • Hybrid search + reranking
  • Query transformation + CRAG
  • Knowledge graphs + vector search

Evaluation

Measure performance:
  • Retrieval precision/recall
  • Answer quality metrics
  • Latency and cost
  • A/B test variations

Fallback Strategies

Always have fallbacks:
  • Web search when KB fails
  • General LLM for out-of-scope
  • Human escalation paths

Monitor Quality

Track metrics:
  • Document relevance scores
  • User feedback
  • Failure cases
  • Source attribution

Next Steps

Local RAG

Build privacy-focused RAG with Ollama and local models

Build docs developers (and LLMs) love