Spaces:
Running
Running
File size: 6,422 Bytes
1a6d961 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
import pandas as pd
import lancedb
from functools import cached_property, lru_cache
from pydantic import Field, BaseModel
from typing import Optional, Dict, List, Annotated, Any
from fastapi import APIRouter
import uuid
import io
from io import BytesIO
import csv
# LlamaIndex imports
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.core import StorageContext, load_index_from_storage
import json
import os
import shutil
router = APIRouter(
prefix="/rag",
tags=["rag"]
)
# Configure global LlamaIndex settings
Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")
tables_file_path = './data/tables.json'
# Database connection dependency
@lru_cache()
def get_db_connection(db_path: str = "./lancedb/dev"):
return lancedb.connect(db_path)
# Pydantic models
class CreateTableResponse(BaseModel):
table_id: str
message: str
status: str
class QueryTableResponse(BaseModel):
results: Dict[str, Any]
total_results: int
@router.post("/create_table", response_model=CreateTableResponse)
async def create_embedding_table(
user_id: str,
files: List[UploadFile] = File(...),
table_id: Optional[str] = None
) -> CreateTableResponse:
"""Create a table and load embeddings from uploaded files using LlamaIndex."""
allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"}
for file in files:
if file.filename is None:
raise HTTPException(status_code=400, detail="File must have a valid name.")
file_extension = os.path.splitext(file.filename)[1].lower()
if file_extension not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"File type {file_extension} is not allowed. Supported file types are: {', '.join(allowed_extensions)}."
)
if table_id is None:
table_id = str(uuid.uuid4())
table_name = table_id #f"{user_id}__table__{table_id}"
# Create a directory for the uploaded files
directory_path = f"./data/{table_name}"
os.makedirs(directory_path, exist_ok=True)
# Save each uploaded file to the data directory
for file in files:
file_path = os.path.join(directory_path, file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Store user_id and table_name in a JSON file
try:
tables_file_path = './data/tables.json'
os.makedirs(os.path.dirname(tables_file_path), exist_ok=True)
# Load existing tables or create a new file if it doesn't exist
try:
with open(tables_file_path, 'r') as f:
tables = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
tables = {}
# Update the tables dictionary
if user_id not in tables:
tables[user_id] = []
if table_name not in tables[user_id]:
tables[user_id].append(table_name)
# Write the updated tables back to the JSON file
with open(tables_file_path, 'w') as f:
json.dump(tables, f)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to update tables file: {str(e)}")
try:
# Setup LanceDB vector store
vector_store = LanceDBVectorStore(
uri="./lancedb/dev",
table_name=table_name,
# mode="overwrite",
# query_type="vector"
)
# Load documents using SimpleDirectoryReader
documents = SimpleDirectoryReader(directory_path).load_data()
# Create the index
index = VectorStoreIndex.from_documents(
documents,
vector_store=vector_store
)
index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
return CreateTableResponse(
table_id=table_id,
message=f"Table created and documents indexed successfully",
status="success"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Table creation failed: {str(e)}")
@router.post("/query_table/{table_id}", response_model=QueryTableResponse)
async def query_table(
table_id: str,
query: str,
user_id: str,
#db: Annotated[Any, Depends(get_db_connection)],
limit: Optional[int] = 10
) -> QueryTableResponse:
"""Query the database table using LlamaIndex."""
try:
table_name = table_id #f"{user_id}__table__{table_id}"
# load index and retriever
storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}")
index = load_index_from_storage(storage_context)
retriever = index.as_retriever(similarity_top_k=limit)
# Get response
response = retriever.retrieve(query)
# Format results
results = [{
'text': node.text,
'score': node.score
} for node in response]
return QueryTableResponse(
results={'data': results},
total_results=len(results)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")
@router.get("/get_tables/{user_id}")
async def get_tables(user_id: str):
"""Get all tables for a user."""
tables_file_path = './data/tables.json'
try:
# Load existing tables from the JSON file
with open(tables_file_path, 'r') as f:
tables = json.load(f)
# Retrieve tables for the specified user
user_tables = tables.get(user_id, [])
return user_tables
except (FileNotFoundError, json.JSONDecodeError):
return [] # Return an empty list if the file doesn't exist or is invalid
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}")
@router.get("/health")
async def health_check():
return {"status": "healthy"} |