from fastapi import FastAPI, Depends, HTTPException, UploadFile, File import pandas as pd import lancedb from functools import cached_property, lru_cache from pydantic import Field, BaseModel from typing import Optional, Dict, List, Annotated, Any from fastapi import APIRouter import uuid import io from io import BytesIO import csv import sqlite3 # LlamaIndex imports from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex from llama_index.vector_stores.lancedb import LanceDBVectorStore from llama_index.embeddings.fastembed import FastEmbedEmbedding from llama_index.core.schema import TextNode from llama_index.core import StorageContext, load_index_from_storage import json import os import shutil router = APIRouter( prefix="/rag", tags=["rag"] ) # Configure global LlamaIndex settings Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5") # Database connection dependency @lru_cache() def get_db_connection(db_path: str = "./lancedb/dev"): return lancedb.connect(db_path) def get_db(): conn = sqlite3.connect('./data/tables.db') conn.row_factory = sqlite3.Row return conn def init_db(): db = get_db() db.execute(''' CREATE TABLE IF NOT EXISTS tables ( id INTEGER PRIMARY KEY, user_id TEXT NOT NULL, table_id TEXT NOT NULL, table_name TEXT NOT NULL, created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') db.execute(''' CREATE TABLE IF NOT EXISTS table_files ( id INTEGER PRIMARY KEY, table_id TEXT NOT NULL, filename TEXT NOT NULL, file_path TEXT NOT NULL, FOREIGN KEY (table_id) REFERENCES tables (table_id) ) ''') db.commit() # Pydantic models class CreateTableResponse(BaseModel): table_id: str message: str status: str table_name: str class QueryTableResponse(BaseModel): results: Dict[str, Any] total_results: int @router.post("/create_table", response_model=CreateTableResponse) async def create_embedding_table( user_id: str, files: List[UploadFile] = File(...), table_id: Optional[str] = None, table_name: Optional[str] = None ) -> CreateTableResponse: allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"} for file in files: if not file.filename: raise HTTPException(status_code=400, detail="Invalid filename") if os.path.splitext(file.filename)[1].lower() not in allowed_extensions: raise HTTPException(status_code=400, detail="Unsupported file type") table_id = table_id or str(uuid.uuid4()) table_name = table_name or f"knowledge-base-{str(uuid.uuid4())[:4]}" directory_path = f"./data/{table_id}" os.makedirs(directory_path, exist_ok=True) for file in files: file_path = os.path.join(directory_path, file.filename) with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) try: vector_store = LanceDBVectorStore( uri="./lancedb/dev", table_name=table_id, mode="overwrite", query_type="hybrid" ) documents = SimpleDirectoryReader(directory_path).load_data() index = VectorStoreIndex.from_documents(documents, vector_store=vector_store) index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}") db = get_db() db.execute( 'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)', (user_id, table_id, table_name) ) for file in files: db.execute( 'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)', (table_id, file.filename, f"./data/{table_id}/{file.filename}") ) db.commit() return CreateTableResponse( table_id=table_id, message="Success", status="success", table_name=table_name ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/query_table/{table_id}", response_model=QueryTableResponse) async def query_table( table_id: str, query: str, user_id: str, #db: Annotated[Any, Depends(get_db_connection)], limit: Optional[int] = 10 ) -> QueryTableResponse: """Query the database table using LlamaIndex.""" try: table_name = table_id #f"{user_id}__table__{table_id}" # load index and retriever storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}") index = load_index_from_storage(storage_context) retriever = index.as_retriever(similarity_top_k=limit) # Get response response = retriever.retrieve(query) # Format results results = [{ 'text': node.text, 'score': node.score } for node in response] return QueryTableResponse( results={'data': results}, total_results=len(results) ) except Exception as e: raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") @router.get("/get_tables/{user_id}") async def get_tables(user_id: str): db = get_db() tables = db.execute(''' SELECT t.*, t.created_time, GROUP_CONCAT(tf.filename) as filenames, GROUP_CONCAT(tf.file_path) as file_paths FROM tables t LEFT JOIN table_files tf ON t.table_id = tf.table_id WHERE t.user_id = ? GROUP BY t.table_id ''', (user_id,)).fetchall() result = [] for table in tables: table_dict = dict(table) table_dict['files'] = [ {'filename': f, 'file_path': p} for f, p in zip( table_dict.pop('filenames').split(',') if table_dict['filenames'] else [], table_dict.pop('file_paths').split(',') if table_dict['file_paths'] else [] ) ] result.append(table_dict) return result @router.delete("/delete_table/{table_id}") async def delete_table(table_id: str, user_id: str): try: db = get_db() # Verify user owns the table table = db.execute( 'SELECT * FROM tables WHERE table_id = ? AND user_id = ?', (table_id, user_id) ).fetchone() if not table: raise HTTPException(status_code=404, detail="Table not found or unauthorized") # Delete files from filesystem table_path = f"./data/{table_id}" index_path = f"./lancedb/index/{table_id}" if os.path.exists(table_path): shutil.rmtree(table_path) if os.path.exists(index_path): shutil.rmtree(index_path) # Delete from database db.execute('DELETE FROM table_files WHERE table_id = ?', (table_id,)) db.execute('DELETE FROM tables WHERE table_id = ?', (table_id,)) db.commit() return {"message": "Table deleted successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/health") async def health_check(): return {"status": "healthy"} @router.on_event("startup") async def startup(): init_db() print("RAG Router started") table_name = "digiyatra" user_id = "digiyatra" # Create vector store and index vector_store = LanceDBVectorStore( uri="./lancedb/dev", table_name=table_name, mode="overwrite", query_type="hybrid" ) # Load CSV and create nodes with open('combined_digi_yatra.csv', newline='') as f: nodes = [ TextNode(text=str(row), id_=str(uuid.uuid4())) for row in list(csv.reader(f))[1:] ] # Create and persist index index = VectorStoreIndex(nodes, vector_store=vector_store) index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}") # Store in SQLite db = get_db() db.execute( 'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)', (user_id, table_name, table_name) ) db.execute( 'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)', (table_name, 'combined_digi_yatra.csv', 'combined_digi_yatra.csv') ) db.commit() @router.on_event("shutdown") async def shutdown(): print("RAG Router shutdown")