File size: 6,422 Bytes
1a6d961
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
import pandas as pd
import lancedb
from functools import cached_property, lru_cache
from pydantic import Field, BaseModel
from typing import Optional, Dict, List, Annotated, Any
from fastapi import APIRouter
import uuid
import io
from io import BytesIO
import csv

# LlamaIndex imports
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.core import StorageContext, load_index_from_storage
import json
import os
import shutil

router = APIRouter(
    prefix="/rag",
    tags=["rag"]
)

# Configure global LlamaIndex settings
Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")
tables_file_path = './data/tables.json'

# Database connection dependency
@lru_cache()
def get_db_connection(db_path: str = "./lancedb/dev"):
    return lancedb.connect(db_path)

# Pydantic models
class CreateTableResponse(BaseModel):
    table_id: str
    message: str
    status: str

class QueryTableResponse(BaseModel):
    results: Dict[str, Any]
    total_results: int


@router.post("/create_table", response_model=CreateTableResponse)
async def create_embedding_table(

    user_id: str,

    files: List[UploadFile] = File(...),

    table_id: Optional[str] = None

) -> CreateTableResponse:
    """Create a table and load embeddings from uploaded files using LlamaIndex."""
    allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"}
    for file in files:
        if file.filename is None:
            raise HTTPException(status_code=400, detail="File must have a valid name.")
        file_extension = os.path.splitext(file.filename)[1].lower()
        if file_extension not in allowed_extensions:
            raise HTTPException(
                status_code=400, 
                detail=f"File type {file_extension} is not allowed. Supported file types are: {', '.join(allowed_extensions)}."
            )
    
    if table_id is None:
        table_id = str(uuid.uuid4())
    table_name = table_id #f"{user_id}__table__{table_id}"
    
    # Create a directory for the uploaded files
    directory_path = f"./data/{table_name}"
    os.makedirs(directory_path, exist_ok=True)

    # Save each uploaded file to the data directory
    for file in files:
        file_path = os.path.join(directory_path, file.filename)
        with open(file_path, "wb") as buffer:
            shutil.copyfileobj(file.file, buffer)

    # Store user_id and table_name in a JSON file
    try:
        tables_file_path = './data/tables.json'
        os.makedirs(os.path.dirname(tables_file_path), exist_ok=True)
        # Load existing tables or create a new file if it doesn't exist
        try:
            with open(tables_file_path, 'r') as f:
                tables = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            tables = {}

        # Update the tables dictionary
        if user_id not in tables:
            tables[user_id] = []
        if table_name not in tables[user_id]:
            tables[user_id].append(table_name)

        # Write the updated tables back to the JSON file
        with open(tables_file_path, 'w') as f:
            json.dump(tables, f)

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to update tables file: {str(e)}")
    try:
        # Setup LanceDB vector store
        vector_store = LanceDBVectorStore(
            uri="./lancedb/dev",
            table_name=table_name,
            # mode="overwrite",
            # query_type="vector"
        )

        # Load documents using SimpleDirectoryReader
        documents = SimpleDirectoryReader(directory_path).load_data()
        
        # Create the index
        index = VectorStoreIndex.from_documents(
            documents,
            vector_store=vector_store
        )
        index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")

        return CreateTableResponse(
            table_id=table_id,
            message=f"Table created and documents indexed successfully",
            status="success"
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Table creation failed: {str(e)}")

@router.post("/query_table/{table_id}", response_model=QueryTableResponse)
async def query_table(

    table_id: str,

    query: str,

    user_id: str,

    #db: Annotated[Any, Depends(get_db_connection)],

    limit: Optional[int] = 10

) -> QueryTableResponse:
    """Query the database table using LlamaIndex."""
    try:
        table_name = table_id  #f"{user_id}__table__{table_id}"
        
        # load index and retriever
        storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}")
        index = load_index_from_storage(storage_context)
        retriever = index.as_retriever(similarity_top_k=limit)
        
        # Get response
        response = retriever.retrieve(query)
        
        # Format results
        results = [{
            'text': node.text,
            'score': node.score
        } for node in response]
        
        return QueryTableResponse(
            results={'data': results},
            total_results=len(results)
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")

@router.get("/get_tables/{user_id}")
async def get_tables(user_id: str):
    """Get all tables for a user."""
    
    tables_file_path = './data/tables.json'
    try:
        # Load existing tables from the JSON file
        with open(tables_file_path, 'r') as f:
            tables = json.load(f)

        # Retrieve tables for the specified user
        user_tables = tables.get(user_id, [])
        return user_tables

    except (FileNotFoundError, json.JSONDecodeError):
        return []  # Return an empty list if the file doesn't exist or is invalid
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}")

@router.get("/health")
async def health_check():
    return {"status": "healthy"}