John6666 commited on
Commit
2b188e2
β€’
1 Parent(s): b3dd22d

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +15 -15
  3. tagger.py +9 -5
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ‘€πŸ“¦
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: openrail
 
4
  colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.37.2
8
  app_file: app.py
9
  pinned: false
10
  license: openrail
app.py CHANGED
@@ -102,35 +102,35 @@ def main():
102
  input_ban_tags,
103
  ]
104
 
105
- translate_t2t_text_button.click(translate_prompt, inputs=[t2t_text], outputs=[t2t_text])
106
 
107
  generate_from_text_btn.click(
108
  predict_text_to_tags,
109
- inputs=[t2t_text, t2t_max_tokens, t2t_temperature, t2t_top_k, t2t_top_p, t2t_repeat_penalty],
110
- outputs=[
111
  input_general,
112
  ],
113
  )
114
 
115
- copy_input_btn.click(compose_prompt_to_copy, inputs=[input_character, input_copyright, input_general], outputs=[input_tags_to_copy]).then(
116
- gradio_copy_text, inputs=[input_tags_to_copy], js=COPY_ACTION_JS,
117
  )
118
 
119
  generate_btn.click(
120
  parse_upsampling_output(v2.on_generate),
121
- inputs=[
122
  *v2.input_components,
123
  ],
124
- outputs=[output_text, elapsed_time_md, copy_btn, copy_btn_pony],
125
- ).then(
126
- convert_danbooru_to_e621_prompt, inputs=[output_text, tag_type], outputs=[output_text_pony],
127
- ).then(
128
- insert_recom_prompt, inputs=[output_text, dummy_np, recom_animagine], outputs=[output_text, dummy_np],
129
- ).then(
130
- insert_recom_prompt, inputs=[output_text_pony, dummy_np, recom_pony], outputs=[output_text_pony, dummy_np],
131
  )
132
- copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS)
133
- copy_btn_pony.click(gradio_copy_text, inputs=[output_text_pony], js=COPY_ACTION_JS)
134
 
135
  ui.launch()
136
 
 
102
  input_ban_tags,
103
  ]
104
 
105
+ translate_t2t_text_button.click(translate_prompt, [t2t_text], [t2t_text])
106
 
107
  generate_from_text_btn.click(
108
  predict_text_to_tags,
109
+ [t2t_text, t2t_max_tokens, t2t_temperature, t2t_top_k, t2t_top_p, t2t_repeat_penalty],
110
+ [
111
  input_general,
112
  ],
113
  )
114
 
115
+ copy_input_btn.click(compose_prompt_to_copy, [input_character, input_copyright, input_general], [input_tags_to_copy]).success(
116
+ gradio_copy_text, [input_tags_to_copy], js=COPY_ACTION_JS,
117
  )
118
 
119
  generate_btn.click(
120
  parse_upsampling_output(v2.on_generate),
121
+ [
122
  *v2.input_components,
123
  ],
124
+ [output_text, elapsed_time_md, copy_btn, copy_btn_pony],
125
+ ).success(
126
+ convert_danbooru_to_e621_prompt, [output_text, tag_type], [output_text_pony],
127
+ ).success(
128
+ insert_recom_prompt, [output_text, dummy_np, recom_animagine], [output_text, dummy_np],
129
+ ).success(
130
+ insert_recom_prompt, [output_text_pony, dummy_np, recom_pony], [output_text_pony, dummy_np],
131
  )
132
+ copy_btn.click(gradio_copy_text, [output_text], js=COPY_ACTION_JS)
133
+ copy_btn_pony.click(gradio_copy_text, [output_text_pony], js=COPY_ACTION_JS)
134
 
135
  ui.launch()
136
 
tagger.py CHANGED
@@ -81,6 +81,14 @@ def character_list_to_series_list(character_list):
81
  return output_series_tag
82
 
83
 
 
 
 
 
 
 
 
 
84
  def danbooru_to_e621(dtag, e621_dict):
85
  def d_to_e(match, e621_dict):
86
  dtag = match.group(0)
@@ -415,28 +423,24 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
415
  results = {
416
  wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
417
  }
418
-
419
  # rating, character, general
420
  rating, character, general = postprocess_results(
421
  results, general_threshold, character_threshold
422
  )
423
-
424
  prompt = gen_prompt(
425
  list(rating.keys()), list(character.keys()), list(general.keys())
426
  )
427
-
428
  output_series_tag = ""
429
  output_series_list = character_list_to_series_list(character.keys())
430
  if output_series_list:
431
  output_series_tag = output_series_list[0]
432
  else:
433
  output_series_tag = ""
434
-
435
  return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
436
 
437
 
438
  def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3, character_threshold: float = 0.8):
439
- if algo and not "Use WD Tagger" in algo:
440
  return "", "", input_tags, gr.update(interactive=True),
441
  return predict_tags(image, general_threshold, character_threshold)
442
 
 
81
  return output_series_tag
82
 
83
 
84
+ def select_random_character(series: str, character: str):
85
+ from random import randrange
86
+ character_list = list(anime_series_dict.keys())
87
+ character = character_list[randrange(len(character_list) - 1)]
88
+ series = anime_series_dict.get(character.split(",")[0].strip(), "")
89
+ return series, character
90
+
91
+
92
  def danbooru_to_e621(dtag, e621_dict):
93
  def d_to_e(match, e621_dict):
94
  dtag = match.group(0)
 
423
  results = {
424
  wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
425
  }
 
426
  # rating, character, general
427
  rating, character, general = postprocess_results(
428
  results, general_threshold, character_threshold
429
  )
 
430
  prompt = gen_prompt(
431
  list(rating.keys()), list(character.keys()), list(general.keys())
432
  )
 
433
  output_series_tag = ""
434
  output_series_list = character_list_to_series_list(character.keys())
435
  if output_series_list:
436
  output_series_tag = output_series_list[0]
437
  else:
438
  output_series_tag = ""
 
439
  return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
440
 
441
 
442
  def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3, character_threshold: float = 0.8):
443
+ if not "Use WD Tagger" in algo and len(algo) != 0:
444
  return "", "", input_tags, gr.update(interactive=True),
445
  return predict_tags(image, general_threshold, character_threshold)
446