# api.py from fastapi import FastAPI from pydantic import BaseModel from typing import List, Tuple from prompt_search_engine import PromptSearchEngine from vectorizer import Vectorizer from datasets import load_dataset # Define the request and response models class QueryRequest(BaseModel): query: str n: int = 5 # default value class QueryResponse(BaseModel): results: List[Tuple[float, str]] # Initialize FastAPI app app = FastAPI() # Global variable to store the search engine search_engine = None # Load prompts and initialize the search engine when the app starts @app.on_event("startup") def startup_event(): global search_engine # Load the prompts dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts") prompts = dataset["train"]["Prompt"] # For testing, limit the number of prompts prompts = prompts[:1000] # Adjust the number as needed # Initialize vectorizer with the default model vectorizer = Vectorizer(model="all-MiniLM-L6-v2") # Initialize the search engine search_engine = PromptSearchEngine(prompts, vectorizer) # Define the /search endpoint @app.post("/search") def search_prompts(request: QueryRequest): global search_engine if search_engine is None: return {"results": []} # Get the top-n most similar prompts similar_prompts = search_engine.most_similar(query=request.query, n=request.n) # Prepare the response results = [{"score": float(score), "prompt": prompt} for score, prompt in similar_prompts] return {"results": results}