Spaces:
Runtime error
Runtime error
update for v3
Browse files- README.md +11 -5
- app.py +169 -63
- data/selected_tags.csv +0 -0
- tagger/common.py +56 -4
README.md
CHANGED
@@ -9,12 +9,18 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
short_description: A WD Tagger Space for pi-chan to use
|
11 |
preload_from_hub:
|
12 |
-
- SmilingWolf/wd-
|
13 |
-
- SmilingWolf/wd-
|
14 |
-
- SmilingWolf/wd-
|
15 |
-
- SmilingWolf/wd-v1-4-
|
16 |
-
- SmilingWolf/wd-v1-4-
|
|
|
|
|
|
|
17 |
models:
|
|
|
|
|
|
|
18 |
- SmilingWolf/wd-v1-4-moat-tagger-v2
|
19 |
- SmilingWolf/wd-v1-4-swinv2-tagger-v2
|
20 |
- SmilingWolf/wd-v1-4-convnext-tagger-v2
|
|
|
9 |
pinned: false
|
10 |
short_description: A WD Tagger Space for pi-chan to use
|
11 |
preload_from_hub:
|
12 |
+
- SmilingWolf/wd-vit-tagger-v3 model.onnx,selected_tags.csv
|
13 |
+
- SmilingWolf/wd-swinv2-tagger-v3 model.onnx,selected_tags.csv
|
14 |
+
- SmilingWolf/wd-convnext-tagger-v3 model.onnx,selected_tags.csv
|
15 |
+
- SmilingWolf/wd-v1-4-moat-tagger-v2 model.onnx,selected_tags.csv
|
16 |
+
- SmilingWolf/wd-v1-4-swinv2-tagger-v2 model.onnx,selected_tags.csv
|
17 |
+
- SmilingWolf/wd-v1-4-convnext-tagger-v2 model.onnx,selected_tags.csv
|
18 |
+
- SmilingWolf/wd-v1-4-convnextv2-tagger-v2 model.onnx,selected_tags.csv
|
19 |
+
- SmilingWolf/wd-v1-4-vit-tagger-v2 model.onnx,selected_tags.csv
|
20 |
models:
|
21 |
+
- SmilingWolf/wd-vit-tagger-v3
|
22 |
+
- SmilingWolf/wd-swinv2-tagger-v3
|
23 |
+
- SmilingWolf/wd-convnext-tagger-v3
|
24 |
- SmilingWolf/wd-v1-4-moat-tagger-v2
|
25 |
- SmilingWolf/wd-v1-4-swinv2-tagger-v2
|
26 |
- SmilingWolf/wd-v1-4-convnext-tagger-v2
|
app.py
CHANGED
@@ -7,25 +7,41 @@ import numpy as np
|
|
7 |
import onnxruntime as rt
|
8 |
from PIL import Image
|
9 |
|
10 |
-
from tagger.common import LabelData,
|
11 |
from tagger.model import create_session
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
HF_TOKEN = getenv("HF_TOKEN", None)
|
14 |
-
WORK_DIR = Path.cwd().resolve()
|
15 |
|
16 |
MODEL_VARIANTS: dict[str, str] = {
|
17 |
-
"
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
}
|
|
|
|
|
|
|
23 |
|
|
|
|
|
24 |
# allowed extensions
|
25 |
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
|
26 |
|
27 |
-
#
|
28 |
-
IMAGE_SIZE = 448
|
29 |
example_images = sorted(
|
30 |
[
|
31 |
str(x.relative_to(WORK_DIR))
|
@@ -33,34 +49,51 @@ example_images = sorted(
|
|
33 |
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
|
34 |
]
|
35 |
)
|
36 |
-
loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k, _ in MODEL_VARIANTS.items()}
|
37 |
|
38 |
|
39 |
-
def load_model(variant: str) -> rt.InferenceSession:
|
40 |
global loaded_models
|
41 |
|
42 |
# resolve the repo name
|
43 |
-
model_repo = MODEL_VARIANTS.get(variant, None)
|
44 |
if model_repo is None:
|
45 |
-
raise ValueError(f"Unknown model variant: {variant}")
|
46 |
|
47 |
-
|
|
|
48 |
# save model to cache
|
49 |
-
loaded_models[
|
|
|
|
|
|
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
|
54 |
def predict(
|
55 |
image: Image.Image,
|
|
|
56 |
variant: str,
|
57 |
-
|
58 |
-
|
|
|
|
|
59 |
):
|
60 |
-
#
|
61 |
-
model: rt.InferenceSession = load_model(variant)
|
62 |
# load labels
|
63 |
-
labels: LabelData =
|
64 |
|
65 |
# get input size and name
|
66 |
_, h, w, _ = model.get_inputs()[0].shape
|
@@ -85,13 +118,21 @@ def predict(
|
|
85 |
rating_labels = dict([probs[i] for i in labels.rating])
|
86 |
|
87 |
# General labels, pick any where prediction confidence > threshold
|
|
|
|
|
|
|
|
|
88 |
gen_labels = [probs[i] for i in labels.general]
|
89 |
-
gen_labels = dict([x for x in gen_labels if x[1] >
|
90 |
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
|
91 |
|
92 |
# Character labels, pick any where prediction confidence > threshold
|
|
|
|
|
|
|
|
|
93 |
char_labels = [probs[i] for i in labels.character]
|
94 |
-
char_labels = dict([x for x in char_labels if x[1] >
|
95 |
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
|
96 |
|
97 |
# Combine general and character labels, sort by confidence
|
@@ -102,64 +143,129 @@ def predict(
|
|
102 |
caption = ", ".join(combined_names)
|
103 |
booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
|
104 |
|
105 |
-
return image, caption, booru, rating_labels, char_labels, gen_labels
|
106 |
|
107 |
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
with gr.Row(equal_height=False):
|
110 |
-
with gr.Column():
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
with gr.Row():
|
122 |
-
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
123 |
clear = gr.ClearButton(
|
124 |
components=[],
|
125 |
variant="secondary",
|
126 |
size="lg",
|
127 |
)
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
],
|
135 |
-
inputs=[img_input, variant, gen_thresh, char_thresh],
|
136 |
-
)
|
137 |
-
with gr.Column():
|
138 |
-
img_output = gr.Image(label="Preprocessed", type="pil", image_mode="RGB", scale=1, visible=False)
|
139 |
with gr.Group():
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
)
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
# tell clear button which components to clear
|
151 |
-
clear.add([img_input, img_output,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
# show/hide processed image
|
154 |
-
def
|
155 |
-
return gr.update(visible=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
-
|
|
|
158 |
|
159 |
submit.click(
|
160 |
predict,
|
161 |
-
inputs=[img_input, variant,
|
162 |
-
outputs=[img_output,
|
163 |
api_name="predict",
|
164 |
)
|
165 |
|
|
|
7 |
import onnxruntime as rt
|
8 |
from PIL import Image
|
9 |
|
10 |
+
from tagger.common import LabelData, load_labels_hf, preprocess_image
|
11 |
from tagger.model import create_session
|
12 |
|
13 |
+
TITLE = "WaifuDiffusion Tagger"
|
14 |
+
DESCRIPTION = """
|
15 |
+
Tag images with the WaifuDiffusion Tagger models!
|
16 |
+
|
17 |
+
Primarily used as a backend for a Discord bot.
|
18 |
+
"""
|
19 |
HF_TOKEN = getenv("HF_TOKEN", None)
|
|
|
20 |
|
21 |
MODEL_VARIANTS: dict[str, str] = {
|
22 |
+
"v3": {
|
23 |
+
"SwinV2": "SmilingWolf/wd-swinv2-tagger-v3",
|
24 |
+
"ConvNeXT": "SmilingWolf/wd-convnext-tagger-v3",
|
25 |
+
"ViT": "SmilingWolf/wd-vit-tagger-v3",
|
26 |
+
},
|
27 |
+
"v2": {
|
28 |
+
"MOAT": "SmilingWolf/wd-v1-4-moat-tagger-v2",
|
29 |
+
"SwinV2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2",
|
30 |
+
"ConvNeXT": "SmilingWolf/wd-v1-4-convnext-tagger-v2",
|
31 |
+
"ConvNeXTv2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
|
32 |
+
"ViT": "SmilingWolf/wd-v1-4-vit-tagger-v2",
|
33 |
+
},
|
34 |
}
|
35 |
+
# prepopulate cache keys in model cache
|
36 |
+
cache_keys = ["-".join([x, y]) for x in MODEL_VARIANTS.keys() for y in MODEL_VARIANTS[x].keys()]
|
37 |
+
loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k in cache_keys}
|
38 |
|
39 |
+
# get the repo root (or the current working directory if running in ipython)
|
40 |
+
WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve()
|
41 |
# allowed extensions
|
42 |
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
|
43 |
|
44 |
+
# get the example images
|
|
|
45 |
example_images = sorted(
|
46 |
[
|
47 |
str(x.relative_to(WORK_DIR))
|
|
|
49 |
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
|
50 |
]
|
51 |
)
|
|
|
52 |
|
53 |
|
54 |
+
def load_model(version: str, variant: str) -> rt.InferenceSession:
|
55 |
global loaded_models
|
56 |
|
57 |
# resolve the repo name
|
58 |
+
model_repo = MODEL_VARIANTS.get(version, {}).get(variant, None)
|
59 |
if model_repo is None:
|
60 |
+
raise ValueError(f"Unknown model variant: {version}-{variant}")
|
61 |
|
62 |
+
cache_key = f"{version}-{variant}"
|
63 |
+
if loaded_models.get(cache_key, None) is None:
|
64 |
# save model to cache
|
65 |
+
loaded_models[cache_key] = create_session(model_repo, token=HF_TOKEN)
|
66 |
+
|
67 |
+
return loaded_models[cache_key]
|
68 |
+
|
69 |
|
70 |
+
def mcut_threshold(probs: np.ndarray) -> float:
|
71 |
+
"""
|
72 |
+
Maximum Cut Thresholding (MCut)
|
73 |
+
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
|
74 |
+
for Multi-label Classification. In 11th International Symposium, IDA 2012
|
75 |
+
(pp. 172-183).
|
76 |
+
"""
|
77 |
+
probs = probs[probs.argsort()[::-1]]
|
78 |
+
diffs = probs[:-1] - probs[1:]
|
79 |
+
idx = diffs.argmax()
|
80 |
+
thresh = (probs[idx] + probs[idx + 1]) / 2
|
81 |
+
return float(thresh)
|
82 |
|
83 |
|
84 |
def predict(
|
85 |
image: Image.Image,
|
86 |
+
version: str,
|
87 |
variant: str,
|
88 |
+
gen_threshold: float = 0.35,
|
89 |
+
gen_use_mcut: bool = False,
|
90 |
+
char_threshold: float = 0.85,
|
91 |
+
char_use_mcut: bool = False,
|
92 |
):
|
93 |
+
# join variant for cache key
|
94 |
+
model: rt.InferenceSession = load_model(version, variant)
|
95 |
# load labels
|
96 |
+
labels: LabelData = load_labels_hf(MODEL_VARIANTS[version][variant])
|
97 |
|
98 |
# get input size and name
|
99 |
_, h, w, _ = model.get_inputs()[0].shape
|
|
|
118 |
rating_labels = dict([probs[i] for i in labels.rating])
|
119 |
|
120 |
# General labels, pick any where prediction confidence > threshold
|
121 |
+
if gen_use_mcut:
|
122 |
+
gen_array = np.array([probs[i][1] for i in labels.general])
|
123 |
+
gen_threshold = mcut_threshold(gen_array)
|
124 |
+
|
125 |
gen_labels = [probs[i] for i in labels.general]
|
126 |
+
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
|
127 |
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
|
128 |
|
129 |
# Character labels, pick any where prediction confidence > threshold
|
130 |
+
if char_use_mcut:
|
131 |
+
char_array = np.array([probs[i][1] for i in labels.character])
|
132 |
+
char_threshold = round(mcut_threshold(char_array), 2)
|
133 |
+
|
134 |
char_labels = [probs[i] for i in labels.character]
|
135 |
+
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
|
136 |
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
|
137 |
|
138 |
# Combine general and character labels, sort by confidence
|
|
|
143 |
caption = ", ".join(combined_names)
|
144 |
booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
|
145 |
|
146 |
+
return image, caption, booru, rating_labels, char_labels, char_threshold, gen_labels, gen_threshold
|
147 |
|
148 |
|
149 |
+
css = """
|
150 |
+
#gen_mcut, #char_mcut {
|
151 |
+
padding-top: var(--scale-3);
|
152 |
+
}
|
153 |
+
#gen_threshold.dimmed, #char_threshold.dimmed {
|
154 |
+
filter: brightness(75%);
|
155 |
+
}
|
156 |
+
"""
|
157 |
+
|
158 |
+
with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo:
|
159 |
with gr.Row(equal_height=False):
|
160 |
+
with gr.Column(min_width=720):
|
161 |
+
with gr.Group():
|
162 |
+
img_input = gr.Image(
|
163 |
+
label="Input",
|
164 |
+
type="pil",
|
165 |
+
image_mode="RGB",
|
166 |
+
sources=["upload", "clipboard"],
|
167 |
+
)
|
168 |
+
show_processed = gr.Checkbox(label="Show Preprocessed Image", value=False)
|
169 |
+
with gr.Row():
|
170 |
+
version = gr.Radio(
|
171 |
+
choices=list(MODEL_VARIANTS.keys()),
|
172 |
+
label="Model Version",
|
173 |
+
value="v3",
|
174 |
+
min_width=160,
|
175 |
+
scale=1,
|
176 |
+
) # gen_threshold > div.wrap.hide
|
177 |
+
variant = gr.Radio(
|
178 |
+
choices=list(MODEL_VARIANTS[version.value].keys()),
|
179 |
+
label="Model Variant",
|
180 |
+
value="ConvNeXT",
|
181 |
+
min_width=560,
|
182 |
+
)
|
183 |
+
with gr.Group():
|
184 |
+
with gr.Row():
|
185 |
+
gen_threshold = gr.Slider(
|
186 |
+
minimum=0.0,
|
187 |
+
maximum=1.0,
|
188 |
+
value=0.35,
|
189 |
+
step=0.01,
|
190 |
+
label="General Tag Threshold",
|
191 |
+
scale=5,
|
192 |
+
elem_id="gen_threshold",
|
193 |
+
)
|
194 |
+
gen_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="gen_mcut")
|
195 |
+
with gr.Row():
|
196 |
+
char_threshold = gr.Slider(
|
197 |
+
minimum=0.0,
|
198 |
+
maximum=1.0,
|
199 |
+
value=0.85,
|
200 |
+
step=0.01,
|
201 |
+
label="Character Tag Threshold",
|
202 |
+
scale=5,
|
203 |
+
elem_id="char_threshold",
|
204 |
+
)
|
205 |
+
char_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="char_mcut")
|
206 |
with gr.Row():
|
|
|
207 |
clear = gr.ClearButton(
|
208 |
components=[],
|
209 |
variant="secondary",
|
210 |
size="lg",
|
211 |
)
|
212 |
+
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
213 |
+
|
214 |
+
with gr.Column(min_width=720):
|
215 |
+
img_output = gr.Image(
|
216 |
+
label="Preprocessed Image", type="pil", image_mode="RGB", scale=1, visible=False
|
217 |
+
)
|
|
|
|
|
|
|
|
|
|
|
218 |
with gr.Group():
|
219 |
+
caption = gr.Textbox(label="Caption", show_copy_button=True)
|
220 |
+
tags = gr.Textbox(label="Tags", show_copy_button=True)
|
221 |
+
with gr.Group():
|
222 |
+
rating = gr.Label(label="Rating")
|
223 |
+
with gr.Group():
|
224 |
+
char_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False)
|
225 |
+
character = gr.Label(label="Character")
|
226 |
+
with gr.Group():
|
227 |
+
gen_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False)
|
228 |
+
general = gr.Label(label="General")
|
229 |
+
|
230 |
+
with gr.Row():
|
231 |
+
examples = [[imgpath, 0.35, mc, 0.85, mc] for mc in [False, True] for imgpath in example_images]
|
232 |
+
|
233 |
+
examples = gr.Examples(
|
234 |
+
examples=examples,
|
235 |
+
inputs=[img_input, gen_threshold, gen_mcut, char_threshold, char_mcut],
|
236 |
+
)
|
237 |
|
238 |
# tell clear button which components to clear
|
239 |
+
clear.add([img_input, img_output, caption, rating, character, general])
|
240 |
+
|
241 |
+
def on_select_variant(evt: gr.SelectData, variant: str):
|
242 |
+
if evt.selected:
|
243 |
+
choices = list(MODEL_VARIANTS[variant])
|
244 |
+
return gr.update(choices=choices, value=choices[0])
|
245 |
+
return gr.update()
|
246 |
+
|
247 |
+
version.select(on_select_variant, inputs=[version], outputs=[variant])
|
248 |
|
249 |
# show/hide processed image
|
250 |
+
def on_change_show(val: gr.Checkbox):
|
251 |
+
return gr.update(visible=val)
|
252 |
+
|
253 |
+
show_processed.select(on_change_show, inputs=[show_processed], outputs=[img_output])
|
254 |
+
|
255 |
+
# handle mcut thresholding (auto-calculate threshold from probs, disable slider)
|
256 |
+
def on_change_mcut(val: gr.Checkbox):
|
257 |
+
return (
|
258 |
+
gr.update(interactive=not val, elem_classes=["dimmed"] if val else []),
|
259 |
+
gr.update(visible=val),
|
260 |
+
)
|
261 |
|
262 |
+
gen_mcut.change(on_change_mcut, inputs=[gen_mcut], outputs=[gen_threshold, gen_mcut_out])
|
263 |
+
char_mcut.change(on_change_mcut, inputs=[char_mcut], outputs=[char_threshold, char_mcut_out])
|
264 |
|
265 |
submit.click(
|
266 |
predict,
|
267 |
+
inputs=[img_input, version, variant, gen_threshold, gen_mcut, char_threshold, char_mcut],
|
268 |
+
outputs=[img_output, caption, tags, rating, character, char_threshold, general, gen_threshold],
|
269 |
api_name="predict",
|
270 |
)
|
271 |
|
data/selected_tags.csv
DELETED
The diff for this file is too large to render.
See raw diff
|
|
tagger/common.py
CHANGED
@@ -3,10 +3,12 @@ from dataclasses import asdict, dataclass
|
|
3 |
from functools import lru_cache
|
4 |
from os import PathLike
|
5 |
from pathlib import Path
|
6 |
-
from typing import Any
|
7 |
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
|
|
|
|
10 |
from PIL import Image
|
11 |
|
12 |
|
@@ -36,10 +38,36 @@ class ImageLabels(DictJsonMixin):
|
|
36 |
|
37 |
|
38 |
@lru_cache(maxsize=5)
|
39 |
-
def load_labels(
|
40 |
-
|
|
|
41 |
if not csv_path.is_file():
|
42 |
-
raise FileNotFoundError("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
|
45 |
tag_data = LabelData(
|
@@ -101,3 +129,27 @@ def preprocess_image(
|
|
101 |
image.thumbnail(size_px, Image.BICUBIC)
|
102 |
|
103 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from functools import lru_cache
|
4 |
from os import PathLike
|
5 |
from pathlib import Path
|
6 |
+
from typing import Any, Optional
|
7 |
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from huggingface_hub.utils import HfHubHTTPError
|
12 |
from PIL import Image
|
13 |
|
14 |
|
|
|
38 |
|
39 |
|
40 |
@lru_cache(maxsize=5)
|
41 |
+
def load_labels(version: str = "v3", data_dir: PathLike = "./data") -> LabelData:
|
42 |
+
data_dir = Path(data_dir).resolve()
|
43 |
+
csv_path = data_dir.joinpath(f"selected_tags_{version}.csv")
|
44 |
if not csv_path.is_file():
|
45 |
+
raise FileNotFoundError(f"{csv_path.name} not found in {data_dir}")
|
46 |
+
|
47 |
+
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
|
48 |
+
tag_data = LabelData(
|
49 |
+
names=df["name"].tolist(),
|
50 |
+
rating=list(np.where(df["category"] == 9)[0]),
|
51 |
+
general=list(np.where(df["category"] == 0)[0]),
|
52 |
+
character=list(np.where(df["category"] == 4)[0]),
|
53 |
+
)
|
54 |
+
|
55 |
+
return tag_data
|
56 |
+
|
57 |
+
|
58 |
+
@lru_cache(maxsize=5)
|
59 |
+
def load_labels_hf(
|
60 |
+
repo_id: str,
|
61 |
+
revision: Optional[str] = None,
|
62 |
+
token: Optional[str] = None,
|
63 |
+
) -> LabelData:
|
64 |
+
try:
|
65 |
+
csv_path = hf_hub_download(
|
66 |
+
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
|
67 |
+
)
|
68 |
+
csv_path = Path(csv_path).resolve()
|
69 |
+
except HfHubHTTPError as e:
|
70 |
+
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
|
71 |
|
72 |
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
|
73 |
tag_data = LabelData(
|
|
|
129 |
image.thumbnail(size_px, Image.BICUBIC)
|
130 |
|
131 |
return image
|
132 |
+
|
133 |
+
|
134 |
+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
|
135 |
+
kaomojis = [
|
136 |
+
"0_0",
|
137 |
+
"(o)_(o)",
|
138 |
+
"+_+",
|
139 |
+
"+_-",
|
140 |
+
"._.",
|
141 |
+
"<o>_<o>",
|
142 |
+
"<|>_<|>",
|
143 |
+
"=_=",
|
144 |
+
">_<",
|
145 |
+
"3_3",
|
146 |
+
"6_9",
|
147 |
+
">_o",
|
148 |
+
"@_@",
|
149 |
+
"^_^",
|
150 |
+
"o_o",
|
151 |
+
"u_u",
|
152 |
+
"x_x",
|
153 |
+
"|_|",
|
154 |
+
"||_||",
|
155 |
+
]
|