Part 3: RAG Implementation with ChromaDB

3 June 2025 · netologist · 3 min, 618 words ·

Why RAG + Local Embeddings?

RAG (Retrieval Augmented Generation) solves the knowledge limitation problem of LLMs by:

Vector Database Setup

# rag/vector_store.py
import chromadb
from chromadb.config import Settings
import uuid
from typing import List, Dict, Any, Optional
from datetime import datetime
import logging
from models.ollama_manager import OllamaManager
from config import config

class VectorStore:
    def __init__(self):
        # Initialize ChromaDB with persistence
        self.client = chromadb.PersistentClient(
            path=config.CHROMA_PERSIST_DIR,
            settings=Settings(
                anonymized_telemetry=False,
                allow_reset=True
            )
        )
        
        self.ollama_manager = OllamaManager()
        
        # Create collections for different data types
        self.collections = {
            'documents': self._get_or_create_collection('documents'),
            'conversations': self._get_or_create_collection('conversations'),
            'web_searches': self._get_or_create_collection('web_searches'),
            'personal_notes': self._get_or_create_collection('personal_notes')
        }
    
    def _get_or_create_collection(self, name: str):
        """Get or create a collection with custom embedding function"""
        try:
            return self.client.get_collection(name)
        except ValueError:
            # Collection doesn't exist, create it
            return self.client.create_collection(
                name=name,
                metadata={"hnsw:space": "cosine"}
            )
    
    async def add_document(
        self, 
        content: str, 
        metadata: Dict[str, Any],
        collection_name: str = 'documents'
    ) -> str:
        """Add document to vector store"""
        
        # Generate embedding using local model
        embeddings = await self.ollama_manager.generate_embeddings([content])
        
        # Create unique ID
        doc_id = str(uuid.uuid4())
        
        # Add timestamp
        metadata['timestamp'] = datetime.now().isoformat()
        metadata['content_length'] = len(content)
        
        # Store in ChromaDB
        self.collections[collection_name].add(
            documents=[content],
            embeddings=embeddings,
            metadatas=[metadata],
            ids=[doc_id]
        )
        
        logging.info(f"Added document {doc_id} to {collection_name}")
        return doc_id
    
    async def similarity_search(
        self,
        query: str,
        collection_name: str = 'documents',
        n_results: int = 5,
        filter_metadata: Optional[Dict] = None
    ) -> List[Dict[str, Any]]:
        """Search for similar documents"""
        
        # Generate query embedding
        query_embeddings = await self.ollama_manager.generate_embeddings([query])
        
        # Search in ChromaDB
        results = self.collections[collection_name].query(
            query_embeddings=query_embeddings,
            n_results=n_results,
            where=filter_metadata
        )
        
        # Format results
        formatted_results = []
        for i in range(len(results['ids'][0])):
            formatted_results.append({
                'id': results['ids'][0][i],
                'content': results['documents'][0][i],
                'metadata': results['metadatas'][0][i],
                'similarity_score': 1 - results['distances'][0][i]  # Convert distance to similarity
            })
        
        return formatted_results
    
    async def get_relevant_context(
        self,
        query: str,
        max_tokens: int = 2000
    ) -> str:
        """Get relevant context for RAG"""
        
        # Search across all collections
        all_results = []
        
        for collection_name in self.collections.keys():
            try:
                results = await self.similarity_search(
                    query=query,
                    collection_name=collection_name,
                    n_results=3
                )
                
                for result in results:
                    result['collection'] = collection_name
                    all_results.append(result)
            except Exception as e:
                logging.warning(f"Search failed in {collection_name}: {e}")
        
        # Sort by similarity score
        all_results.sort(key=lambda x: x['similarity_score'], reverse=True)
        
        # Build context within token limit
        context_parts = []
        current_tokens = 0
        
        for result in all_results:
            content = result['content']
            # Rough token estimation (1 token ≈ 4 characters)
            content_tokens = len(content) // 4
            
            if current_tokens + content_tokens > max_tokens:
                break
            
            source_info = f"[Source: {result['collection']}]"
            context_parts.append(f"{source_info}\n{content}")
            current_tokens += content_tokens
        
        return "\n\n".join(context_parts)

# RAG Integration
class RAGGenerator:
    def __init__(self):
        self.vector_store = VectorStore()
        self.ollama_manager = OllamaManager()
    
    async def generate_with_context(
        self,
        query: str,
        system_prompt: Optional[str] = None
    ) -> str:
        """Generate response using RAG"""
        
        # Get relevant context
        context = await self.vector_store.get_relevant_context(query)
        
        # Build prompt with context
        if context:
            enhanced_prompt = f"""Context Information:
{context}

User Query: {query}

Please provide a comprehensive answer based on the context information provided. If the context doesn't contain relevant information, clearly state that and provide a general response."""
        else:
            enhanced_prompt = f"""User Query: {query}

No specific context information was found. Please provide a helpful general response."""
        
        # Generate response
        response = await self.ollama_manager.generate_response(
            prompt=enhanced_prompt,
            system_prompt=system_prompt
        )
        
        return response

# Usage example
async def demo_rag():
    rag_gen = RAGGenerator()
    
    # Add some sample documents
    await rag_gen.vector_store.add_document(
        content="AlexAI is a personal assistant that prioritizes user privacy by running all AI models locally. It uses Ollama for LLM hosting and ChromaDB for vector storage.",
        metadata={"type": "system_info", "category": "documentation"}
    )
    
    # Query with RAG
    response = await rag_gen.generate_with_context(
        "What is AlexAI and how does it protect privacy?"
    )
    
    print("RAG Response:", response)

if __name__ == "__main__":
    asyncio.run(demo_rag())

Why this RAG implementation?