Prompt_Squirrel / app.py
FoodDesert's picture
Upload app.py
d96b2be verified
raw
history blame
5.03 kB
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from joblib import load
import h5py
from io import BytesIO
faq_content="""
# Frequently Asked Questions (FAQs)
Technically I am writing this before anyone but me has used the tool, so no one has asked questions yet. But if they did, here are the questions I think they might ask:
## Does input order matter?
No
## Should I use underscores in the input tags?
It doesn't matter. The application handles tags either way.
## Why are some valid tags marked as "unseen", and why don't some artists ever get returned?
Some data is excluded from consideration if it did not occur frequently enough in the sample from which the application makes its calculations.
If an artist or tag is too infrequent, we might not think we have enough data to make predictions about it.
## Are there any special tags?
Yes. We normalized the favorite counts of each image to a range of 0-9, with 0 being the lowest favcount, and 9 being the highest.
You can include any of these special tags: "score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9"
in your list to bias the output toward artists with higher or lower scoring images.
## Are there any other special tricks?
Yes. If you want to more strongly bias the artist output toward a specific tag, you can just list it multiple times.
So for example, the query "red fox, red fox, red fox, score:7" will yield a list of artists who are more strongly associated with the tag "red fox"
than the query "red fox, score:7".
## What calculation is this thing actually performing?
Each artist is represented by a "pseudo-document" composed of all the tags from their uploaded images, treating these tags similarly to words in a text document.
Similarly, when you input a set of tags, the system creates a pseudo-document for your query out of all the tags.
It then uses a technique called cosine similarity to compare your tags against each artist's collection, essentially finding which artist's tags are most "similar" to yours.
This method helps identify artists whose work is closely aligned with the themes or elements you're interested in.
For those curious about the underlying mechanics of comparing text-like data, we employ the TF-IDF (Term Frequency-Inverse Document Frequency) method, a standard approach in information retrieval.
You can read more about TF-IDF on its [Wikipedia page](https://en.wikipedia.org/wiki/Tf%E2%80%93idf).
"""
# Load the model and data once at startup
with h5py.File('complete_artist_data.hdf5', 'r') as f:
# Deserialize the vectorizer
vectorizer_bytes = f['vectorizer'][()].tobytes()
vectorizer_buffer = BytesIO(vectorizer_bytes)
vectorizer = load(vectorizer_buffer)
# Load X_artist
X_artist = f['X_artist'][:]
# Load artist names and decode to strings
artist_names = [name.decode() for name in f['artist_names'][:]]
def find_similar_artists(new_tags_string, top_n):
#
new_image_tags = [tag.replace('_', ' ').strip() for tag in new_tags_string.split(",")]
unseen_tags = set(new_image_tags) - set(vectorizer.vocabulary_.keys())
unseen_tags_str = f'Unseen Tags: {", ".join(unseen_tags)}' if unseen_tags else 'No unseen tags.'
X_new_image = vectorizer.transform([','.join(new_image_tags)])
similarities = cosine_similarity(X_new_image, X_artist)[0]
top_artist_indices = np.argsort(similarities)[-top_n:][::-1]
top_artists = [(artist_names[i], similarities[i]) for i in top_artist_indices]
top_artists_str = "\n".join([f"{rank+1}. {artist[3:]} ({score:.4f})" for rank, (artist, score) in enumerate(top_artists)])
dynamic_prompts_formatted_artists = "{" + "|".join([artist for artist, _ in top_artists]) + "}"
return unseen_tags_str, top_artists_str, dynamic_prompts_formatted_artists
iface = gr.Interface(
fn=find_similar_artists,
inputs=[
gr.Textbox(label="Enter image tags", placeholder="e.g. fox, outside, detailed background, ..."),
gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of artists")
],
outputs=[
gr.Textbox(label="Unseen Tags", info="These tags are not used in the artist calculation. Even valid e6 tags may be \"unseen\" if they have insufficient data."),
gr.Textbox(label="Top Artists", info="These are the artists most strongly associated with your tags. The number in parenthes is a similarity score between 0 and 1, with higher numbers indicating greater similarity."),
gr.Textbox(label="Dynamic Prompts Format", info="For if you're using the Automatic1111 webui (https://github.com/AUTOMATIC1111/stable-diffusion-webui) with the Dynamic Prompts extension activated (https://github.com/adieyal/sd-dynamic-prompts) and want to try them all individually.")
],
title="Tagset Completer",
description="Enter a list of comma-separated e6 tags",
article=faq_content
)
iface.launch()