Spaces:
Running
Running
File size: 4,634 Bytes
c24a176 802ae2a c24a176 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from os import getenv
from pathlib import Path
import gradio as gr
from PIL import Image
from rich.traceback import install as traceback_install
from tagger.common import Heatmap, ImageLabels, LabelData, load_labels_hf, preprocess_image
from tagger.model import load_model_and_transform, process_heatmap
TITLE = "WD Tagger Heatmap"
DESCRIPTION = """WD Tagger v3 Heatmap Generator."""
# get HF token
HF_TOKEN = getenv("HF_TOKEN", None)
# model repo and cache
MODEL_REPO = "SmilingWolf/wd-vit-tagger-v3"
# get the repo root (or the current working directory if running in ipython)
WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve()
# allowed extensions
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
_ = traceback_install(show_locals=True, locals_max_length=0)
# get the example images
example_images = sorted(
[
str(x.relative_to(WORK_DIR))
for x in WORK_DIR.joinpath("examples").iterdir()
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
]
)
def predict(
image: Image.Image,
threshold: float = 0.5,
):
# join variant for cache key
model, transform = load_model_and_transform(MODEL_REPO)
# load labels
labels: LabelData = load_labels_hf(MODEL_REPO)
# preprocess image
image = preprocess_image(image, (448, 448))
image = transform(image).unsqueeze(0)
# get the model output
heatmaps: list[Heatmap]
image_labels: ImageLabels
heatmaps, heatmap_grid, image_labels = process_heatmap(model, image, labels, threshold)
heatmap_images = [(x.image, x.label) for x in heatmaps]
return (
heatmap_images,
heatmap_grid,
image_labels.caption,
image_labels.booru,
image_labels.rating,
image_labels.character,
image_labels.general,
)
css = """
#use_mcut, #char_mcut {
padding-top: var(--scale-3);
}
#threshold.dimmed {
filter: brightness(75%);
}
"""
with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo:
with gr.Row(equal_height=False):
with gr.Column(min_width=720):
with gr.Group():
img_input = gr.Image(
label="Input",
type="pil",
image_mode="RGB",
sources=["upload", "clipboard"],
)
with gr.Group():
with gr.Row():
threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.35,
step=0.01,
label="Tag Threshold",
scale=5,
elem_id="threshold",
)
with gr.Row():
clear = gr.ClearButton(
components=[],
variant="secondary",
size="lg",
)
submit = gr.Button(value="Submit", variant="primary", size="lg")
with gr.Column(min_width=720):
with gr.Tab(label="Heatmaps"):
heatmap_gallery = gr.Gallery(columns=3, show_label=False)
with gr.Tab(label="Grid"):
heatmap_grid = gr.Image(show_label=False)
with gr.Tab(label="Tags"):
with gr.Group():
caption = gr.Textbox(label="Caption", show_copy_button=True)
tags = gr.Textbox(label="Tags", show_copy_button=True)
with gr.Group():
rating = gr.Label(label="Rating")
with gr.Group():
character = gr.Label(label="Character")
with gr.Group():
general = gr.Label(label="General")
with gr.Row():
examples = [[imgpath, 0.35] for imgpath in example_images]
examples = gr.Examples(
examples=examples,
inputs=[img_input, threshold],
)
# tell clear button which components to clear
clear.add([img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general])
submit.click(
predict,
inputs=[img_input, threshold],
outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general],
api_name="predict",
)
if __name__ == "__main__":
demo.queue(max_size=10)
if getenv("SPACE_ID", None) is not None:
demo.launch()
else:
demo.launch(
server_name="0.0.0.0",
server_port=7871,
debug=True,
)
|