Spaces:
Runtime error
Runtime error
File size: 1,556 Bytes
adad4ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
# api.py
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Tuple
from prompt_search_engine import PromptSearchEngine
from vectorizer import Vectorizer
from datasets import load_dataset
# Define the request and response models
class QueryRequest(BaseModel):
query: str
n: int = 5 # default value
class QueryResponse(BaseModel):
results: List[Tuple[float, str]]
# Initialize FastAPI app
app = FastAPI()
# Global variable to store the search engine
search_engine = None
# Load prompts and initialize the search engine when the app starts
@app.on_event("startup")
def startup_event():
global search_engine
# Load the prompts
dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
prompts = dataset["train"]["Prompt"]
# For testing, limit the number of prompts
prompts = prompts[:1000] # Adjust the number as needed
# Initialize vectorizer with the default model
vectorizer = Vectorizer(model="all-MiniLM-L6-v2")
# Initialize the search engine
search_engine = PromptSearchEngine(prompts, vectorizer)
# Define the /search endpoint
@app.post("/search")
def search_prompts(request: QueryRequest):
global search_engine
if search_engine is None:
return {"results": []}
# Get the top-n most similar prompts
similar_prompts = search_engine.most_similar(query=request.query, n=request.n)
# Prepare the response
results = [{"score": float(score), "prompt": prompt} for score, prompt in similar_prompts]
return {"results": results}
|