File size: 1,556 Bytes
adad4ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# api.py

from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Tuple

from prompt_search_engine import PromptSearchEngine
from vectorizer import Vectorizer
from datasets import load_dataset

# Define the request and response models
class QueryRequest(BaseModel):
    query: str
    n: int = 5  # default value

class QueryResponse(BaseModel):
    results: List[Tuple[float, str]]

# Initialize FastAPI app
app = FastAPI()

# Global variable to store the search engine
search_engine = None

# Load prompts and initialize the search engine when the app starts
@app.on_event("startup")
def startup_event():
    global search_engine
    # Load the prompts
    dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
    prompts = dataset["train"]["Prompt"]
    # For testing, limit the number of prompts
    prompts = prompts[:1000]  # Adjust the number as needed
    # Initialize vectorizer with the default model
    vectorizer = Vectorizer(model="all-MiniLM-L6-v2")
    # Initialize the search engine
    search_engine = PromptSearchEngine(prompts, vectorizer)

# Define the /search endpoint
@app.post("/search")
def search_prompts(request: QueryRequest):
    global search_engine
    if search_engine is None:
        return {"results": []}
    # Get the top-n most similar prompts
    similar_prompts = search_engine.most_similar(query=request.query, n=request.n)
    # Prepare the response
    results = [{"score": float(score), "prompt": prompt} for score, prompt in similar_prompts]
    return {"results": results}