from fastapi import FastAPI, Depends, HTTPException, UploadFile, File import pandas as pd import lancedb from functools import cached_property, lru_cache from pydantic import Field, BaseModel from typing import Optional, Dict, List, Annotated, Any from fastapi import APIRouter import uuid import io from io import BytesIO import csv # LlamaIndex imports from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex from llama_index.vector_stores.lancedb import LanceDBVectorStore from llama_index.embeddings.fastembed import FastEmbedEmbedding from llama_index.core.schema import TextNode from llama_index.core import StorageContext, load_index_from_storage import json import os import shutil router = APIRouter( prefix="/rag", tags=["rag"] ) # Configure global LlamaIndex settings Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5") tables_file_path = './data/tables.json' # Database connection dependency @lru_cache() def get_db_connection(db_path: str = "./lancedb/dev"): return lancedb.connect(db_path) # Pydantic models class CreateTableResponse(BaseModel): table_id: str message: str status: str table_name: str class QueryTableResponse(BaseModel): results: Dict[str, Any] total_results: int @router.post("/create_table", response_model=CreateTableResponse) async def create_embedding_table( user_id: str, files: List[UploadFile] = File(...), table_id: Optional[str] = None, table_name: Optional[str] = None ) -> CreateTableResponse: """Create a table and load embeddings from uploaded files using LlamaIndex.""" allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"} for file in files: if file.filename is None: raise HTTPException(status_code=400, detail="File must have a valid name.") file_extension = os.path.splitext(file.filename)[1].lower() if file_extension not in allowed_extensions: raise HTTPException( status_code=400, detail=f"File type {file_extension} is not allowed. Supported file types are: {', '.join(allowed_extensions)}." ) if table_id is None: table_id = str(uuid.uuid4()) table_name = f"knowledge-base-{str(uuid.uuid4())[:4]}" if not table_name else table_name #table_name = table_id #f"{user_id}__table__{table_id}" # Create a directory for the uploaded files directory_path = f"./data/{table_id}" os.makedirs(directory_path, exist_ok=True) # Save each uploaded file to the data directory for file in files: file_path = os.path.join(directory_path, file.filename) with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) try: # Setup LanceDB vector store vector_store = LanceDBVectorStore( uri="./lancedb/dev", table_name=table_id, mode="overwrite", query_type="hybrid" ) # Load documents using SimpleDirectoryReader documents = SimpleDirectoryReader(directory_path).load_data() # Create the index index = VectorStoreIndex.from_documents( documents, vector_store=vector_store ) index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}") # Store user_id and table_name in a JSON file try: tables_file_path = './data/tables.json' os.makedirs(os.path.dirname(tables_file_path), exist_ok=True) # Load existing tables or create a new file if it doesn't exist try: with open(tables_file_path, 'r') as f: tables = json.load(f) except (FileNotFoundError, json.JSONDecodeError): tables = {} # Update the tables dictionary if user_id not in tables: tables[user_id] = [] if table_id not in [table['table_id'] for table in tables[user_id]]: tables[user_id].append({"table_id": table_id, "table_name": table_name}) # Write the updated tables back to the JSON file with open(tables_file_path, 'w') as f: json.dump(tables, f) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to update tables file: {str(e)}") return CreateTableResponse( table_id=table_id, message="Table created and documents indexed successfully", status="success", table_name=table_name ) except Exception as e: raise HTTPException(status_code=500, detail=f"Table creation failed: {str(e)}") @router.post("/query_table/{table_id}", response_model=QueryTableResponse) async def query_table( table_id: str, query: str, user_id: str, #db: Annotated[Any, Depends(get_db_connection)], limit: Optional[int] = 10 ) -> QueryTableResponse: """Query the database table using LlamaIndex.""" try: table_name = table_id #f"{user_id}__table__{table_id}" # load index and retriever storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}") index = load_index_from_storage(storage_context) retriever = index.as_retriever(similarity_top_k=limit) # Get response response = retriever.retrieve(query) # Format results results = [{ 'text': node.text, 'score': node.score } for node in response] return QueryTableResponse( results={'data': results}, total_results=len(results) ) except Exception as e: raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") @router.get("/get_tables/{user_id}") async def get_tables(user_id: str): """Get all tables for a user.""" tables_file_path = './data/tables.json' try: # Load existing tables from the JSON file with open(tables_file_path, 'r') as f: tables = json.load(f) # Retrieve tables for the specified user user_tables = tables.get(user_id, []) return user_tables except (FileNotFoundError, json.JSONDecodeError): return [] # Return an empty list if the file doesn't exist or is invalid except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}") @router.get("/health") async def health_check(): return {"status": "healthy"} @router.on_event("startup") async def startup(): print("RAG Router started") from llama_index.core.schema import TextNode table_name = "digiyatra" nodes = [] vector_store = LanceDBVectorStore( uri="./lancedb/dev", table_name=table_name, mode="overwrite", query_type="hybrid" ) # load digiyatra csv and create node for each row using csv.reader with open('combined_digi_yatra.csv', newline='') as f: reader = csv.reader(f) data = list(reader) for row in data[1:]: node = TextNode(text=str(row), id_=str(uuid.uuid4())) nodes.append(node) index = VectorStoreIndex(nodes, vector_store=vector_store) index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}") # Create tables dictionary tables = {} user_id = "digiyatra" tables[user_id] = [ { "table_id": table_name, "table_name": table_name } ] with open(tables_file_path, 'w') as f: json.dump(tables, f) @router.on_event("shutdown") async def shutdown(): print("RAG Router shutdown")