from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List import json import os import logging from txtai.embeddings import Embeddings # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) embeddings = Embeddings({"path": "avsolatorio/GIST-all-MiniLM-L6-v2"}) class DocumentRequest(BaseModel): index_id: str documents: List[str] class QueryRequest(BaseModel): index_id: str query: str num_results: int def save_embeddings(index_id, document_list): try: folder_path = f"/app/indexes/{index_id}" os.makedirs(folder_path, exist_ok=True) # Save embeddings embeddings.save(f"{folder_path}/embeddings") # Save document_list with open(f"{folder_path}/document_list.json", "w") as f: json.dump(document_list, f) logger.info(f"Embeddings and document list saved for index_id: {index_id}") except Exception as e: logger.error(f"Error saving embeddings for index_id {index_id}: {str(e)}") raise HTTPException(status_code=500, detail=f"Error saving embeddings: {str(e)}") def load_embeddings(index_id): try: folder_path = f"/app/indexes/{index_id}" if not os.path.exists(folder_path): logger.error(f"Index not found for index_id: {index_id}") raise HTTPException(status_code=404, detail="Index not found") # Load embeddings embeddings.load(f"{folder_path}/embeddings") # Load document_list with open(f"{folder_path}/document_list.json", "r") as f: document_list = json.load(f) logger.info(f"Embeddings and document list loaded for index_id: {index_id}") return document_list except Exception as e: logger.error(f"Error loading embeddings for index_id {index_id}: {str(e)}") raise HTTPException(status_code=500, detail=f"Error loading embeddings: {str(e)}") @app.post("/create_index/") async def create_index(request: DocumentRequest): try: document_list = [(i, text, None) for i, text in enumerate(request.documents)] embeddings.index(document_list) save_embeddings(request.index_id, request.documents) # Save the original documents logger.info(f"Index created successfully for index_id: {request.index_id}") return {"message": "Index created successfully"} except Exception as e: logger.error(f"Error creating index: {str(e)}") raise HTTPException(status_code=500, detail=f"Error creating index: {str(e)}") @app.post("/query_index/") async def query_index(request: QueryRequest): try: document_list = load_embeddings(request.index_id) results = embeddings.search(request.query, request.num_results) queried_texts = [document_list[idx[0]] for idx in results] logger.info(f"Query executed successfully for index_id: {request.index_id}") return {"queried_texts": queried_texts} except Exception as e: logger.error(f"Error querying index: {str(e)}") raise HTTPException(status_code=500, detail=f"Error querying index: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)