File size: 4,634 Bytes
c24a176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
802ae2a
 
 
c24a176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from os import getenv
from pathlib import Path

import gradio as gr
from PIL import Image
from rich.traceback import install as traceback_install

from tagger.common import Heatmap, ImageLabels, LabelData, load_labels_hf, preprocess_image
from tagger.model import load_model_and_transform, process_heatmap

TITLE = "WD Tagger Heatmap"
DESCRIPTION = """WD Tagger v3 Heatmap Generator."""
# get HF token
HF_TOKEN = getenv("HF_TOKEN", None)

# model repo and cache
MODEL_REPO = "SmilingWolf/wd-vit-tagger-v3"
# get the repo root (or the current working directory if running in ipython)
WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve()
# allowed extensions
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]

_ = traceback_install(show_locals=True, locals_max_length=0)

# get the example images
example_images = sorted(
    [
        str(x.relative_to(WORK_DIR))
        for x in WORK_DIR.joinpath("examples").iterdir()
        if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
    ]
)


def predict(
    image: Image.Image,
    threshold: float = 0.5,
):
    # join variant for cache key
    model, transform = load_model_and_transform(MODEL_REPO)
    # load labels
    labels: LabelData = load_labels_hf(MODEL_REPO)
    # preprocess image
    image = preprocess_image(image, (448, 448))
    image = transform(image).unsqueeze(0)

    # get the model output
    heatmaps: list[Heatmap]
    image_labels: ImageLabels
    heatmaps, heatmap_grid, image_labels = process_heatmap(model, image, labels, threshold)

    heatmap_images = [(x.image, x.label) for x in heatmaps]

    return (
        heatmap_images,
        heatmap_grid,
        image_labels.caption,
        image_labels.booru,
        image_labels.rating,
        image_labels.character,
        image_labels.general,
    )


css = """
#use_mcut, #char_mcut {
    padding-top: var(--scale-3);
}
#threshold.dimmed {
    filter: brightness(75%);
}
"""

with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo:
    with gr.Row(equal_height=False):
        with gr.Column(min_width=720):
            with gr.Group():
                img_input = gr.Image(
                    label="Input",
                    type="pil",
                    image_mode="RGB",
                    sources=["upload", "clipboard"],
                )
            with gr.Group():
                with gr.Row():
                    threshold = gr.Slider(
                        minimum=0.0,
                        maximum=1.0,
                        value=0.35,
                        step=0.01,
                        label="Tag Threshold",
                        scale=5,
                        elem_id="threshold",
                    )
            with gr.Row():
                clear = gr.ClearButton(
                    components=[],
                    variant="secondary",
                    size="lg",
                )
                submit = gr.Button(value="Submit", variant="primary", size="lg")

        with gr.Column(min_width=720):
            with gr.Tab(label="Heatmaps"):
                heatmap_gallery = gr.Gallery(columns=3, show_label=False)
            with gr.Tab(label="Grid"):
                heatmap_grid = gr.Image(show_label=False)
            with gr.Tab(label="Tags"):
                with gr.Group():
                    caption = gr.Textbox(label="Caption", show_copy_button=True)
                    tags = gr.Textbox(label="Tags", show_copy_button=True)
                with gr.Group():
                    rating = gr.Label(label="Rating")
                with gr.Group():
                    character = gr.Label(label="Character")
                with gr.Group():
                    general = gr.Label(label="General")

    with gr.Row():
        examples = [[imgpath, 0.35] for imgpath in example_images]
        examples = gr.Examples(
            examples=examples,
            inputs=[img_input, threshold],
        )

    # tell clear button which components to clear
    clear.add([img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general])

    submit.click(
        predict,
        inputs=[img_input, threshold],
        outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general],
        api_name="predict",
    )

if __name__ == "__main__":
    demo.queue(max_size=10)
    if getenv("SPACE_ID", None) is not None:
        demo.launch()
    else:
        demo.launch(
            server_name="0.0.0.0",
            server_port=7871,
            debug=True,
        )