Spaces:
Runtime error
Runtime error
Initial commit with cleaned project files
Browse files- Dockerfile +43 -0
- README.md +54 -0
- api.py +49 -0
- prompt_search_engine.py +16 -0
- requirements.txt +5 -0
- similarity.py +6 -0
- vectorizer.py +10 -0
Dockerfile
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Python runtime as a parent image
|
2 |
+
FROM python:3.9-slim
|
3 |
+
|
4 |
+
# Install git and git-lfs
|
5 |
+
RUN apt-get update && apt-get install -y git git-lfs && git lfs install
|
6 |
+
|
7 |
+
# Create a non-root user 'appuser'
|
8 |
+
RUN useradd -ms /bin/bash appuser
|
9 |
+
|
10 |
+
# Set the working directory
|
11 |
+
WORKDIR /home/appuser/app
|
12 |
+
|
13 |
+
# Copy requirements file
|
14 |
+
COPY requirements.txt .
|
15 |
+
|
16 |
+
# Install required packages
|
17 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
18 |
+
|
19 |
+
# Copy application code
|
20 |
+
COPY . .
|
21 |
+
|
22 |
+
# Set environment variables for cache directories
|
23 |
+
ENV HF_HOME=/home/appuser/app/.cache
|
24 |
+
ENV HF_DATASETS_CACHE=/home/appuser/app/.cache
|
25 |
+
|
26 |
+
# Create the cache directory
|
27 |
+
RUN mkdir -p /home/appuser/app/.cache
|
28 |
+
|
29 |
+
# Change ownership of the application files
|
30 |
+
RUN chown -R appuser:appuser /home/appuser/app
|
31 |
+
|
32 |
+
# Switch to non-root user
|
33 |
+
USER appuser
|
34 |
+
|
35 |
+
# Pre-download models and datasets
|
36 |
+
RUN python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
|
37 |
+
RUN python -c "from datasets import load_dataset; load_dataset('Gustavosta/Stable-Diffusion-Prompts')"
|
38 |
+
|
39 |
+
# Expose port 7860
|
40 |
+
EXPOSE 7860
|
41 |
+
|
42 |
+
# Command to run the API
|
43 |
+
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -9,3 +9,57 @@ short_description: Improve image quality with better prompts!
|
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
12 |
+
|
13 |
+
# Prompt Search Engine
|
14 |
+
|
15 |
+
## Overview
|
16 |
+
|
17 |
+
This project implements a prompt search engine for Stable Diffusion models. The search engine allows users to input a prompt and returns the top `n` most similar prompts from a corpus of existing prompts. This helps in generating higher quality images by providing more effective prompts.
|
18 |
+
|
19 |
+
The search engine consists of two main components:
|
20 |
+
|
21 |
+
- **Prompt Vectorizer**: Converts prompts into numerical vectors using a pre-trained embedding model.
|
22 |
+
- **Similarity Scorer**: Measures the similarity between the input prompt and existing prompts using cosine similarity.
|
23 |
+
|
24 |
+
## Setup Instructions
|
25 |
+
|
26 |
+
### Requirements
|
27 |
+
|
28 |
+
- Python >= 3.9
|
29 |
+
- pip
|
30 |
+
|
31 |
+
### Installation
|
32 |
+
|
33 |
+
1. **Clone the repository**
|
34 |
+
|
35 |
+
```bash
|
36 |
+
git clone <repository-url>
|
37 |
+
cd <repository-directory>
|
38 |
+
```
|
39 |
+
|
40 |
+
2. **Create a virtual environment (optional)**
|
41 |
+
|
42 |
+
```bash
|
43 |
+
python -m venv venv
|
44 |
+
source venv/bin/activate
|
45 |
+
```
|
46 |
+
|
47 |
+
3. **Install dependencies**
|
48 |
+
```bash
|
49 |
+
pip install -r requirements.txt
|
50 |
+
```
|
51 |
+
|
52 |
+
## Running the `run.py` script
|
53 |
+
The `run.py` script allows you to run the prompt search engine from the command line.
|
54 |
+
### Usage
|
55 |
+
```bash
|
56 |
+
python run.py --query "Your query prompt here" --n 5 --model "all-MiniLM-L6-v2"
|
57 |
+
```
|
58 |
+
### Arguments
|
59 |
+
- `--query`: The query prompt (required).
|
60 |
+
- `--n`: The number of similar prompts to return (default 5).
|
61 |
+
- `--model`: The name of the SBERT model to use (default "all-MiniLM-L6-v2").
|
62 |
+
|
63 |
+
### Example
|
64 |
+
`python run.py --query "A cat wearing glasses, sitting at a computer" --n 7`
|
65 |
+
|
api.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# api.py
|
2 |
+
|
3 |
+
from fastapi import FastAPI
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from typing import List, Tuple
|
6 |
+
|
7 |
+
from prompt_search_engine import PromptSearchEngine
|
8 |
+
from vectorizer import Vectorizer
|
9 |
+
from datasets import load_dataset
|
10 |
+
|
11 |
+
# Define the request and response models
|
12 |
+
class QueryRequest(BaseModel):
|
13 |
+
query: str
|
14 |
+
n: int = 5 # default value
|
15 |
+
|
16 |
+
class QueryResponse(BaseModel):
|
17 |
+
results: List[Tuple[float, str]]
|
18 |
+
|
19 |
+
# Initialize FastAPI app
|
20 |
+
app = FastAPI()
|
21 |
+
|
22 |
+
# Global variable to store the search engine
|
23 |
+
search_engine = None
|
24 |
+
|
25 |
+
# Load prompts and initialize the search engine when the app starts
|
26 |
+
@app.on_event("startup")
|
27 |
+
def startup_event():
|
28 |
+
global search_engine
|
29 |
+
# Load the prompts
|
30 |
+
dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
|
31 |
+
prompts = dataset["train"]["Prompt"]
|
32 |
+
# For testing, limit the number of prompts
|
33 |
+
prompts = prompts[:1000] # Adjust the number as needed
|
34 |
+
# Initialize vectorizer with the default model
|
35 |
+
vectorizer = Vectorizer(model="all-MiniLM-L6-v2")
|
36 |
+
# Initialize the search engine
|
37 |
+
search_engine = PromptSearchEngine(prompts, vectorizer)
|
38 |
+
|
39 |
+
# Define the /search endpoint
|
40 |
+
@app.post("/search")
|
41 |
+
def search_prompts(request: QueryRequest):
|
42 |
+
global search_engine
|
43 |
+
if search_engine is None:
|
44 |
+
return {"results": []}
|
45 |
+
# Get the top-n most similar prompts
|
46 |
+
similar_prompts = search_engine.most_similar(query=request.query, n=request.n)
|
47 |
+
# Prepare the response
|
48 |
+
results = [{"score": float(score), "prompt": prompt} for score, prompt in similar_prompts]
|
49 |
+
return {"results": results}
|
prompt_search_engine.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Sequence
|
2 |
+
import numpy as np
|
3 |
+
from vectorizer import Vectorizer
|
4 |
+
from similarity import cosine_similarity
|
5 |
+
|
6 |
+
class PromptSearchEngine:
|
7 |
+
def __init__(self, prompts: Sequence[str], vectorizer: Vectorizer) -> None:
|
8 |
+
self.prompts = prompts
|
9 |
+
self.vectorizer = vectorizer
|
10 |
+
self.corpus_vectors = vectorizer.transform(prompts)
|
11 |
+
|
12 |
+
def most_similar(self, query, n = 5) -> List[Tuple[float, str]]:
|
13 |
+
query_vector = self.vectorizer.transform([query])[0]
|
14 |
+
similarities = cosine_similarity(query_vector, self.corpus_vectors)
|
15 |
+
top_indices = similarities.argsort()[-n:][::-1]
|
16 |
+
return [(similarities[i], self.prompts[i]) for i in top_indices]
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sentence-transformers
|
2 |
+
numpy
|
3 |
+
datasets
|
4 |
+
fastapi
|
5 |
+
uvicorn
|
similarity.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
def cosine_similarity(query_vector: np.ndarray, corpus_vectors: np.ndarray) -> np.ndarray:
|
4 |
+
query_norm = query_vector / np.linalg.norm(query_vector)
|
5 |
+
corpus_norm = corpus_vectors / np.linalg.norm(corpus_vectors, axis=1, keepdims=True)
|
6 |
+
return np.dot(corpus_norm, query_norm)
|
vectorizer.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
import numpy as np
|
3 |
+
from typing import Sequence
|
4 |
+
|
5 |
+
class Vectorizer:
|
6 |
+
def __init__(self, model) -> None:
|
7 |
+
self.model = SentenceTransformer(model)
|
8 |
+
|
9 |
+
def transform(self, prompts: Sequence[str]) -> np.ndarray:
|
10 |
+
return self.model.encode(prompts)
|