from fastapi import FastAPI, Request, Query from fastapi.templating import Jinja2Templates from fastapi import File, UploadFile from fastapi.responses import FileResponse from pydantic import BaseModel from sentence_transformers import SentenceTransformer import faiss import numpy as np import json app = FastAPI() model = SentenceTransformer('paraphrase-MiniLM-L6-v2') index = faiss.IndexFlatL2(384) # 384 is the dimensionality of the MiniLM model templates = Jinja2Templates(directory=".") class EmbedRequest(BaseModel): texts: list[str] class SearchRequest(BaseModel): text: str n: int = 5 @app.get("/") def read_root(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/embed") def embed_strings(request: EmbedRequest): new_documents = request.texts new_embeddings = model.encode(new_documents) index.add(np.array(new_embeddings)) new_size = index.ntotal return { "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}" } @app.post("/search") def search_string(request: SearchRequest): embedding = model.encode([request.text]) distances, indices = index.search(np.array(embedding), request.n) found_documents = index.reconstruct_n(indices[0], request.n) return { "distances": distances[0].tolist(), "indices": indices[0].tolist(), "documents": found_documents.tolist() } ######################### ## database management ## ######################### @app.get("/admin/database/length") def get_database_length(): return {"length": index.ntotal} @app.post("/admin/database/reset") def reset_database(): index.reset() return {"message": "Database reset"} @app.get("/admin/documents/download") def download_documents(): # Reconstruct the documents from the FAISS index documents = index.reconstruct_n(0, index.ntotal) # Convert the documents list to a JSON string documents_json = json.dumps(documents.tolist()) # Create a response with the JSON string as the content response = Response(content=documents_json, media_type="application/json") # Set the content disposition header to trigger a download response.headers["Content-Disposition"] = "attachment; filename=documents.json" return response @app.get("/admin/database/download") def download_database(): # Save the FAISS index to a file faiss.write_index(index, "database.index") # Create a response with the index file as the content response = FileResponse("database.index") # Set the content disposition header to trigger a download response.headers["Content-Disposition"] = "attachment; filename=database.index" return response @app.post("/admin/database/upload") def upload_database(file: UploadFile = File(...)): # Read the contents of the uploaded file contents = file.file.read() # Load the FAISS index from the file contents index = faiss.read_index_binary(contents) # Clear the existing documents and add the new ones documents.clear() documents.extend(index.reconstruct_n(0, index.ntotal)) return {"message": "Database uploaded"}