Anja97 commited on
Commit
adad4ac
·
1 Parent(s): 3da5ff8

Initial commit with cleaned project files

Browse files
Files changed (7) hide show
  1. Dockerfile +43 -0
  2. README.md +54 -0
  3. api.py +49 -0
  4. prompt_search_engine.py +16 -0
  5. requirements.txt +5 -0
  6. similarity.py +6 -0
  7. vectorizer.py +10 -0
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9-slim
3
+
4
+ # Install git and git-lfs
5
+ RUN apt-get update && apt-get install -y git git-lfs && git lfs install
6
+
7
+ # Create a non-root user 'appuser'
8
+ RUN useradd -ms /bin/bash appuser
9
+
10
+ # Set the working directory
11
+ WORKDIR /home/appuser/app
12
+
13
+ # Copy requirements file
14
+ COPY requirements.txt .
15
+
16
+ # Install required packages
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy application code
20
+ COPY . .
21
+
22
+ # Set environment variables for cache directories
23
+ ENV HF_HOME=/home/appuser/app/.cache
24
+ ENV HF_DATASETS_CACHE=/home/appuser/app/.cache
25
+
26
+ # Create the cache directory
27
+ RUN mkdir -p /home/appuser/app/.cache
28
+
29
+ # Change ownership of the application files
30
+ RUN chown -R appuser:appuser /home/appuser/app
31
+
32
+ # Switch to non-root user
33
+ USER appuser
34
+
35
+ # Pre-download models and datasets
36
+ RUN python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
37
+ RUN python -c "from datasets import load_dataset; load_dataset('Gustavosta/Stable-Diffusion-Prompts')"
38
+
39
+ # Expose port 7860
40
+ EXPOSE 7860
41
+
42
+ # Command to run the API
43
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -9,3 +9,57 @@ short_description: Improve image quality with better prompts!
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+
13
+ # Prompt Search Engine
14
+
15
+ ## Overview
16
+
17
+ This project implements a prompt search engine for Stable Diffusion models. The search engine allows users to input a prompt and returns the top `n` most similar prompts from a corpus of existing prompts. This helps in generating higher quality images by providing more effective prompts.
18
+
19
+ The search engine consists of two main components:
20
+
21
+ - **Prompt Vectorizer**: Converts prompts into numerical vectors using a pre-trained embedding model.
22
+ - **Similarity Scorer**: Measures the similarity between the input prompt and existing prompts using cosine similarity.
23
+
24
+ ## Setup Instructions
25
+
26
+ ### Requirements
27
+
28
+ - Python >= 3.9
29
+ - pip
30
+
31
+ ### Installation
32
+
33
+ 1. **Clone the repository**
34
+
35
+ ```bash
36
+ git clone <repository-url>
37
+ cd <repository-directory>
38
+ ```
39
+
40
+ 2. **Create a virtual environment (optional)**
41
+
42
+ ```bash
43
+ python -m venv venv
44
+ source venv/bin/activate
45
+ ```
46
+
47
+ 3. **Install dependencies**
48
+ ```bash
49
+ pip install -r requirements.txt
50
+ ```
51
+
52
+ ## Running the `run.py` script
53
+ The `run.py` script allows you to run the prompt search engine from the command line.
54
+ ### Usage
55
+ ```bash
56
+ python run.py --query "Your query prompt here" --n 5 --model "all-MiniLM-L6-v2"
57
+ ```
58
+ ### Arguments
59
+ - `--query`: The query prompt (required).
60
+ - `--n`: The number of similar prompts to return (default 5).
61
+ - `--model`: The name of the SBERT model to use (default "all-MiniLM-L6-v2").
62
+
63
+ ### Example
64
+ `python run.py --query "A cat wearing glasses, sitting at a computer" --n 7`
65
+
api.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api.py
2
+
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
+ from typing import List, Tuple
6
+
7
+ from prompt_search_engine import PromptSearchEngine
8
+ from vectorizer import Vectorizer
9
+ from datasets import load_dataset
10
+
11
+ # Define the request and response models
12
+ class QueryRequest(BaseModel):
13
+ query: str
14
+ n: int = 5 # default value
15
+
16
+ class QueryResponse(BaseModel):
17
+ results: List[Tuple[float, str]]
18
+
19
+ # Initialize FastAPI app
20
+ app = FastAPI()
21
+
22
+ # Global variable to store the search engine
23
+ search_engine = None
24
+
25
+ # Load prompts and initialize the search engine when the app starts
26
+ @app.on_event("startup")
27
+ def startup_event():
28
+ global search_engine
29
+ # Load the prompts
30
+ dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
31
+ prompts = dataset["train"]["Prompt"]
32
+ # For testing, limit the number of prompts
33
+ prompts = prompts[:1000] # Adjust the number as needed
34
+ # Initialize vectorizer with the default model
35
+ vectorizer = Vectorizer(model="all-MiniLM-L6-v2")
36
+ # Initialize the search engine
37
+ search_engine = PromptSearchEngine(prompts, vectorizer)
38
+
39
+ # Define the /search endpoint
40
+ @app.post("/search")
41
+ def search_prompts(request: QueryRequest):
42
+ global search_engine
43
+ if search_engine is None:
44
+ return {"results": []}
45
+ # Get the top-n most similar prompts
46
+ similar_prompts = search_engine.most_similar(query=request.query, n=request.n)
47
+ # Prepare the response
48
+ results = [{"score": float(score), "prompt": prompt} for score, prompt in similar_prompts]
49
+ return {"results": results}
prompt_search_engine.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Sequence
2
+ import numpy as np
3
+ from vectorizer import Vectorizer
4
+ from similarity import cosine_similarity
5
+
6
+ class PromptSearchEngine:
7
+ def __init__(self, prompts: Sequence[str], vectorizer: Vectorizer) -> None:
8
+ self.prompts = prompts
9
+ self.vectorizer = vectorizer
10
+ self.corpus_vectors = vectorizer.transform(prompts)
11
+
12
+ def most_similar(self, query, n = 5) -> List[Tuple[float, str]]:
13
+ query_vector = self.vectorizer.transform([query])[0]
14
+ similarities = cosine_similarity(query_vector, self.corpus_vectors)
15
+ top_indices = similarities.argsort()[-n:][::-1]
16
+ return [(similarities[i], self.prompts[i]) for i in top_indices]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ sentence-transformers
2
+ numpy
3
+ datasets
4
+ fastapi
5
+ uvicorn
similarity.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def cosine_similarity(query_vector: np.ndarray, corpus_vectors: np.ndarray) -> np.ndarray:
4
+ query_norm = query_vector / np.linalg.norm(query_vector)
5
+ corpus_norm = corpus_vectors / np.linalg.norm(corpus_vectors, axis=1, keepdims=True)
6
+ return np.dot(corpus_norm, query_norm)
vectorizer.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import numpy as np
3
+ from typing import Sequence
4
+
5
+ class Vectorizer:
6
+ def __init__(self, model) -> None:
7
+ self.model = SentenceTransformer(model)
8
+
9
+ def transform(self, prompts: Sequence[str]) -> np.ndarray:
10
+ return self.model.encode(prompts)