CodeChris commited on
Commit
c348d43
·
verified ·
1 Parent(s): 53d40a7

Customize app.py:

Browse files

* Apply gradio theme
* Remove v2 models; remove tagging example image
* Add function tag_dict_to_sorted_string() for reuse
* Add Textbox: sorted_character_strings
* Add Checkboxes: sort_by_confidence, sort_descending, escape_parens
* Add a copy button to tag outputs (`sorted_general_strings`)
* Adjust default tag & character threshold
* Edit Title and Description

Files changed (1) hide show
  1. app.py +372 -339
app.py CHANGED
@@ -1,339 +1,372 @@
1
- import argparse
2
- import os
3
-
4
- import gradio as gr
5
- import huggingface_hub
6
- import numpy as np
7
- import onnxruntime as rt
8
- import pandas as pd
9
- from PIL import Image
10
-
11
- TITLE = "WaifuDiffusion Tagger"
12
- DESCRIPTION = """
13
- Demo for the WaifuDiffusion tagger models
14
-
15
- Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
16
- """
17
-
18
- # Dataset v3 series of models:
19
- SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
20
- CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
21
- VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
22
- VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
23
- EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
24
-
25
- # Dataset v2 series of models:
26
- MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
27
- SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
28
- CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
- CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
- VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
-
32
- # Files to download from the repos
33
- MODEL_FILENAME = "model.onnx"
34
- LABEL_FILENAME = "selected_tags.csv"
35
-
36
- # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
37
- kaomojis = [
38
- "0_0",
39
- "(o)_(o)",
40
- "+_+",
41
- "+_-",
42
- "._.",
43
- "<o>_<o>",
44
- "<|>_<|>",
45
- "=_=",
46
- ">_<",
47
- "3_3",
48
- "6_9",
49
- ">_o",
50
- "@_@",
51
- "^_^",
52
- "o_o",
53
- "u_u",
54
- "x_x",
55
- "|_|",
56
- "||_||",
57
- ]
58
-
59
-
60
- def parse_args() -> argparse.Namespace:
61
- parser = argparse.ArgumentParser()
62
- parser.add_argument("--score-slider-step", type=float, default=0.05)
63
- parser.add_argument("--score-general-threshold", type=float, default=0.35)
64
- parser.add_argument("--score-character-threshold", type=float, default=0.85)
65
- parser.add_argument("--share", action="store_true")
66
- return parser.parse_args()
67
-
68
-
69
- def load_labels(dataframe) -> list[str]:
70
- name_series = dataframe["name"]
71
- name_series = name_series.map(
72
- lambda x: x.replace("_", " ") if x not in kaomojis else x
73
- )
74
- tag_names = name_series.tolist()
75
-
76
- rating_indexes = list(np.where(dataframe["category"] == 9)[0])
77
- general_indexes = list(np.where(dataframe["category"] == 0)[0])
78
- character_indexes = list(np.where(dataframe["category"] == 4)[0])
79
- return tag_names, rating_indexes, general_indexes, character_indexes
80
-
81
-
82
- def mcut_threshold(probs):
83
- """
84
- Maximum Cut Thresholding (MCut)
85
- Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
86
- for Multi-label Classification. In 11th International Symposium, IDA 2012
87
- (pp. 172-183).
88
- """
89
- sorted_probs = probs[probs.argsort()[::-1]]
90
- difs = sorted_probs[:-1] - sorted_probs[1:]
91
- t = difs.argmax()
92
- thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
93
- return thresh
94
-
95
-
96
- class Predictor:
97
- def __init__(self):
98
- self.model_target_size = None
99
- self.last_loaded_repo = None
100
-
101
- def download_model(self, model_repo):
102
- csv_path = huggingface_hub.hf_hub_download(
103
- model_repo,
104
- LABEL_FILENAME,
105
- )
106
- model_path = huggingface_hub.hf_hub_download(
107
- model_repo,
108
- MODEL_FILENAME,
109
- )
110
- return csv_path, model_path
111
-
112
- def load_model(self, model_repo):
113
- if model_repo == self.last_loaded_repo:
114
- return
115
-
116
- csv_path, model_path = self.download_model(model_repo)
117
-
118
- tags_df = pd.read_csv(csv_path)
119
- sep_tags = load_labels(tags_df)
120
-
121
- self.tag_names = sep_tags[0]
122
- self.rating_indexes = sep_tags[1]
123
- self.general_indexes = sep_tags[2]
124
- self.character_indexes = sep_tags[3]
125
-
126
- model = rt.InferenceSession(model_path)
127
- _, height, width, _ = model.get_inputs()[0].shape
128
- self.model_target_size = height
129
-
130
- self.last_loaded_repo = model_repo
131
- self.model = model
132
-
133
- def prepare_image(self, image):
134
- target_size = self.model_target_size
135
-
136
- canvas = Image.new("RGBA", image.size, (255, 255, 255))
137
- canvas.alpha_composite(image)
138
- image = canvas.convert("RGB")
139
-
140
- # Pad image to square
141
- image_shape = image.size
142
- max_dim = max(image_shape)
143
- pad_left = (max_dim - image_shape[0]) // 2
144
- pad_top = (max_dim - image_shape[1]) // 2
145
-
146
- padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
147
- padded_image.paste(image, (pad_left, pad_top))
148
-
149
- # Resize
150
- if max_dim != target_size:
151
- padded_image = padded_image.resize(
152
- (target_size, target_size),
153
- Image.BICUBIC,
154
- )
155
-
156
- # Convert to numpy array
157
- image_array = np.asarray(padded_image, dtype=np.float32)
158
-
159
- # Convert PIL-native RGB to BGR
160
- image_array = image_array[:, :, ::-1]
161
-
162
- return np.expand_dims(image_array, axis=0)
163
-
164
- def predict(
165
- self,
166
- image,
167
- model_repo,
168
- general_thresh,
169
- general_mcut_enabled,
170
- character_thresh,
171
- character_mcut_enabled,
172
- ):
173
- self.load_model(model_repo)
174
-
175
- image = self.prepare_image(image)
176
-
177
- input_name = self.model.get_inputs()[0].name
178
- label_name = self.model.get_outputs()[0].name
179
- preds = self.model.run([label_name], {input_name: image})[0]
180
-
181
- labels = list(zip(self.tag_names, preds[0].astype(float)))
182
-
183
- # First 4 labels are actually ratings: pick one with argmax
184
- ratings_names = [labels[i] for i in self.rating_indexes]
185
- rating = dict(ratings_names)
186
-
187
- # Then we have general tags: pick any where prediction confidence > threshold
188
- general_names = [labels[i] for i in self.general_indexes]
189
-
190
- if general_mcut_enabled:
191
- general_probs = np.array([x[1] for x in general_names])
192
- general_thresh = mcut_threshold(general_probs)
193
-
194
- general_res = [x for x in general_names if x[1] > general_thresh]
195
- general_res = dict(general_res)
196
-
197
- # Everything else is characters: pick any where prediction confidence > threshold
198
- character_names = [labels[i] for i in self.character_indexes]
199
-
200
- if character_mcut_enabled:
201
- character_probs = np.array([x[1] for x in character_names])
202
- character_thresh = mcut_threshold(character_probs)
203
- character_thresh = max(0.15, character_thresh)
204
-
205
- character_res = [x for x in character_names if x[1] > character_thresh]
206
- character_res = dict(character_res)
207
-
208
- sorted_general_strings = sorted(
209
- general_res.items(),
210
- key=lambda x: x[1],
211
- reverse=True,
212
- )
213
- sorted_general_strings = [x[0] for x in sorted_general_strings]
214
- sorted_general_strings = (
215
- ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
216
- )
217
-
218
- return sorted_general_strings, rating, character_res, general_res
219
-
220
-
221
- def main():
222
- args = parse_args()
223
-
224
- predictor = Predictor()
225
-
226
- dropdown_list = [
227
- SWINV2_MODEL_DSV3_REPO,
228
- CONV_MODEL_DSV3_REPO,
229
- VIT_MODEL_DSV3_REPO,
230
- VIT_LARGE_MODEL_DSV3_REPO,
231
- EVA02_LARGE_MODEL_DSV3_REPO,
232
- MOAT_MODEL_DSV2_REPO,
233
- SWIN_MODEL_DSV2_REPO,
234
- CONV_MODEL_DSV2_REPO,
235
- CONV2_MODEL_DSV2_REPO,
236
- VIT_MODEL_DSV2_REPO,
237
- ]
238
-
239
- with gr.Blocks(title=TITLE) as demo:
240
- with gr.Column():
241
- gr.Markdown(
242
- value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
243
- )
244
- gr.Markdown(value=DESCRIPTION)
245
- with gr.Row():
246
- with gr.Column(variant="panel"):
247
- image = gr.Image(type="pil", image_mode="RGBA", label="Input")
248
- model_repo = gr.Dropdown(
249
- dropdown_list,
250
- value=SWINV2_MODEL_DSV3_REPO,
251
- label="Model",
252
- )
253
- with gr.Row():
254
- general_thresh = gr.Slider(
255
- 0,
256
- 1,
257
- step=args.score_slider_step,
258
- value=args.score_general_threshold,
259
- label="General Tags Threshold",
260
- scale=3,
261
- )
262
- general_mcut_enabled = gr.Checkbox(
263
- value=False,
264
- label="Use MCut threshold",
265
- scale=1,
266
- )
267
- with gr.Row():
268
- character_thresh = gr.Slider(
269
- 0,
270
- 1,
271
- step=args.score_slider_step,
272
- value=args.score_character_threshold,
273
- label="Character Tags Threshold",
274
- scale=3,
275
- )
276
- character_mcut_enabled = gr.Checkbox(
277
- value=False,
278
- label="Use MCut threshold",
279
- scale=1,
280
- )
281
- with gr.Row():
282
- clear = gr.ClearButton(
283
- components=[
284
- image,
285
- model_repo,
286
- general_thresh,
287
- general_mcut_enabled,
288
- character_thresh,
289
- character_mcut_enabled,
290
- ],
291
- variant="secondary",
292
- size="lg",
293
- )
294
- submit = gr.Button(value="Submit", variant="primary", size="lg")
295
- with gr.Column(variant="panel"):
296
- sorted_general_strings = gr.Textbox(label="Output (string)")
297
- rating = gr.Label(label="Rating")
298
- character_res = gr.Label(label="Output (characters)")
299
- general_res = gr.Label(label="Output (tags)")
300
- clear.add(
301
- [
302
- sorted_general_strings,
303
- rating,
304
- character_res,
305
- general_res,
306
- ]
307
- )
308
-
309
- submit.click(
310
- predictor.predict,
311
- inputs=[
312
- image,
313
- model_repo,
314
- general_thresh,
315
- general_mcut_enabled,
316
- character_thresh,
317
- character_mcut_enabled,
318
- ],
319
- outputs=[sorted_general_strings, rating, character_res, general_res],
320
- )
321
-
322
- gr.Examples(
323
- [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
324
- inputs=[
325
- image,
326
- model_repo,
327
- general_thresh,
328
- general_mcut_enabled,
329
- character_thresh,
330
- character_mcut_enabled,
331
- ],
332
- )
333
-
334
- demo.queue(max_size=10)
335
- demo.launch()
336
-
337
-
338
- if __name__ == "__main__":
339
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import gradio as gr
5
+ import huggingface_hub
6
+ import numpy as np
7
+ import onnxruntime as rt
8
+ import pandas as pd
9
+ from PIL import Image
10
+
11
+ TITLE = "Image Tagger"
12
+ DESCRIPTION = "Modified from: [SmilingWolf/wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger)"
13
+
14
+ # Dataset v3 series of models:
15
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
16
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
17
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
18
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
19
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
20
+
21
+ # Dataset v2 series of models:
22
+ # MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
23
+ # SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
24
+ # CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
25
+ # CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
26
+ # VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
27
+
28
+ # Files to download from the repos
29
+ MODEL_FILENAME = "model.onnx"
30
+ LABEL_FILENAME = "selected_tags.csv"
31
+
32
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
33
+ kaomojis = [
34
+ "0_0",
35
+ "(o)_(o)",
36
+ "+_+",
37
+ "+_-",
38
+ "._.",
39
+ "<o>_<o>",
40
+ "<|>_<|>",
41
+ "=_=",
42
+ ">_<",
43
+ "3_3",
44
+ "6_9",
45
+ ">_o",
46
+ "@_@",
47
+ "^_^",
48
+ "o_o",
49
+ "u_u",
50
+ "x_x",
51
+ "|_|",
52
+ "||_||",
53
+ ]
54
+
55
+
56
+ def parse_args() -> argparse.Namespace:
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
59
+ parser.add_argument("--score-general-threshold", type=float, default=0.35)
60
+ parser.add_argument("--score-character-threshold", type=float, default=0.80)
61
+ parser.add_argument("--sort-tag-string-by-confidence", action="store_true")
62
+ parser.add_argument("--share", action="store_true")
63
+ return parser.parse_args()
64
+
65
+
66
+ def load_labels(dataframe) -> list[str]:
67
+ name_series = dataframe["name"]
68
+ name_series = name_series.map(
69
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
70
+ )
71
+ tag_names = name_series.tolist()
72
+
73
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
74
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
75
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
76
+ return tag_names, rating_indexes, general_indexes, character_indexes
77
+
78
+
79
+ def mcut_threshold(probs):
80
+ """
81
+ Maximum Cut Thresholding (MCut)
82
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
83
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
84
+ (pp. 172-183).
85
+ """
86
+ sorted_probs = probs[probs.argsort()[::-1]]
87
+ difs = sorted_probs[:-1] - sorted_probs[1:]
88
+ t = difs.argmax()
89
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
90
+ return thresh
91
+
92
+
93
+ class Predictor:
94
+ def __init__(self):
95
+ self.model_target_size = None
96
+ self.last_loaded_repo = None
97
+
98
+ def download_model(self, model_repo):
99
+ csv_path = huggingface_hub.hf_hub_download(
100
+ model_repo,
101
+ LABEL_FILENAME,
102
+ )
103
+ model_path = huggingface_hub.hf_hub_download(
104
+ model_repo,
105
+ MODEL_FILENAME,
106
+ )
107
+ return csv_path, model_path
108
+
109
+ def load_model(self, model_repo):
110
+ if model_repo == self.last_loaded_repo:
111
+ return
112
+
113
+ csv_path, model_path = self.download_model(model_repo)
114
+
115
+ tags_df = pd.read_csv(csv_path)
116
+ sep_tags = load_labels(tags_df)
117
+
118
+ self.tag_names = sep_tags[0]
119
+ self.rating_indexes = sep_tags[1]
120
+ self.general_indexes = sep_tags[2]
121
+ self.character_indexes = sep_tags[3]
122
+
123
+ model = rt.InferenceSession(model_path)
124
+ _, height, width, _ = model.get_inputs()[0].shape
125
+ self.model_target_size = height
126
+
127
+ self.last_loaded_repo = model_repo
128
+ self.model = model
129
+
130
+ def prepare_image(self, image):
131
+ target_size = self.model_target_size
132
+
133
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
134
+ canvas.alpha_composite(image)
135
+ image = canvas.convert("RGB")
136
+
137
+ # Pad image to square
138
+ image_shape = image.size
139
+ max_dim = max(image_shape)
140
+ pad_left = (max_dim - image_shape[0]) // 2
141
+ pad_top = (max_dim - image_shape[1]) // 2
142
+
143
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
144
+ padded_image.paste(image, (pad_left, pad_top))
145
+
146
+ # Resize
147
+ if max_dim != target_size:
148
+ padded_image = padded_image.resize(
149
+ (target_size, target_size),
150
+ Image.BICUBIC,
151
+ )
152
+
153
+ # Convert to numpy array
154
+ image_array = np.asarray(padded_image, dtype=np.float32)
155
+
156
+ # Convert PIL-native RGB to BGR
157
+ image_array = image_array[:, :, ::-1]
158
+
159
+ return np.expand_dims(image_array, axis=0)
160
+
161
+ def tag_dict_to_sorted_string(self, dict_res: dict, sort_by_confidence, descending, escape_parens):
162
+ """Custom function: Sort tag dict by confidence/alphabetically"""
163
+ if sort_by_confidence:
164
+ _sorted_list = sorted(
165
+ dict_res.items(),
166
+ key=lambda x: x[1],
167
+ reverse=descending
168
+ )
169
+ else:
170
+ _sorted_list = sorted(
171
+ dict_res.items(),
172
+ reverse=descending
173
+ )
174
+ _sorted_string = ", ".join([x[0] for x in _sorted_list])
175
+ if escape_parens:
176
+ _sorted_string = _sorted_string.replace("(", "\(").replace(")", "\)")
177
+ return _sorted_string
178
+
179
+ def predict(
180
+ self,
181
+ image,
182
+ model_repo,
183
+ general_thresh,
184
+ general_mcut_enabled,
185
+ character_thresh,
186
+ character_mcut_enabled,
187
+ sort_by_confidence_enabled,
188
+ sort_descending_enabled,
189
+ escape_parens_enabled
190
+ ):
191
+ self.load_model(model_repo)
192
+
193
+ image = self.prepare_image(image)
194
+
195
+ input_name = self.model.get_inputs()[0].name
196
+ label_name = self.model.get_outputs()[0].name
197
+ preds = self.model.run([label_name], {input_name: image})[0]
198
+
199
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
200
+
201
+ # First 4 labels are actually ratings: pick one with argmax
202
+ ratings_names = [labels[i] for i in self.rating_indexes]
203
+ rating = dict(ratings_names)
204
+
205
+ # Then we have general tags: pick any where prediction confidence > threshold
206
+ general_names = [labels[i] for i in self.general_indexes]
207
+
208
+ if general_mcut_enabled:
209
+ general_probs = np.array([x[1] for x in general_names])
210
+ general_thresh = mcut_threshold(general_probs)
211
+
212
+ general_res = [x for x in general_names if x[1] > general_thresh]
213
+ general_res = dict(general_res)
214
+
215
+ # Everything else is characters: pick any where prediction confidence > threshold
216
+ character_names = [labels[i] for i in self.character_indexes]
217
+
218
+ if character_mcut_enabled:
219
+ character_probs = np.array([x[1] for x in character_names])
220
+ character_thresh = mcut_threshold(character_probs)
221
+ character_thresh = max(0.15, character_thresh)
222
+
223
+ character_res = [x for x in character_names if x[1] > character_thresh]
224
+ character_res = dict(character_res)
225
+
226
+ sorted_general_strings = self.tag_dict_to_sorted_string(
227
+ general_res,
228
+ sort_by_confidence=sort_by_confidence_enabled,
229
+ descending=sort_descending_enabled,
230
+ escape_parens=escape_parens_enabled
231
+ )
232
+ sorted_character_strings = self.tag_dict_to_sorted_string(
233
+ character_res,
234
+ sort_by_confidence=sort_by_confidence_enabled,
235
+ descending=sort_descending_enabled,
236
+ escape_parens=escape_parens_enabled
237
+ )
238
+
239
+ return sorted_general_strings, sorted_character_strings, rating, character_res, general_res
240
+
241
+
242
+ def main():
243
+ args = parse_args()
244
+
245
+ predictor = Predictor()
246
+
247
+ dropdown_list = [
248
+ SWINV2_MODEL_DSV3_REPO,
249
+ CONV_MODEL_DSV3_REPO,
250
+ VIT_MODEL_DSV3_REPO,
251
+ VIT_LARGE_MODEL_DSV3_REPO,
252
+ EVA02_LARGE_MODEL_DSV3_REPO,
253
+ # MOAT_MODEL_DSV2_REPO,
254
+ # SWIN_MODEL_DSV2_REPO,
255
+ # CONV_MODEL_DSV2_REPO,
256
+ # CONV2_MODEL_DSV2_REPO,
257
+ # VIT_MODEL_DSV2_REPO,
258
+ ]
259
+
260
+ with gr.Blocks(title=TITLE, theme=gr.themes.Soft(primary_hue="teal")) as demo:
261
+ with gr.Column():
262
+ gr.Markdown(
263
+ value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
264
+ )
265
+ gr.Markdown(value=DESCRIPTION)
266
+ with gr.Row():
267
+ with gr.Column(variant="panel"):
268
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
269
+ model_repo = gr.Dropdown(
270
+ dropdown_list,
271
+ value=SWINV2_MODEL_DSV3_REPO,
272
+ label="Model",
273
+ )
274
+ with gr.Row():
275
+ general_thresh = gr.Slider(
276
+ 0,
277
+ 1,
278
+ step=args.score_slider_step,
279
+ value=args.score_general_threshold,
280
+ label="General Tags Threshold",
281
+ scale=3,
282
+ )
283
+ general_mcut_enabled = gr.Checkbox(
284
+ value=False,
285
+ label="Use MCut threshold",
286
+ scale=1,
287
+ )
288
+ with gr.Row():
289
+ character_thresh = gr.Slider(
290
+ 0,
291
+ 1,
292
+ step=args.score_slider_step,
293
+ value=args.score_character_threshold,
294
+ label="Character Tags Threshold",
295
+ scale=3,
296
+ )
297
+ character_mcut_enabled = gr.Checkbox(
298
+ value=False,
299
+ label="Use MCut threshold",
300
+ scale=1,
301
+ )
302
+ with gr.Row():
303
+ clear = gr.ClearButton(
304
+ components=[
305
+ image,
306
+ model_repo,
307
+ general_thresh,
308
+ general_mcut_enabled,
309
+ character_thresh,
310
+ character_mcut_enabled,
311
+ ],
312
+ variant="secondary",
313
+ size="lg",
314
+ )
315
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
316
+ with gr.Column(variant="panel"):
317
+ with gr.Row():
318
+ sort_by_confidence_enabled = gr.Checkbox(
319
+ value=True if args.sort_tag_string_by_confidence else False,
320
+ label="Sort By Confidence"
321
+ )
322
+ sort_descending_enabled = gr.Checkbox(
323
+ value=False,
324
+ label="Descending"
325
+ )
326
+ escape_parens_enabled = gr.Checkbox(
327
+ value=False,
328
+ label="Escape Parens"
329
+ )
330
+ sorted_general_strings = gr.Textbox(
331
+ label="Output (string)",
332
+ show_copy_button=True
333
+ )
334
+ sorted_character_strings = gr.Textbox(
335
+ label="Character (string)",
336
+ show_copy_button=True
337
+ )
338
+ rating = gr.Label(label="Rating")
339
+ character_res = gr.Label(label="Output (characters)")
340
+ general_res = gr.Label(label="Output (tags)")
341
+ clear.add(
342
+ [
343
+ sorted_general_strings,
344
+ rating,
345
+ character_res,
346
+ general_res,
347
+ ]
348
+ )
349
+
350
+ submit.click(
351
+ predictor.predict,
352
+ inputs=[
353
+ image,
354
+ model_repo,
355
+ general_thresh,
356
+ general_mcut_enabled,
357
+ character_thresh,
358
+ character_mcut_enabled,
359
+ sort_by_confidence_enabled,
360
+ sort_descending_enabled,
361
+ escape_parens_enabled
362
+ ],
363
+ outputs=[sorted_general_strings, sorted_character_strings,
364
+ rating, character_res, general_res],
365
+ )
366
+
367
+ demo.queue(max_size=10)
368
+ demo.launch()
369
+
370
+
371
+ if __name__ == "__main__":
372
+ main()