from fastapi import UploadFile, File, Form, HTTPException, APIRouter from typing import List, Optional, Dict, Tuple import lancedb from lancedb.pydantic import LanceModel, Vector from lancedb.embeddings import get_registry import pandas as pd from utils import process_pdf_to_chunks import hashlib import uuid import json from datetime import datetime from pydantic import BaseModel import logging # Create router router = APIRouter( prefix="/rag", tags=["rag"] ) # Initialize LanceDB and embedding model db = lancedb.connect("/tmp/db") model = get_registry().get("sentence-transformers").create( name="Snowflake/snowflake-arctic-embed-xs", device="cpu" ) def get_user_collection(user_id: str, collection_name: str) -> str: """Generate user-specific collection name""" return f"{user_id}_{collection_name}" class DocumentChunk(LanceModel): text: str = model.SourceField() vector: Vector(model.ndims()) = model.VectorField() document_id: str chunk_index: int file_name: str file_type: str created_date: str collection_id: str user_id: str metadata_json: str char_start: int char_end: int page_numbers: List[int] images: List[str] class QueryInput(BaseModel): collection_id: str query: str top_k: Optional[int] = 3 user_id: str class SearchResult(BaseModel): text: str distance: float metadata: Dict # Added metadata field class SearchResponse(BaseModel): results: List[SearchResult] async def process_file(file: UploadFile, collection_id: str, user_id: str) -> Tuple[List[dict], str]: """Process single file and return chunks with metadata""" content = await file.read() file_type = file.filename.split('.')[-1].lower() chunks = [] doc_id = "" if file_type == 'pdf': chunks, doc_id = process_pdf_to_chunks( pdf_content=content, file_name=file.filename ) elif file_type == 'txt': doc_id = hashlib.sha256(content).hexdigest()[:4] text_content = content.decode('utf-8') chunks = [{ "text": text_content, "metadata": { "created_date": datetime.now().isoformat(), "file_name": file.filename, "document_id": doc_id, "user_id": user_id, "location": { "chunk_index": 0, "char_start": 0, "char_end": len(text_content), "pages": [1], "total_chunks": 1 }, "images": [] } }] return chunks, doc_id @router.post("/upload_files") async def upload_files( files: List[UploadFile] = File(...), collection_name: Optional[str] = Form(None), user_id: str = Form(...) ): try: collection_id = get_user_collection( user_id, collection_name if collection_name else f"col_{uuid.uuid4().hex[:8]}" ) all_chunks = [] doc_ids = {} for file in files: try: chunks, doc_id = await process_file(file, collection_id, user_id) for chunk in chunks: chunk_data = { "text": chunk["text"], "document_id": chunk["metadata"]["document_id"], "chunk_index": chunk["metadata"]["location"]["chunk_index"], "file_name": chunk["metadata"]["file_name"], "file_type": file.filename.split('.')[-1].lower(), "created_date": chunk["metadata"]["created_date"], "collection_id": collection_id, "user_id": user_id, "metadata_json": json.dumps(chunk["metadata"]), "char_start": chunk["metadata"]["location"]["char_start"], "char_end": chunk["metadata"]["location"]["char_end"], "page_numbers": chunk["metadata"]["location"]["pages"], "images": chunk["metadata"].get("images", []) } all_chunks.append(chunk_data) doc_ids[doc_id] = file.filename except Exception as e: logging.error(f"Error processing file {file.filename}: {str(e)}") raise HTTPException( status_code=400, detail=f"Error processing file {file.filename}: {str(e)}" ) try: table = db.open_table(collection_id) except Exception as e: logging.error(f"Error opening table: {str(e)}") try: table = db.create_table( collection_id, schema=DocumentChunk, mode="create" ) # Create FTS index on the text column for hybrid search support # table.create_fts_index( # field_names="text", # replace=True, # tokenizer_name="en_stem", # Use English stemming # lower_case=True, # Convert text to lowercase # remove_stop_words=True, # Remove common words like "the", "is", "at" # writer_heap_size=1024 * 1024 * 1024 # 1GB heap size # ) except Exception as e: logging.error(f"Error creating table: {str(e)}") raise HTTPException( status_code=500, detail=f"Error creating database table: {str(e)}" ) try: df = pd.DataFrame(all_chunks) table.add(data=df) except Exception as e: logging.error(f"Error adding data to table: {str(e)}") raise HTTPException( status_code=500, detail=f"Error adding data to database: {str(e)}" ) return { "message": f"Successfully processed {len(files)} files", "collection_id": collection_id, "total_chunks": len(all_chunks), "user_id": user_id, "document_ids": doc_ids } except HTTPException: raise except Exception as e: logging.error(f"Unexpected error during file upload: {str(e)}") raise HTTPException( status_code=500, detail=f"Unexpected error: {str(e)}" ) @router.get("/get_document/{collection_id}/{document_id}") async def get_document( collection_id: str, document_id: str, user_id: str ): try: table = db.open_table(f"{user_id}_{collection_id}") except Exception as e: logging.error(f"Error opening table: {str(e)}") raise HTTPException( status_code=404, detail=f"Collection not found: {str(e)}" ) try: chunks = table.to_pandas() doc_chunks = chunks[ (chunks['document_id'] == document_id) & (chunks['user_id'] == user_id) ].sort_values('chunk_index') if len(doc_chunks) == 0: raise HTTPException( status_code=404, detail=f"Document {document_id} not found in collection {collection_id}" ) return { "document_id": document_id, "file_name": doc_chunks.iloc[0]['file_name'], "chunks": [ { "text": row['text'], "metadata": json.loads(row['metadata_json']) } for _, row in doc_chunks.iterrows() ] } except HTTPException: raise except Exception as e: logging.error(f"Error retrieving document: {str(e)}") raise HTTPException( status_code=500, detail=f"Error retrieving document: {str(e)}" ) @router.post("/query_collection", response_model=SearchResponse) async def query_collection(input_data: QueryInput): try: collection_id = get_user_collection(input_data.user_id, input_data.collection_id) try: table = db.open_table(collection_id) except Exception as e: logging.error(f"Error opening table: {str(e)}") raise HTTPException( status_code=404, detail=f"Collection not found: {str(e)}" ) try: results = ( table.search(input_data.query) .where(f"user_id = '{input_data.user_id}'") .limit(input_data.top_k) .to_list() ) except Exception as e: logging.error(f"Error searching collection: {str(e)}") raise HTTPException( status_code=500, detail=f"Error searching collection: {str(e)}" ) return SearchResponse(results=[ SearchResult( text=r['text'], distance=float(r['_distance']), metadata=json.loads(r['metadata_json']) ) for r in results ]) except HTTPException: raise except Exception as e: logging.error(f"Unexpected error during query: {str(e)}") raise HTTPException( status_code=500, detail=f"Unexpected error: {str(e)}" ) @router.get("/list_collections") async def list_collections(user_id: str): try: all_collections = db.table_names() user_collections = [ c for c in all_collections if c.startswith(f"{user_id}_") ] # Get documents for each collection collections_info = [] for collection_name in user_collections: try: table = db.open_table(collection_name) df = table.to_pandas() # Group by document_id to get unique documents documents = df.groupby('document_id').agg({ 'file_name': 'first', 'created_date': 'first' }).reset_index() collections_info.append({ "collection_id": collection_name.replace(f"{user_id}_", ""), "documents": [ { "document_id": row['document_id'], "file_name": row['file_name'], "created_date": row['created_date'] } for _, row in documents.iterrows() ] }) except Exception as e: logging.error(f"Error processing collection {collection_name}: {str(e)}") continue return {"collections": collections_info} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.delete("/delete_collection/{collection_id}") async def delete_collection(collection_id: str, user_id: str): try: full_collection_id = f"{user_id}_{collection_id}" # Check if collection exists try: table = db.open_table(full_collection_id) except Exception as e: logging.error(f"Collection not found: {str(e)}") raise HTTPException( status_code=404, detail=f"Collection {collection_id} not found" ) # Verify ownership if not full_collection_id.startswith(f"{user_id}_"): logging.error(f"Unauthorized deletion attempt for collection {collection_id} by user {user_id}") raise HTTPException( status_code=403, detail="Not authorized to delete this collection" ) try: db.drop_table(full_collection_id) except Exception as e: logging.error(f"Error deleting collection {collection_id}: {str(e)}") raise HTTPException( status_code=500, detail=f"Error deleting collection: {str(e)}" ) return { "message": f"Collection {collection_id} deleted successfully", "collection_id": collection_id } except HTTPException: raise except Exception as e: logging.error(f"Unexpected error deleting collection {collection_id}: {str(e)}") raise HTTPException( status_code=500, detail=f"Unexpected error: {str(e)}" ) @router.post("/get_collection_files") def get_collection_files(collection_id: str, user_id: str) -> str: """Get list of files in the specified collection""" try: # Get the full collection name collection_name = f"{user_id}_{collection_id}" # Open the table and convert to pandas table = db.open_table(collection_name) df = table.to_pandas() logging.info(f"fetched chunks {str(df.head())}") # Get unique file names unique_files = df['file_name'].unique() # Join the file names into a string return ", ".join(unique_files) except Exception as e: logging.error(f"Error getting collection files: {str(e)}") return f"Error getting files: {str(e)}" @router.post("/query_collection_tool") async def query_collection_tool(input_data: QueryInput): try: response = await query_collection(input_data) results = [] # Access response directly since it's a Pydantic model for r in response.results: result_dict = { "text": r.text, "distance": r.distance, "metadata": { "document_id": r.metadata.get("document_id"), "chunk_index": r.metadata.get("location", {}).get("chunk_index") } } results.append(result_dict) return str(results) except Exception as e: logging.error(f"Unexpected error during query: {str(e)}") raise HTTPException( status_code=500, detail=f"Unexpected error: {str(e)}" )