Spaces:
Running
Running
from os import getenv | |
from pathlib import Path | |
import gradio as gr | |
from PIL import Image | |
from rich.traceback import install as traceback_install | |
from tagger.common import Heatmap, ImageLabels, LabelData, load_labels_hf, preprocess_image | |
from tagger.model import load_model_and_transform, process_heatmap | |
TITLE = "WD Tagger Heatmap" | |
DESCRIPTION = """WD Tagger v3 Heatmap Generator.""" | |
# get HF token | |
HF_TOKEN = getenv("HF_TOKEN", None) | |
# model repo and cache | |
MODEL_REPO = "SmilingWolf/wd-vit-tagger-v3" | |
# get the repo root (or the current working directory if running in ipython) | |
WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve() | |
# allowed extensions | |
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] | |
_ = traceback_install(show_locals=True, locals_max_length=0) | |
# get the example images | |
example_images = sorted( | |
[ | |
str(x.relative_to(WORK_DIR)) | |
for x in WORK_DIR.joinpath("examples").iterdir() | |
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS | |
] | |
) | |
def predict( | |
image: Image.Image, | |
threshold: float = 0.5, | |
): | |
# join variant for cache key | |
model, transform = load_model_and_transform(MODEL_REPO) | |
# load labels | |
labels: LabelData = load_labels_hf(MODEL_REPO) | |
# preprocess image | |
image = preprocess_image(image, (448, 448)) | |
image = transform(image).unsqueeze(0) | |
# get the model output | |
heatmaps: list[Heatmap] | |
image_labels: ImageLabels | |
heatmaps, heatmap_grid, image_labels = process_heatmap(model, image, labels, threshold) | |
heatmap_images = [(x.image, x.label) for x in heatmaps] | |
return ( | |
heatmap_images, | |
heatmap_grid, | |
image_labels.caption, | |
image_labels.booru, | |
image_labels.rating, | |
image_labels.character, | |
image_labels.general, | |
) | |
css = """ | |
#use_mcut, #char_mcut { | |
padding-top: var(--scale-3); | |
} | |
#threshold.dimmed { | |
filter: brightness(75%); | |
} | |
""" | |
with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo: | |
with gr.Row(equal_height=False): | |
with gr.Column(min_width=720): | |
with gr.Group(): | |
img_input = gr.Image( | |
label="Input", | |
type="pil", | |
image_mode="RGB", | |
sources=["upload", "clipboard"], | |
) | |
with gr.Group(): | |
with gr.Row(): | |
threshold = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.35, | |
step=0.01, | |
label="Tag Threshold", | |
scale=5, | |
elem_id="threshold", | |
) | |
with gr.Row(): | |
clear = gr.ClearButton( | |
components=[], | |
variant="secondary", | |
size="lg", | |
) | |
submit = gr.Button(value="Submit", variant="primary", size="lg") | |
with gr.Column(min_width=720): | |
with gr.Tab(label="Heatmaps"): | |
heatmap_gallery = gr.Gallery(columns=3, show_label=False) | |
with gr.Tab(label="Grid"): | |
heatmap_grid = gr.Image(show_label=False) | |
with gr.Tab(label="Tags"): | |
with gr.Group(): | |
caption = gr.Textbox(label="Caption", show_copy_button=True) | |
tags = gr.Textbox(label="Tags", show_copy_button=True) | |
with gr.Group(): | |
rating = gr.Label(label="Rating") | |
with gr.Group(): | |
character = gr.Label(label="Character") | |
with gr.Group(): | |
general = gr.Label(label="General") | |
with gr.Row(): | |
examples = [[imgpath, 0.35] for imgpath in example_images] | |
examples = gr.Examples( | |
examples=examples, | |
inputs=[img_input, threshold], | |
) | |
# tell clear button which components to clear | |
clear.add([img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general]) | |
submit.click( | |
predict, | |
inputs=[img_input, threshold], | |
outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general], | |
api_name="predict", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=10) | |
if getenv("SPACE_ID", None) is not None: | |
demo.launch() | |
else: | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7871, | |
debug=True, | |
) | |