Anja97's picture
Initial commit with cleaned project files
adad4ac
raw
history blame
1.56 kB
# api.py
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Tuple
from prompt_search_engine import PromptSearchEngine
from vectorizer import Vectorizer
from datasets import load_dataset
# Define the request and response models
class QueryRequest(BaseModel):
query: str
n: int = 5 # default value
class QueryResponse(BaseModel):
results: List[Tuple[float, str]]
# Initialize FastAPI app
app = FastAPI()
# Global variable to store the search engine
search_engine = None
# Load prompts and initialize the search engine when the app starts
@app.on_event("startup")
def startup_event():
global search_engine
# Load the prompts
dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
prompts = dataset["train"]["Prompt"]
# For testing, limit the number of prompts
prompts = prompts[:1000] # Adjust the number as needed
# Initialize vectorizer with the default model
vectorizer = Vectorizer(model="all-MiniLM-L6-v2")
# Initialize the search engine
search_engine = PromptSearchEngine(prompts, vectorizer)
# Define the /search endpoint
@app.post("/search")
def search_prompts(request: QueryRequest):
global search_engine
if search_engine is None:
return {"results": []}
# Get the top-n most similar prompts
similar_prompts = search_engine.most_similar(query=request.query, n=request.n)
# Prepare the response
results = [{"score": float(score), "prompt": prompt} for score, prompt in similar_prompts]
return {"results": results}