File size: 3,242 Bytes
e26d32e
5914320
daedc24
e9edc55
 
a0edacc
3ab82e8
 
 
e9edc55
3eec3b2
00a8910
3ab82e8
 
daedc24
3eec3b2
3f61915
5914320
a0edacc
 
 
f238fcb
 
 
a0edacc
5914320
 
 
3ab82e8
744d14e
eb810c1
a0edacc
744d14e
 
 
daedc24
744d14e
 
 
5914320
daedc24
eb810c1
f238fcb
 
 
e9edc55
744d14e
 
 
e9edc55
744d14e
 
 
 
 
 
 
daedc24
744d14e
daedc24
 
744d14e
daedc24
744d14e
 
 
daedc24
 
 
744d14e
daedc24
744d14e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from fastapi import FastAPI, Request, Query
from fastapi.templating import Jinja2Templates
from fastapi import File, UploadFile
from fastapi.responses import FileResponse

from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json

app = FastAPI()
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
index = faiss.IndexFlatL2(384)  # 384 is the dimensionality of the MiniLM model


templates = Jinja2Templates(directory=".")

class EmbedRequest(BaseModel):
    texts: list[str]

class SearchRequest(BaseModel):
    text: str
    n: int = 5
    
@app.get("/")
def read_root(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})


@app.post("/embed")
def embed_strings(request: EmbedRequest):
    new_documents = request.texts
    new_embeddings = model.encode(new_documents)
    index.add(np.array(new_embeddings))
    new_size = index.ntotal
    return {
        "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
    }


@app.post("/search")
def search_string(request: SearchRequest):
    embedding = model.encode([request.text])
    distances, indices = index.search(np.array(embedding), request.n)
    found_documents = index.reconstruct_n(indices[0], request.n)
    return {
        "distances": distances[0].tolist(),
        "indices": indices[0].tolist(),
        "documents": found_documents.tolist()
    }

#########################
## database management ##
#########################
@app.get("/admin/database/length")
def get_database_length():
    return {"length": index.ntotal}

@app.post("/admin/database/reset")
def reset_database():
    index.reset()
    return {"message": "Database reset"}

@app.get("/admin/documents/download")
def download_documents():
    # Reconstruct the documents from the FAISS index
    documents = index.reconstruct_n(0, index.ntotal)

    # Convert the documents list to a JSON string
    documents_json = json.dumps(documents.tolist())

    # Create a response with the JSON string as the content
    response = Response(content=documents_json, media_type="application/json")

    # Set the content disposition header to trigger a download
    response.headers["Content-Disposition"] = "attachment; filename=documents.json"

    return response

@app.get("/admin/database/download")
def download_database():
    # Save the FAISS index to a file
    faiss.write_index(index, "database.index")

    # Create a response with the index file as the content
    response = FileResponse("database.index")

    # Set the content disposition header to trigger a download
    response.headers["Content-Disposition"] = "attachment; filename=database.index"

    return response

@app.post("/admin/database/upload")
def upload_database(file: UploadFile = File(...)):
    # Read the contents of the uploaded file
    contents = file.file.read()

    # Load the FAISS index from the file contents
    index = faiss.read_index_binary(contents)

    # Clear the existing documents and add the new ones
    documents.clear()
    documents.extend(index.reconstruct_n(0, index.ntotal))

    return {"message": "Database uploaded"}