CodeChris commited on
Commit
09d71ca
·
verified ·
1 Parent(s): 8d7c593

Add tag string format presets and comma-sep option.

Browse files
Files changed (1) hide show
  1. app.py +423 -382
app.py CHANGED
@@ -1,383 +1,424 @@
1
- import argparse
2
- import gradio as gr
3
- import huggingface_hub
4
- import numpy as np
5
- import onnxruntime as rt
6
- import pandas as pd
7
- from PIL import Image
8
-
9
- TITLE = "Image Tagger"
10
- DESCRIPTION = "Modified from: [SmilingWolf/wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger)"
11
-
12
- # Dataset v3 series of models:
13
- SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
14
- CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
15
- VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
16
- VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
17
- EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
18
-
19
- # Dataset v2 series of models:
20
- # MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
21
- # SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
22
- # CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
23
- # CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
24
- # VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
25
-
26
- # Files to download from the repos
27
- MODEL_FILENAME = "model.onnx"
28
- LABEL_FILENAME = "selected_tags.csv"
29
-
30
- # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
31
- kaomojis = [
32
- "0_0",
33
- "(o)_(o)",
34
- "+_+",
35
- "+_-",
36
- "._.",
37
- "<o>_<o>",
38
- "<|>_<|>",
39
- "=_=",
40
- ">_<",
41
- "3_3",
42
- "6_9",
43
- ">_o",
44
- "@_@",
45
- "^_^",
46
- "o_o",
47
- "u_u",
48
- "x_x",
49
- "|_|",
50
- "||_||",
51
- ]
52
-
53
-
54
- def parse_args() -> argparse.Namespace:
55
- parser = argparse.ArgumentParser()
56
- parser.add_argument("--score-slider-step", type=float, default=0.05)
57
- parser.add_argument("--score-general-threshold", type=float, default=0.35)
58
- parser.add_argument("--score-character-threshold", type=float, default=0.80)
59
- parser.add_argument("--sort-tag-string-by-confidence", action="store_true")
60
- parser.add_argument("--share", action="store_true")
61
- return parser.parse_args()
62
-
63
-
64
- def load_labels(dataframe) -> list[str]:
65
- name_series = dataframe["name"]
66
- name_series = name_series.map(
67
- lambda x: x.replace("_", " ") if x not in kaomojis else x
68
- )
69
- tag_names = name_series.tolist()
70
-
71
- rating_indexes = list(np.where(dataframe["category"] == 9)[0])
72
- general_indexes = list(np.where(dataframe["category"] == 0)[0])
73
- character_indexes = list(np.where(dataframe["category"] == 4)[0])
74
- return tag_names, rating_indexes, general_indexes, character_indexes
75
-
76
-
77
- def mcut_threshold(probs):
78
- """
79
- Maximum Cut Thresholding (MCut)
80
- Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
81
- for Multi-label Classification. In 11th International Symposium, IDA 2012
82
- (pp. 172-183).
83
- """
84
- sorted_probs = probs[probs.argsort()[::-1]]
85
- difs = sorted_probs[:-1] - sorted_probs[1:]
86
- t = difs.argmax()
87
- thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
88
- return thresh
89
-
90
-
91
- class Predictor:
92
- def __init__(self):
93
- self.model_target_size = None
94
- self.last_loaded_repo = None
95
-
96
- def download_model(self, model_repo):
97
- csv_path = huggingface_hub.hf_hub_download(
98
- model_repo,
99
- LABEL_FILENAME,
100
- )
101
- model_path = huggingface_hub.hf_hub_download(
102
- model_repo,
103
- MODEL_FILENAME,
104
- )
105
- return csv_path, model_path
106
-
107
- def load_model(self, model_repo):
108
- if model_repo == self.last_loaded_repo:
109
- return
110
-
111
- csv_path, model_path = self.download_model(model_repo)
112
-
113
- tags_df = pd.read_csv(csv_path)
114
- sep_tags = load_labels(tags_df)
115
-
116
- self.tag_names = sep_tags[0]
117
- self.rating_indexes = sep_tags[1]
118
- self.general_indexes = sep_tags[2]
119
- self.character_indexes = sep_tags[3]
120
-
121
- model = rt.InferenceSession(model_path)
122
- _, height, width, _ = model.get_inputs()[0].shape
123
- self.model_target_size = height
124
-
125
- self.last_loaded_repo = model_repo
126
- self.model = model
127
-
128
- def prepare_image(self, image):
129
- target_size = self.model_target_size
130
-
131
- canvas = Image.new("RGBA", image.size, (255, 255, 255))
132
- canvas.alpha_composite(image)
133
- image = canvas.convert("RGB")
134
-
135
- # Pad image to square
136
- image_shape = image.size
137
- max_dim = max(image_shape)
138
- pad_left = (max_dim - image_shape[0]) // 2
139
- pad_top = (max_dim - image_shape[1]) // 2
140
-
141
- padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
142
- padded_image.paste(image, (pad_left, pad_top))
143
-
144
- # Resize
145
- if max_dim != target_size:
146
- padded_image = padded_image.resize(
147
- (target_size, target_size),
148
- Image.BICUBIC,
149
- )
150
-
151
- # Convert to numpy array
152
- image_array = np.asarray(padded_image, dtype=np.float32)
153
-
154
- # Convert PIL-native RGB to BGR
155
- image_array = image_array[:, :, ::-1]
156
-
157
- return np.expand_dims(image_array, axis=0)
158
-
159
- def tag_dict_to_sorted_string(self, dict_res: dict, sort_by_confidence, descending,
160
- remove_underlines, escape_parens):
161
- """Custom function: Sort tag dict by confidence/alphabetically"""
162
- if sort_by_confidence:
163
- _sorted_list = sorted(
164
- dict_res.items(),
165
- key=lambda x: x[1],
166
- reverse=descending
167
- )
168
- else:
169
- _sorted_list = sorted(
170
- dict_res.items(),
171
- reverse=descending
172
- )
173
- if remove_underlines:
174
- _sorted_string = ", ".join([x[0] for x in _sorted_list])
175
- else: # Add back underlines
176
- _sorted_string = ", ".join([x[0].replace(" ", "_") for x in _sorted_list])
177
- if escape_parens:
178
- _sorted_string = _sorted_string.replace("(", "\\(").replace(")", "\\)")
179
- return _sorted_string
180
-
181
- def predict(
182
- self,
183
- image,
184
- model_repo,
185
- general_thresh,
186
- general_mcut_enabled,
187
- character_thresh,
188
- character_mcut_enabled,
189
- sort_by_confidence_enabled,
190
- sort_descending_enabled,
191
- remove_underline_enabled,
192
- escape_parens_enabled
193
- ):
194
- self.load_model(model_repo)
195
-
196
- image = self.prepare_image(image)
197
-
198
- input_name = self.model.get_inputs()[0].name
199
- label_name = self.model.get_outputs()[0].name
200
- preds = self.model.run([label_name], {input_name: image})[0]
201
-
202
- labels = list(zip(self.tag_names, preds[0].astype(float)))
203
-
204
- # First 4 labels are actually ratings: pick one with argmax
205
- ratings_names = [labels[i] for i in self.rating_indexes]
206
- rating = dict(ratings_names)
207
-
208
- # Then we have general tags: pick any where prediction confidence > threshold
209
- general_names = [labels[i] for i in self.general_indexes]
210
-
211
- if general_mcut_enabled:
212
- general_probs = np.array([x[1] for x in general_names])
213
- general_thresh = mcut_threshold(general_probs)
214
-
215
- general_res = [x for x in general_names if x[1] > general_thresh]
216
- general_res = dict(general_res)
217
-
218
- # Everything else is characters: pick any where prediction confidence > threshold
219
- character_names = [labels[i] for i in self.character_indexes]
220
-
221
- if character_mcut_enabled:
222
- character_probs = np.array([x[1] for x in character_names])
223
- character_thresh = mcut_threshold(character_probs)
224
- character_thresh = max(0.15, character_thresh)
225
-
226
- character_res = [x for x in character_names if x[1] > character_thresh]
227
- character_res = dict(character_res)
228
-
229
- sorted_general_strings = self.tag_dict_to_sorted_string(
230
- general_res,
231
- sort_by_confidence=sort_by_confidence_enabled,
232
- descending=sort_descending_enabled,
233
- remove_underlines=remove_underline_enabled,
234
- escape_parens=escape_parens_enabled
235
- )
236
- sorted_character_strings = self.tag_dict_to_sorted_string(
237
- character_res,
238
- sort_by_confidence=sort_by_confidence_enabled,
239
- descending=sort_descending_enabled,
240
- remove_underlines=remove_underline_enabled,
241
- escape_parens=escape_parens_enabled
242
- )
243
-
244
- return sorted_general_strings, sorted_character_strings, rating, character_res, general_res
245
-
246
-
247
- def main():
248
- args = parse_args()
249
-
250
- predictor = Predictor()
251
-
252
- dropdown_list = [
253
- SWINV2_MODEL_DSV3_REPO,
254
- CONV_MODEL_DSV3_REPO,
255
- VIT_MODEL_DSV3_REPO,
256
- VIT_LARGE_MODEL_DSV3_REPO,
257
- EVA02_LARGE_MODEL_DSV3_REPO,
258
- # MOAT_MODEL_DSV2_REPO,
259
- # SWIN_MODEL_DSV2_REPO,
260
- # CONV_MODEL_DSV2_REPO,
261
- # CONV2_MODEL_DSV2_REPO,
262
- # VIT_MODEL_DSV2_REPO,
263
- ]
264
-
265
- with gr.Blocks(title=TITLE, theme=gr.themes.Soft(primary_hue="teal")) as demo:
266
- with gr.Column():
267
- gr.Markdown(
268
- value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
269
- )
270
- gr.Markdown(value=DESCRIPTION)
271
- with gr.Row():
272
- with gr.Column(variant="panel"):
273
- image = gr.Image(type="pil", image_mode="RGBA", label="Input")
274
- model_repo = gr.Dropdown(
275
- dropdown_list,
276
- value=SWINV2_MODEL_DSV3_REPO,
277
- label="Model",
278
- )
279
- with gr.Row():
280
- general_thresh = gr.Slider(
281
- 0,
282
- 1,
283
- step=args.score_slider_step,
284
- value=args.score_general_threshold,
285
- label="General Tags Threshold",
286
- scale=3,
287
- )
288
- general_mcut_enabled = gr.Checkbox(
289
- value=False,
290
- label="Use MCut threshold",
291
- scale=1,
292
- )
293
- with gr.Row():
294
- character_thresh = gr.Slider(
295
- 0,
296
- 1,
297
- step=args.score_slider_step,
298
- value=args.score_character_threshold,
299
- label="Character Tags Threshold",
300
- scale=3,
301
- )
302
- character_mcut_enabled = gr.Checkbox(
303
- value=False,
304
- label="Use MCut threshold",
305
- scale=1,
306
- )
307
- with gr.Row():
308
- clear = gr.ClearButton(
309
- components=[
310
- image,
311
- model_repo,
312
- general_thresh,
313
- general_mcut_enabled,
314
- character_thresh,
315
- character_mcut_enabled,
316
- ],
317
- variant="secondary",
318
- size="lg",
319
- )
320
- submit = gr.Button(value="Submit", variant="primary", size="lg")
321
- with gr.Column(variant="panel"):
322
- with gr.Row():
323
- sort_by_confidence_enabled = gr.Checkbox(
324
- value=True if args.sort_tag_string_by_confidence else False,
325
- label="Sort By Confidence"
326
- )
327
- sort_descending_enabled = gr.Checkbox(
328
- value=False,
329
- label="Descending"
330
- )
331
- with gr.Row():
332
- remove_underline_enabled = gr.Checkbox(
333
- value=True,
334
- label="Remove Tag Underlines"
335
- )
336
- escape_parens_enabled = gr.Checkbox(
337
- value=False,
338
- label="Escape Parens"
339
- )
340
- sorted_general_strings = gr.Textbox(
341
- label="Output (string)",
342
- show_copy_button=True
343
- )
344
- sorted_character_strings = gr.Textbox(
345
- label="Characters (string)",
346
- show_copy_button=True
347
- )
348
- rating = gr.Label(label="Rating")
349
- character_res = gr.Label(label="Output (characters)")
350
- general_res = gr.Label(label="Output (tags)")
351
- clear.add(
352
- [
353
- sorted_general_strings,
354
- rating,
355
- character_res,
356
- general_res,
357
- ]
358
- )
359
-
360
- submit.click(
361
- predictor.predict,
362
- inputs=[
363
- image,
364
- model_repo,
365
- general_thresh,
366
- general_mcut_enabled,
367
- character_thresh,
368
- character_mcut_enabled,
369
- sort_by_confidence_enabled,
370
- sort_descending_enabled,
371
- remove_underline_enabled,
372
- escape_parens_enabled
373
- ],
374
- outputs=[sorted_general_strings, sorted_character_strings,
375
- rating, character_res, general_res],
376
- )
377
-
378
- demo.queue(max_size=10)
379
- demo.launch(share=args.share)
380
-
381
-
382
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  main()
 
1
+ import argparse
2
+ import gradio as gr
3
+ import huggingface_hub
4
+ import numpy as np
5
+ import onnxruntime as rt
6
+ import pandas as pd
7
+ from PIL import Image
8
+
9
+ TITLE = "Image Tagger"
10
+ DESCRIPTION = "Modified from: [SmilingWolf/wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger) (8279aed)"
11
+
12
+ # Dataset v3 series of models:
13
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
14
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
15
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
16
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
17
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
18
+
19
+ # Dataset v2 series of models:
20
+ # MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
21
+ # SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
22
+ # CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
23
+ # CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
24
+ # VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
25
+
26
+ # Files to download from the repos
27
+ MODEL_FILENAME = "model.onnx"
28
+ LABEL_FILENAME = "selected_tags.csv"
29
+
30
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
31
+ kaomojis = [
32
+ "0_0",
33
+ "(o)_(o)",
34
+ "+_+",
35
+ "+_-",
36
+ "._.",
37
+ "<o>_<o>",
38
+ "<|>_<|>",
39
+ "=_=",
40
+ ">_<",
41
+ "3_3",
42
+ "6_9",
43
+ ">_o",
44
+ "@_@",
45
+ "^_^",
46
+ "o_o",
47
+ "u_u",
48
+ "x_x",
49
+ "|_|",
50
+ "||_||",
51
+ ]
52
+
53
+
54
+ def parse_args() -> argparse.Namespace:
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
57
+ parser.add_argument("--score-general-threshold", type=float, default=0.35)
58
+ parser.add_argument("--score-character-threshold", type=float, default=0.80)
59
+ parser.add_argument("--sort-tag-string-by-confidence", action="store_true")
60
+ parser.add_argument("--share", action="store_true")
61
+ return parser.parse_args()
62
+
63
+
64
+ def load_labels(dataframe) -> list[str]:
65
+ name_series = dataframe["name"]
66
+ name_series = name_series.map(
67
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
68
+ )
69
+ tag_names = name_series.tolist()
70
+
71
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
72
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
73
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
74
+ return tag_names, rating_indexes, general_indexes, character_indexes
75
+
76
+
77
+ def mcut_threshold(probs):
78
+ """
79
+ Maximum Cut Thresholding (MCut)
80
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
81
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
82
+ (pp. 172-183).
83
+ """
84
+ sorted_probs = probs[probs.argsort()[::-1]]
85
+ difs = sorted_probs[:-1] - sorted_probs[1:]
86
+ t = difs.argmax()
87
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
88
+ return thresh
89
+
90
+
91
+ class Predictor:
92
+ def __init__(self):
93
+ self.model_target_size = None
94
+ self.last_loaded_repo = None
95
+
96
+ def download_model(self, model_repo):
97
+ csv_path = huggingface_hub.hf_hub_download(
98
+ model_repo,
99
+ LABEL_FILENAME,
100
+ )
101
+ model_path = huggingface_hub.hf_hub_download(
102
+ model_repo,
103
+ MODEL_FILENAME,
104
+ )
105
+ return csv_path, model_path
106
+
107
+ def load_model(self, model_repo):
108
+ if model_repo == self.last_loaded_repo:
109
+ return
110
+
111
+ csv_path, model_path = self.download_model(model_repo)
112
+
113
+ tags_df = pd.read_csv(csv_path)
114
+ sep_tags = load_labels(tags_df)
115
+
116
+ self.tag_names = sep_tags[0]
117
+ self.rating_indexes = sep_tags[1]
118
+ self.general_indexes = sep_tags[2]
119
+ self.character_indexes = sep_tags[3]
120
+
121
+ model = rt.InferenceSession(model_path)
122
+ _, height, width, _ = model.get_inputs()[0].shape
123
+ self.model_target_size = height
124
+
125
+ self.last_loaded_repo = model_repo
126
+ self.model = model
127
+
128
+ def prepare_image(self, image):
129
+ target_size = self.model_target_size
130
+
131
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
132
+ canvas.alpha_composite(image)
133
+ image = canvas.convert("RGB")
134
+
135
+ # Pad image to square
136
+ image_shape = image.size
137
+ max_dim = max(image_shape)
138
+ pad_left = (max_dim - image_shape[0]) // 2
139
+ pad_top = (max_dim - image_shape[1]) // 2
140
+
141
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
142
+ padded_image.paste(image, (pad_left, pad_top))
143
+
144
+ # Resize
145
+ if max_dim != target_size:
146
+ padded_image = padded_image.resize(
147
+ (target_size, target_size),
148
+ Image.BICUBIC,
149
+ )
150
+
151
+ # Convert to numpy array
152
+ image_array = np.asarray(padded_image, dtype=np.float32)
153
+
154
+ # Convert PIL-native RGB to BGR
155
+ image_array = image_array[:, :, ::-1]
156
+
157
+ return np.expand_dims(image_array, axis=0)
158
+
159
+ def tag_dict_to_sorted_string(self, dict_res: dict, sort_by_confidence, descending,
160
+ remove_underlines, escape_parens, comma_sep):
161
+ """Custom function: Sort tag dict by confidence/alphabetically"""
162
+ sep = ', ' if comma_sep else ' '
163
+ if sort_by_confidence:
164
+ _sorted_list = sorted(
165
+ dict_res.items(),
166
+ key=lambda x: x[1],
167
+ reverse=descending
168
+ )
169
+ else:
170
+ _sorted_list = sorted(
171
+ dict_res.items(),
172
+ reverse=descending
173
+ )
174
+ if remove_underlines:
175
+ _sorted_string = sep.join([x[0] for x in _sorted_list])
176
+ else: # Add back underlines
177
+ _sorted_string = sep.join([x[0].replace(" ", "_") for x in _sorted_list])
178
+ if escape_parens:
179
+ _sorted_string = _sorted_string.replace("(", "\\(").replace(")", "\\)")
180
+ return _sorted_string
181
+
182
+ def predict(
183
+ self,
184
+ image,
185
+ model_repo,
186
+ general_thresh,
187
+ general_mcut_enabled,
188
+ character_thresh,
189
+ character_mcut_enabled,
190
+ sort_by_confidence_enabled,
191
+ sort_descending_enabled,
192
+ preset_checkboxgroup
193
+ ):
194
+ # Decouple the checkgroup status into 3
195
+ remove_underline_enabled, escape_parens_enabled, comma_sep_enabled = [
196
+ True if i in preset_checkboxgroup else False
197
+ for i in range(3)
198
+ ]
199
+
200
+ self.load_model(model_repo)
201
+
202
+ image = self.prepare_image(image)
203
+
204
+ input_name = self.model.get_inputs()[0].name
205
+ label_name = self.model.get_outputs()[0].name
206
+ preds = self.model.run([label_name], {input_name: image})[0]
207
+
208
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
209
+
210
+ # First 4 labels are actually ratings: pick one with argmax
211
+ ratings_names = [labels[i] for i in self.rating_indexes]
212
+ rating = dict(ratings_names)
213
+
214
+ # Then we have general tags: pick any where prediction confidence > threshold
215
+ general_names = [labels[i] for i in self.general_indexes]
216
+
217
+ if general_mcut_enabled:
218
+ general_probs = np.array([x[1] for x in general_names])
219
+ general_thresh = mcut_threshold(general_probs)
220
+
221
+ general_res = [x for x in general_names if x[1] > general_thresh]
222
+ general_res = dict(general_res)
223
+
224
+ # Everything else is characters: pick any where prediction confidence > threshold
225
+ character_names = [labels[i] for i in self.character_indexes]
226
+
227
+ if character_mcut_enabled:
228
+ character_probs = np.array([x[1] for x in character_names])
229
+ character_thresh = mcut_threshold(character_probs)
230
+ character_thresh = max(0.15, character_thresh)
231
+
232
+ character_res = [x for x in character_names if x[1] > character_thresh]
233
+ character_res = dict(character_res)
234
+
235
+ sorted_general_strings = self.tag_dict_to_sorted_string(
236
+ general_res,
237
+ sort_by_confidence=sort_by_confidence_enabled,
238
+ descending=sort_descending_enabled,
239
+ remove_underlines=remove_underline_enabled,
240
+ escape_parens=escape_parens_enabled,
241
+ comma_sep=comma_sep_enabled
242
+ )
243
+ sorted_character_strings = self.tag_dict_to_sorted_string(
244
+ character_res,
245
+ sort_by_confidence=sort_by_confidence_enabled,
246
+ descending=sort_descending_enabled,
247
+ remove_underlines=remove_underline_enabled,
248
+ escape_parens=escape_parens_enabled,
249
+ comma_sep=comma_sep_enabled
250
+ )
251
+
252
+ return sorted_general_strings, sorted_character_strings, rating, character_res, general_res
253
+
254
+
255
+ def main():
256
+ args = parse_args()
257
+
258
+ predictor = Predictor()
259
+
260
+ dropdown_list = [
261
+ SWINV2_MODEL_DSV3_REPO,
262
+ CONV_MODEL_DSV3_REPO,
263
+ VIT_MODEL_DSV3_REPO,
264
+ VIT_LARGE_MODEL_DSV3_REPO,
265
+ EVA02_LARGE_MODEL_DSV3_REPO,
266
+ # MOAT_MODEL_DSV2_REPO,
267
+ # SWIN_MODEL_DSV2_REPO,
268
+ # CONV_MODEL_DSV2_REPO,
269
+ # CONV2_MODEL_DSV2_REPO,
270
+ # VIT_MODEL_DSV2_REPO,
271
+ ]
272
+
273
+ # Define widget udpate functions
274
+
275
+ PRESET_CHECKBOX_CHOICES = ["Remove Underlines", "Escape Parens", "Comma Separator"]
276
+ PRESET_CHECKBOX_DICT = {
277
+ "Normal": [PRESET_CHECKBOX_CHOICES[i] for i in[0, 2]],
278
+ "Booru": []
279
+ }
280
+
281
+ def update_preset_checkboxes(preset_radio, preset_checkbox_indices):
282
+ """Change checkboxgroup according to the radio selected preset."""
283
+ current_checks = [PRESET_CHECKBOX_CHOICES[i] for i in preset_checkbox_indices]
284
+ updated_checks = PRESET_CHECKBOX_DICT.get(preset_radio, current_checks)
285
+ return updated_checks
286
+
287
+ def update_tag_preset():
288
+ """Whenever the checkboxgroup is manually changed, set preset to 'Custom'."""
289
+ return "Custom"
290
+
291
+ with gr.Blocks(title=TITLE, theme=gr.themes.Soft(primary_hue="teal")) as demo:
292
+ with gr.Column():
293
+ gr.Markdown(
294
+ value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
295
+ )
296
+ gr.Markdown(value=DESCRIPTION)
297
+ with gr.Row():
298
+ with gr.Column(variant="panel"):
299
+ submit = gr.Button(value="Submit", variant="primary")
300
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
301
+ model_repo = gr.Dropdown(
302
+ dropdown_list,
303
+ value=SWINV2_MODEL_DSV3_REPO,
304
+ label="Model",
305
+ )
306
+ with gr.Row():
307
+ general_thresh = gr.Slider(
308
+ 0,
309
+ 1,
310
+ step=args.score_slider_step,
311
+ value=args.score_general_threshold,
312
+ label="General Tags Threshold",
313
+ scale=3,
314
+ )
315
+ general_mcut_enabled = gr.Checkbox(
316
+ value=False,
317
+ label="Use MCut threshold",
318
+ scale=1,
319
+ )
320
+ with gr.Row():
321
+ character_thresh = gr.Slider(
322
+ 0,
323
+ 1,
324
+ step=args.score_slider_step,
325
+ value=args.score_character_threshold,
326
+ label="Character Tags Threshold",
327
+ scale=3,
328
+ )
329
+ character_mcut_enabled = gr.Checkbox(
330
+ value=False,
331
+ label="Use MCut threshold",
332
+ scale=1,
333
+ )
334
+ with gr.Row():
335
+ clear = gr.ClearButton(
336
+ components=[
337
+ image,
338
+ model_repo,
339
+ general_thresh,
340
+ general_mcut_enabled,
341
+ character_thresh,
342
+ character_mcut_enabled,
343
+ ],
344
+ variant="secondary"
345
+ )
346
+ with gr.Column(variant="panel"):
347
+ default_tag_preset = "Normal"
348
+ with gr.Row():
349
+ tag_format_preset = gr.Radio(
350
+ ["Normal", "Booru", "Custom"],
351
+ value=default_tag_preset,
352
+ label="Tagging Format Presets"
353
+ )
354
+ with gr.Row():
355
+ preset_checkboxgroup = gr.CheckboxGroup(
356
+ choices=PRESET_CHECKBOX_CHOICES,
357
+ value=PRESET_CHECKBOX_DICT[default_tag_preset],
358
+ type='index',
359
+ show_label=False
360
+ )
361
+
362
+ with gr.Row():
363
+ sort_by_confidence_enabled = gr.Checkbox(
364
+ value=True if args.sort_tag_string_by_confidence else False,
365
+ label="Sort By Confidence"
366
+ )
367
+ sort_descending_enabled = gr.Checkbox(
368
+ value=False,
369
+ label="Descending"
370
+ )
371
+ sorted_general_strings = gr.Textbox(
372
+ label="Output (string)",
373
+ show_copy_button=True
374
+ )
375
+ sorted_character_strings = gr.Textbox(
376
+ label="Characters (string)",
377
+ show_copy_button=True
378
+ )
379
+ rating = gr.Label(label="Rating")
380
+ character_res = gr.Label(label="Output (characters)")
381
+ general_res = gr.Label(label="Output (tags)")
382
+ clear.add(
383
+ [
384
+ sorted_general_strings,
385
+ rating,
386
+ character_res,
387
+ general_res,
388
+ ]
389
+ )
390
+
391
+ # Update gradio widgets
392
+ tag_format_preset.change(
393
+ fn=update_preset_checkboxes,
394
+ inputs=[tag_format_preset, preset_checkboxgroup],
395
+ outputs=preset_checkboxgroup
396
+ )
397
+ preset_checkboxgroup.input(
398
+ fn=update_tag_preset,
399
+ outputs=tag_format_preset
400
+ )
401
+
402
+ submit.click(
403
+ predictor.predict,
404
+ inputs=[
405
+ image,
406
+ model_repo,
407
+ general_thresh,
408
+ general_mcut_enabled,
409
+ character_thresh,
410
+ character_mcut_enabled,
411
+ sort_by_confidence_enabled,
412
+ sort_descending_enabled,
413
+ preset_checkboxgroup
414
+ ],
415
+ outputs=[sorted_general_strings, sorted_character_strings,
416
+ rating, character_res, general_res],
417
+ )
418
+
419
+ demo.queue(max_size=10)
420
+ demo.launch(share=args.share)
421
+
422
+
423
+ if __name__ == "__main__":
424
  main()