Spaces:
Running
Running
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File | |
import pandas as pd | |
import lancedb | |
from functools import cached_property, lru_cache | |
from pydantic import Field, BaseModel | |
from typing import Optional, Dict, List, Annotated, Any | |
from fastapi import APIRouter | |
import uuid | |
import io | |
from io import BytesIO | |
import csv | |
# LlamaIndex imports | |
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex | |
from llama_index.vector_stores.lancedb import LanceDBVectorStore | |
from llama_index.embeddings.fastembed import FastEmbedEmbedding | |
from llama_index.core import StorageContext, load_index_from_storage | |
import json | |
import os | |
import shutil | |
router = APIRouter( | |
prefix="/rag", | |
tags=["rag"] | |
) | |
# Configure global LlamaIndex settings | |
Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5") | |
tables_file_path = './data/tables.json' | |
# Database connection dependency | |
def get_db_connection(db_path: str = "./lancedb/dev"): | |
return lancedb.connect(db_path) | |
# Pydantic models | |
class CreateTableResponse(BaseModel): | |
table_id: str | |
message: str | |
status: str | |
class QueryTableResponse(BaseModel): | |
results: Dict[str, Any] | |
total_results: int | |
async def create_embedding_table( | |
user_id: str, | |
files: List[UploadFile] = File(...), | |
table_id: Optional[str] = None | |
) -> CreateTableResponse: | |
"""Create a table and load embeddings from uploaded files using LlamaIndex.""" | |
allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"} | |
for file in files: | |
if file.filename is None: | |
raise HTTPException(status_code=400, detail="File must have a valid name.") | |
file_extension = os.path.splitext(file.filename)[1].lower() | |
if file_extension not in allowed_extensions: | |
raise HTTPException( | |
status_code=400, | |
detail=f"File type {file_extension} is not allowed. Supported file types are: {', '.join(allowed_extensions)}." | |
) | |
if table_id is None: | |
table_id = str(uuid.uuid4()) | |
table_name = table_id #f"{user_id}__table__{table_id}" | |
# Create a directory for the uploaded files | |
directory_path = f"./data/{table_name}" | |
os.makedirs(directory_path, exist_ok=True) | |
# Save each uploaded file to the data directory | |
for file in files: | |
file_path = os.path.join(directory_path, file.filename) | |
with open(file_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
# Store user_id and table_name in a JSON file | |
try: | |
tables_file_path = './data/tables.json' | |
os.makedirs(os.path.dirname(tables_file_path), exist_ok=True) | |
# Load existing tables or create a new file if it doesn't exist | |
try: | |
with open(tables_file_path, 'r') as f: | |
tables = json.load(f) | |
except (FileNotFoundError, json.JSONDecodeError): | |
tables = {} | |
# Update the tables dictionary | |
if user_id not in tables: | |
tables[user_id] = [] | |
if table_name not in tables[user_id]: | |
tables[user_id].append(table_name) | |
# Write the updated tables back to the JSON file | |
with open(tables_file_path, 'w') as f: | |
json.dump(tables, f) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Failed to update tables file: {str(e)}") | |
try: | |
# Setup LanceDB vector store | |
vector_store = LanceDBVectorStore( | |
uri="./lancedb/dev", | |
table_name=table_name, | |
# mode="overwrite", | |
# query_type="vector" | |
) | |
# Load documents using SimpleDirectoryReader | |
documents = SimpleDirectoryReader(directory_path).load_data() | |
# Create the index | |
index = VectorStoreIndex.from_documents( | |
documents, | |
vector_store=vector_store | |
) | |
index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}") | |
return CreateTableResponse( | |
table_id=table_id, | |
message=f"Table created and documents indexed successfully", | |
status="success" | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Table creation failed: {str(e)}") | |
async def query_table( | |
table_id: str, | |
query: str, | |
user_id: str, | |
#db: Annotated[Any, Depends(get_db_connection)], | |
limit: Optional[int] = 10 | |
) -> QueryTableResponse: | |
"""Query the database table using LlamaIndex.""" | |
try: | |
table_name = table_id #f"{user_id}__table__{table_id}" | |
# load index and retriever | |
storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}") | |
index = load_index_from_storage(storage_context) | |
retriever = index.as_retriever(similarity_top_k=limit) | |
# Get response | |
response = retriever.retrieve(query) | |
# Format results | |
results = [{ | |
'text': node.text, | |
'score': node.score | |
} for node in response] | |
return QueryTableResponse( | |
results={'data': results}, | |
total_results=len(results) | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") | |
async def get_tables(user_id: str): | |
"""Get all tables for a user.""" | |
tables_file_path = './data/tables.json' | |
try: | |
# Load existing tables from the JSON file | |
with open(tables_file_path, 'r') as f: | |
tables = json.load(f) | |
# Retrieve tables for the specified user | |
user_tables = tables.get(user_id, []) | |
return user_tables | |
except (FileNotFoundError, json.JSONDecodeError): | |
return [] # Return an empty list if the file doesn't exist or is invalid | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}") | |
async def health_check(): | |
return {"status": "healthy"} |