from fastapi import FastAPI, Request, Query from fastapi.templating import Jinja2Templates from fastapi import File, UploadFile from fastapi.responses import FileResponse from fastapi.responses import Response from pydantic import BaseModel from sentence_transformers import SentenceTransformer import faiss import numpy as np import json import io app = FastAPI() #model = SentenceTransformer('paraphrase-MiniLM-L6-v2') #embedding_dimension = 384 # 384 is the dimensionality of the MiniLM model #1. Specify preffered dimensions embedding_dimension = 512 # 2. load model model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", truncate_dim=embedding_dimension) index = faiss.IndexFlatL2(embedding_dimension) documents = [] templates = Jinja2Templates(directory=".") class EmbedRequest(BaseModel): texts: list[str] class SearchRequest(BaseModel): text: str n: int = 5 @app.get("/") def read_root(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/embed") def embed_strings(request: EmbedRequest): new_documents = request.texts print(f"Start embedding of {len(new_documents)} docs") batch_size = 20 # Split the new_documents list into batches of 10 documents batches = [new_documents[i:i+batch_size] for i in range(0, len(new_documents), batch_size)] # Perform embedding for each batch new_embeddings = [] for batch in batches: batch_embeddings = model.encode(batch) new_embeddings.extend(batch_embeddings) print(f"embeded {batch_size} docs") # Handle remaining documents less than batch_size remaining_docs = len(new_documents) % batch_size print(f"embedind remaining {remaining_docs} docs") if remaining_docs > 0: remaining_batch = new_documents[-remaining_docs:] remaining_embeddings = model.encode(remaining_batch) new_embeddings.extend(remaining_embeddings) index.add(np.array(new_embeddings)) new_size = index.ntotal documents.extend(new_documents) print(f"End embedding {len(new_documents)} docs, new DB size: {new_size}") return { "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}" } def embed_strings_v0(request: EmbedRequest): new_documents = request.texts new_embeddings = model.encode(new_documents) index.add(np.array(new_embeddings)) new_size = index.ntotal documents.extend(new_documents) return { "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}" } @app.post("/search") def search_string(request: SearchRequest): embedding = model.encode([request.text]) distances, indices = index.search(np.array(embedding), request.n) # Get the documents associated with the returned indices found_documents = [documents[i] for i in indices[0]] return { "distances": distances[0].tolist(), "indices": indices[0].tolist(), "documents": found_documents } ######################### ## database management ## ######################### @app.get("/admin/database/length") def get_database_length(): return {"length": index.ntotal} @app.post("/admin/database/reset") def reset_database(): global index global documents index = faiss.IndexFlatL2(embedding_dimension) documents = [] return {"message": "Database reset"} @app.get("/admin/documents/download") def download_documents(): # Convert the documents list to a JSON string documents_json = json.dumps({"texts": documents}) # Create a response with the JSON string as the content response = Response(content=documents_json, media_type="application/json") # Set the content disposition header to trigger a download response.headers["Content-Disposition"] = "attachment; filename=documents.json" return response @app.post("/admin/documents/upload") def upload_documents(file: UploadFile = File(...)): # Read the contents of the uploaded file contents = file.file.read() # Load the JSON data from the file contents data = json.loads(contents) # Get the list of documents from the JSON data new_documents = data["texts"] # Add the new documents to the documents list documents.extend(new_documents) return {"message": f"{len(new_documents)} new documents uploaded"} @app.post("/admin/documents/embed") def embed_documents(file: UploadFile = File(...)): # Read the contents of the uploaded file contents = file.file.read() # Load the JSON data from the file contents data = json.loads(contents) # Get the list of documents from the JSON data new_documents = data["texts"] # Encode the new documents and add them to the FAISS database new_embeddings = model.encode(new_documents) index.add(np.array(new_embeddings)) # Add the new documents to the documents list documents.extend(new_documents) return {"message": f"{len(new_documents)} new documents uploaded and embedded"} @app.get("/admin/database/download") def download_database(): # Save the FAISS index to a file faiss.write_index(index, "database.index") # Create a response with the index file as the content response = FileResponse("database.index", media_type="application/octet-stream") # Set the content disposition header to trigger a download response.headers["Content-Disposition"] = "attachment; filename=database.index" return response @app.post("/admin/database/upload") def upload_database(file: UploadFile = File(...)): # Read the contents of the uploaded file #contents = file.file.read() # Open the uploaded file as a binary file object with open(file.filename, "wb") as f: f.write(file.file.read()) # Load the FAISS index from the file contents global index index = faiss.read_index(file.filename) return {"message": f"Database uploaded with {index.ntotal} embeddings"} def upload_database_1(file: UploadFile = File(...)): # Open the uploaded file as a binary file object with open(file.filename, "wb") as f: f.write(file.file.read()) # Open the file as a binary file object with open(file.filename, "rb") as f: # Load the FAISS index from the file object global index index = faiss.read_index_binary(f) # Clear the existing documents list and add the new documents global documents documents = index.reconstruct_n(0, index.ntotal).tolist() return {"message": f"Database uploaded with {len(documents)} documents"} def upload_database_0(file: UploadFile = File(...)): # Read the contents of the uploaded file contents = file.file.read() # Load the FAISS index from the file contents global index index = faiss.read_index_binary(contents) # Clear the existing documents list and add the new documents #global documents #documents = index.reconstruct_n(0, index.ntotal).tolist() return {"message": f"Database uploaded with {index.ntotal} embeddings"}