File size: 3,242 Bytes
e26d32e 5914320 daedc24 e9edc55 a0edacc 3ab82e8 e9edc55 3eec3b2 00a8910 3ab82e8 daedc24 3eec3b2 3f61915 5914320 a0edacc f238fcb a0edacc 5914320 3ab82e8 744d14e eb810c1 a0edacc 744d14e daedc24 744d14e 5914320 daedc24 eb810c1 f238fcb e9edc55 744d14e e9edc55 744d14e daedc24 744d14e daedc24 744d14e daedc24 744d14e daedc24 744d14e daedc24 744d14e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from fastapi import FastAPI, Request, Query
from fastapi.templating import Jinja2Templates
from fastapi import File, UploadFile
from fastapi.responses import FileResponse
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
app = FastAPI()
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
index = faiss.IndexFlatL2(384) # 384 is the dimensionality of the MiniLM model
templates = Jinja2Templates(directory=".")
class EmbedRequest(BaseModel):
texts: list[str]
class SearchRequest(BaseModel):
text: str
n: int = 5
@app.get("/")
def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/embed")
def embed_strings(request: EmbedRequest):
new_documents = request.texts
new_embeddings = model.encode(new_documents)
index.add(np.array(new_embeddings))
new_size = index.ntotal
return {
"message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
}
@app.post("/search")
def search_string(request: SearchRequest):
embedding = model.encode([request.text])
distances, indices = index.search(np.array(embedding), request.n)
found_documents = index.reconstruct_n(indices[0], request.n)
return {
"distances": distances[0].tolist(),
"indices": indices[0].tolist(),
"documents": found_documents.tolist()
}
#########################
## database management ##
#########################
@app.get("/admin/database/length")
def get_database_length():
return {"length": index.ntotal}
@app.post("/admin/database/reset")
def reset_database():
index.reset()
return {"message": "Database reset"}
@app.get("/admin/documents/download")
def download_documents():
# Reconstruct the documents from the FAISS index
documents = index.reconstruct_n(0, index.ntotal)
# Convert the documents list to a JSON string
documents_json = json.dumps(documents.tolist())
# Create a response with the JSON string as the content
response = Response(content=documents_json, media_type="application/json")
# Set the content disposition header to trigger a download
response.headers["Content-Disposition"] = "attachment; filename=documents.json"
return response
@app.get("/admin/database/download")
def download_database():
# Save the FAISS index to a file
faiss.write_index(index, "database.index")
# Create a response with the index file as the content
response = FileResponse("database.index")
# Set the content disposition header to trigger a download
response.headers["Content-Disposition"] = "attachment; filename=database.index"
return response
@app.post("/admin/database/upload")
def upload_database(file: UploadFile = File(...)):
# Read the contents of the uploaded file
contents = file.file.read()
# Load the FAISS index from the file contents
index = faiss.read_index_binary(contents)
# Clear the existing documents and add the new ones
documents.clear()
documents.extend(index.reconstruct_n(0, index.ntotal))
return {"message": "Database uploaded"}
|