Spaces:
Runtime error
Runtime error
# api.py | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from typing import List, Tuple | |
from prompt_search_engine import PromptSearchEngine | |
from vectorizer import Vectorizer | |
from datasets import load_dataset | |
# Define the request and response models | |
class QueryRequest(BaseModel): | |
query: str | |
n: int = 5 # default value | |
class QueryResponse(BaseModel): | |
results: List[Tuple[float, str]] | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Global variable to store the search engine | |
search_engine = None | |
# Load prompts and initialize the search engine when the app starts | |
def startup_event(): | |
global search_engine | |
# Load the prompts | |
dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts") | |
prompts = dataset["train"]["Prompt"] | |
# For testing, limit the number of prompts | |
prompts = prompts[:1000] # Adjust the number as needed | |
# Initialize vectorizer with the default model | |
vectorizer = Vectorizer(model="all-MiniLM-L6-v2") | |
# Initialize the search engine | |
search_engine = PromptSearchEngine(prompts, vectorizer) | |
# Define the /search endpoint | |
def search_prompts(request: QueryRequest): | |
global search_engine | |
if search_engine is None: | |
return {"results": []} | |
# Get the top-n most similar prompts | |
similar_prompts = search_engine.most_similar(query=request.query, n=request.n) | |
# Prepare the response | |
results = [{"score": float(score), "prompt": prompt} for score, prompt in similar_prompts] | |
return {"results": results} | |