rag_chat_with_analytics / rag_routerv2.py
pvanand's picture
Upload 11 files
1a6d961 verified
raw
history blame
6.42 kB
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
import pandas as pd
import lancedb
from functools import cached_property, lru_cache
from pydantic import Field, BaseModel
from typing import Optional, Dict, List, Annotated, Any
from fastapi import APIRouter
import uuid
import io
from io import BytesIO
import csv
# LlamaIndex imports
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.core import StorageContext, load_index_from_storage
import json
import os
import shutil
router = APIRouter(
prefix="/rag",
tags=["rag"]
)
# Configure global LlamaIndex settings
Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")
tables_file_path = './data/tables.json'
# Database connection dependency
@lru_cache()
def get_db_connection(db_path: str = "./lancedb/dev"):
return lancedb.connect(db_path)
# Pydantic models
class CreateTableResponse(BaseModel):
table_id: str
message: str
status: str
class QueryTableResponse(BaseModel):
results: Dict[str, Any]
total_results: int
@router.post("/create_table", response_model=CreateTableResponse)
async def create_embedding_table(
user_id: str,
files: List[UploadFile] = File(...),
table_id: Optional[str] = None
) -> CreateTableResponse:
"""Create a table and load embeddings from uploaded files using LlamaIndex."""
allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"}
for file in files:
if file.filename is None:
raise HTTPException(status_code=400, detail="File must have a valid name.")
file_extension = os.path.splitext(file.filename)[1].lower()
if file_extension not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"File type {file_extension} is not allowed. Supported file types are: {', '.join(allowed_extensions)}."
)
if table_id is None:
table_id = str(uuid.uuid4())
table_name = table_id #f"{user_id}__table__{table_id}"
# Create a directory for the uploaded files
directory_path = f"./data/{table_name}"
os.makedirs(directory_path, exist_ok=True)
# Save each uploaded file to the data directory
for file in files:
file_path = os.path.join(directory_path, file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Store user_id and table_name in a JSON file
try:
tables_file_path = './data/tables.json'
os.makedirs(os.path.dirname(tables_file_path), exist_ok=True)
# Load existing tables or create a new file if it doesn't exist
try:
with open(tables_file_path, 'r') as f:
tables = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
tables = {}
# Update the tables dictionary
if user_id not in tables:
tables[user_id] = []
if table_name not in tables[user_id]:
tables[user_id].append(table_name)
# Write the updated tables back to the JSON file
with open(tables_file_path, 'w') as f:
json.dump(tables, f)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update tables file: {str(e)}")
try:
# Setup LanceDB vector store
vector_store = LanceDBVectorStore(
uri="./lancedb/dev",
table_name=table_name,
# mode="overwrite",
# query_type="vector"
)
# Load documents using SimpleDirectoryReader
documents = SimpleDirectoryReader(directory_path).load_data()
# Create the index
index = VectorStoreIndex.from_documents(
documents,
vector_store=vector_store
)
index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
return CreateTableResponse(
table_id=table_id,
message=f"Table created and documents indexed successfully",
status="success"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Table creation failed: {str(e)}")
@router.post("/query_table/{table_id}", response_model=QueryTableResponse)
async def query_table(
table_id: str,
query: str,
user_id: str,
#db: Annotated[Any, Depends(get_db_connection)],
limit: Optional[int] = 10
) -> QueryTableResponse:
"""Query the database table using LlamaIndex."""
try:
table_name = table_id #f"{user_id}__table__{table_id}"
# load index and retriever
storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}")
index = load_index_from_storage(storage_context)
retriever = index.as_retriever(similarity_top_k=limit)
# Get response
response = retriever.retrieve(query)
# Format results
results = [{
'text': node.text,
'score': node.score
} for node in response]
return QueryTableResponse(
results={'data': results},
total_results=len(results)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")
@router.get("/get_tables/{user_id}")
async def get_tables(user_id: str):
"""Get all tables for a user."""
tables_file_path = './data/tables.json'
try:
# Load existing tables from the JSON file
with open(tables_file_path, 'r') as f:
tables = json.load(f)
# Retrieve tables for the specified user
user_tables = tables.get(user_id, [])
return user_tables
except (FileNotFoundError, json.JSONDecodeError):
return [] # Return an empty list if the file doesn't exist or is invalid
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}")
@router.get("/health")
async def health_check():
return {"status": "healthy"}