p1atdev commited on
Commit
31d9259
1 Parent(s): 6fa0b33

feat: add optimized models, use tokenizer chat template and better ui

Browse files
Files changed (1) hide show
  1. app.py +166 -75
app.py CHANGED
@@ -4,13 +4,14 @@ import os
4
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
7
 
8
  import gradio as gr
9
 
10
  MODEL_NAME = (
11
  os.environ.get("MODEL_NAME")
12
  if os.environ.get("MODEL_NAME") is not None
13
- else "p1atdev/dart-test-3-sft-1"
14
  )
15
  HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
16
 
@@ -21,16 +22,32 @@ tokenizer = AutoTokenizer.from_pretrained(
21
  trust_remote_code=True,
22
  token=HF_READ_TOKEN,
23
  )
24
- model = AutoModelForCausalLM.from_pretrained(
25
- MODEL_NAME,
26
- trust_remote_code=True,
27
- token=HF_READ_TOKEN,
28
- )
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  try:
31
- model = torch.compile(model)
32
  except:
33
- print("torch compile not supported")
 
 
 
 
 
34
 
35
  BOS = "<|bos|>"
36
  EOS = "<|eos|>"
@@ -45,6 +62,11 @@ GENERAL_EOS = "</general>"
45
 
46
  INPUT_END = "<|input_end|>"
47
 
 
 
 
 
 
48
  RATING_BOS_ID = tokenizer.convert_tokens_to_ids(RATING_BOS)
49
  RATING_EOS_ID = tokenizer.convert_tokens_to_ids(RATING_EOS)
50
  COPYRIGHT_BOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_BOS)
@@ -54,9 +76,6 @@ CHARACTER_EOS_ID = tokenizer.convert_tokens_to_ids(CHARACTER_EOS)
54
  GENERAL_BOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_BOS)
55
  GENERAL_EOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_EOS)
56
 
57
- INPUT_END_ID = tokenizer.convert_tokens_to_ids(INPUT_END)
58
-
59
-
60
  assert isinstance(RATING_BOS_ID, int)
61
  assert isinstance(RATING_EOS_ID, int)
62
  assert isinstance(COPYRIGHT_BOS_ID, int)
@@ -65,7 +84,6 @@ assert isinstance(CHARACTER_BOS_ID, int)
65
  assert isinstance(CHARACTER_EOS_ID, int)
66
  assert isinstance(GENERAL_BOS_ID, int)
67
  assert isinstance(GENERAL_EOS_ID, int)
68
- assert isinstance(INPUT_END_ID, int)
69
 
70
  SPECIAL_TAGS = [
71
  BOS,
@@ -79,6 +97,10 @@ SPECIAL_TAGS = [
79
  GENERAL_BOS,
80
  GENERAL_EOS,
81
  INPUT_END,
 
 
 
 
82
  ]
83
 
84
  SPECIAL_TAG_IDS = tokenizer.convert_tokens_to_ids(SPECIAL_TAGS)
@@ -95,6 +117,13 @@ RATING_TAGS = {
95
  }
96
  RATING_TAG_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in RATING_TAGS.items()}
97
 
 
 
 
 
 
 
 
98
 
99
  def load_tags(path: str | Path):
100
  if isinstance(path, str):
@@ -115,34 +144,10 @@ PEOPLE_TAG_IDS_LIST = tokenizer.convert_tokens_to_ids(PEOPLE_TAGS_LIST)
115
  assert isinstance(PEOPLE_TAG_IDS_LIST, list)
116
 
117
 
118
- def compose_prompt(
119
- rating: str = "rating:sfw, rating:general",
120
- copyright: str = "",
121
- character: str = "",
122
- general: str = "",
123
- ):
124
- return "".join(
125
- [
126
- BOS,
127
- RATING_BOS,
128
- rating,
129
- RATING_EOS,
130
- COPYRIGHT_BOS,
131
- copyright,
132
- COPYRIGHT_EOS,
133
- CHARACTER_BOS,
134
- character,
135
- CHARACTER_EOS,
136
- GENERAL_BOS,
137
- general,
138
- INPUT_END,
139
- ]
140
- )
141
-
142
-
143
  @torch.no_grad()
144
  def generate(
145
  input_text: str,
 
146
  max_new_tokens: int = 128,
147
  min_new_tokens: int = 0,
148
  do_sample: bool = True,
@@ -157,17 +162,17 @@ def generate(
157
  inputs = tokenizer(
158
  input_text,
159
  return_tensors="pt",
160
- ).input_ids
161
  negative_inputs = (
162
  tokenizer(
163
  negative_input_text,
164
  return_tensors="pt",
165
- ).input_ids
166
  if negative_input_text is not None
167
  else None
168
  )
169
 
170
- generated = model.generate(
171
  inputs,
172
  max_new_tokens=max_new_tokens,
173
  min_new_tokens=min_new_tokens,
@@ -270,12 +275,14 @@ def handle_inputs(
270
  do_cfg: bool = False,
271
  cfg_scale: float = 1.5,
272
  negative_tags: str = "",
 
273
  max_new_tokens: int = 128,
274
  min_new_tokens: int = 0,
275
  temperature: float = 1.0,
276
  top_p: float = 1.0,
277
  top_k: int = 20,
278
  num_beams: int = 1,
 
279
  ):
280
  """
281
  Returns:
@@ -286,6 +293,9 @@ def handle_inputs(
286
  input_prompt_raw,
287
  output_tags_raw,
288
  elapsed_time,
 
 
 
289
  ]
290
  """
291
 
@@ -294,18 +304,28 @@ def handle_inputs(
294
  copyright_tags = ", ".join(copyright_tags_list)
295
  character_tags = ", ".join(character_tags_list)
296
 
297
- prompt = compose_prompt(
298
- rating=prepare_rating_tags(rating_tags),
299
- copyright=copyright_tags,
300
- character=character_tags,
301
- general=general_tags,
 
 
 
 
 
 
302
  )
303
 
304
- negative_prompt = compose_prompt(
305
- rating=prepare_rating_tags(rating_tags),
306
- copyright="",
307
- character="",
308
- general=negative_tags,
 
 
 
 
309
  )
310
 
311
  bad_words_ids = tokenizer.encode_plus(
@@ -314,6 +334,7 @@ def handle_inputs(
314
 
315
  generated_ids = generate(
316
  prompt,
 
317
  max_new_tokens=max_new_tokens,
318
  min_new_tokens=min_new_tokens,
319
  do_sample=True,
@@ -334,6 +355,9 @@ def handle_inputs(
334
  end_time = time.time()
335
  elapsed_time = f"Elapsed: {(end_time - start_time) * 1000:.2f} ms"
336
 
 
 
 
337
  return [
338
  decoded_normal,
339
  decoded_general_only,
@@ -341,13 +365,44 @@ def handle_inputs(
341
  prompt,
342
  decoded_raw,
343
  elapsed_time,
 
 
 
344
  ]
345
 
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  def demo():
348
  with gr.Blocks() as ui:
 
 
 
 
 
349
  with gr.Row():
350
  with gr.Column():
 
 
 
 
 
 
 
 
 
351
  with gr.Group():
352
  rating_dropdown = gr.Dropdown(
353
  label="Rating",
@@ -419,26 +474,29 @@ def demo():
419
 
420
  with gr.Group():
421
  general_tags_textbox = gr.Textbox(
422
- label="General tags",
 
423
  placeholder="1girl, ...",
424
  lines=4,
425
  )
426
 
427
  ban_tags_textbox = gr.Textbox(
428
- label="Ban tags",
429
- placeholder="",
430
  value="",
 
431
  lines=2,
432
  )
433
 
434
- with gr.Accordion(label="Generation config", open=False):
 
 
435
  with gr.Group():
436
  do_cfg_check = gr.Checkbox(
437
  label="Do CFG (Classifier Free Guidance)",
438
  value=False,
439
  )
440
  cfg_scale_slider = gr.Slider(
441
- label="Max new tokens",
442
  maximum=3.0,
443
  minimum=0.1,
444
  step=0.1,
@@ -463,6 +521,13 @@ def demo():
463
  outputs=[cfg_scale_slider, negative_tags_textbox],
464
  )
465
 
 
 
 
 
 
 
 
466
  with gr.Group():
467
  max_new_tokens_slider = gr.Slider(
468
  label="Max new tokens",
@@ -507,27 +572,44 @@ def demo():
507
  value=1,
508
  )
509
 
510
- generate_btn = gr.Button("Generate", variant="primary")
511
-
512
  with gr.Column():
513
- output_tags_natural = gr.Textbox(
514
- label="Generation result",
515
- # placeholder="tags will be here",
516
- interactive=False,
517
- )
518
-
519
- output_tags_general_only = gr.Textbox(
520
- label="General tags only",
521
- interactive=False,
522
- )
523
-
524
- output_tags_animagine = gr.Textbox(
525
- label="Output tags (AnimagineXL v3 style order)",
526
- # placeholder="tags will be here in Animagine v3 style order",
527
- interactive=False,
528
- )
529
 
530
- elapsed_time_md = gr.Markdown(value="Waiting to generate...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
532
  with gr.Accordion(label="Metadata", open=False):
533
  input_prompt_raw = gr.Textbox(
@@ -542,6 +624,8 @@ def demo():
542
  lines=4,
543
  )
544
 
 
 
545
  copyright_tags_mode_dropdown.change(
546
  on_change_copyright_tags_dropdouwn,
547
  inputs=[copyright_tags_mode_dropdown],
@@ -564,12 +648,14 @@ def demo():
564
  do_cfg_check,
565
  cfg_scale_slider,
566
  negative_tags_textbox,
 
567
  max_new_tokens_slider,
568
  min_new_tokens_slider,
569
  temperature_slider,
570
  top_p_slider,
571
  top_k_slider,
572
  num_beams_slider,
 
573
  ],
574
  outputs=[
575
  output_tags_natural,
@@ -578,10 +664,15 @@ def demo():
578
  input_prompt_raw,
579
  output_tags_raw,
580
  elapsed_time_md,
 
 
 
581
  ],
582
  )
583
 
584
- ui.launch()
 
 
585
 
586
 
587
  if __name__ == "__main__":
 
4
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from optimum.onnxruntime import ORTModelForCausalLM
8
 
9
  import gradio as gr
10
 
11
  MODEL_NAME = (
12
  os.environ.get("MODEL_NAME")
13
  if os.environ.get("MODEL_NAME") is not None
14
+ else "p1atdev/dart-v1-sft"
15
  )
16
  HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
17
 
 
22
  trust_remote_code=True,
23
  token=HF_READ_TOKEN,
24
  )
25
+ model = {
26
+ "default": AutoModelForCausalLM.from_pretrained(
27
+ MODEL_NAME,
28
+ token=HF_READ_TOKEN,
29
+ ),
30
+ "ort": ORTModelForCausalLM.from_pretrained(MODEL_NAME),
31
+ "ort_qantized": ORTModelForCausalLM.from_pretrained(
32
+ MODEL_NAME, file_name="model_quantized.onnx"
33
+ ),
34
+ }
35
+
36
+ MODEL_BACKEND_MAP = {
37
+ "Default": "default",
38
+ "ONNX (normal)": "ort",
39
+ "ONNX (quantized)": "ort_qantized",
40
+ }
41
 
42
  try:
43
+ model["default"].to("cuda")
44
  except:
45
+ print("No GPU")
46
+
47
+ try:
48
+ model["default"] = torch.compile(model["default"])
49
+ except:
50
+ print("torch.compile is not supported")
51
 
52
  BOS = "<|bos|>"
53
  EOS = "<|eos|>"
 
62
 
63
  INPUT_END = "<|input_end|>"
64
 
65
+ LENGTH_VERY_SHORT = "<|very_short|>"
66
+ LENGTH_SHORT = "<|short|>"
67
+ LENGTH_LONG = "<|long|>"
68
+ LENGTH_VERY_LONG = "<|very_long|>"
69
+
70
  RATING_BOS_ID = tokenizer.convert_tokens_to_ids(RATING_BOS)
71
  RATING_EOS_ID = tokenizer.convert_tokens_to_ids(RATING_EOS)
72
  COPYRIGHT_BOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_BOS)
 
76
  GENERAL_BOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_BOS)
77
  GENERAL_EOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_EOS)
78
 
 
 
 
79
  assert isinstance(RATING_BOS_ID, int)
80
  assert isinstance(RATING_EOS_ID, int)
81
  assert isinstance(COPYRIGHT_BOS_ID, int)
 
84
  assert isinstance(CHARACTER_EOS_ID, int)
85
  assert isinstance(GENERAL_BOS_ID, int)
86
  assert isinstance(GENERAL_EOS_ID, int)
 
87
 
88
  SPECIAL_TAGS = [
89
  BOS,
 
97
  GENERAL_BOS,
98
  GENERAL_EOS,
99
  INPUT_END,
100
+ LENGTH_VERY_SHORT,
101
+ LENGTH_SHORT,
102
+ LENGTH_LONG,
103
+ LENGTH_VERY_LONG,
104
  ]
105
 
106
  SPECIAL_TAG_IDS = tokenizer.convert_tokens_to_ids(SPECIAL_TAGS)
 
117
  }
118
  RATING_TAG_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in RATING_TAGS.items()}
119
 
120
+ LENGTH_TAGS = {
121
+ "very short": LENGTH_VERY_SHORT,
122
+ "short": LENGTH_SHORT,
123
+ "long": LENGTH_LONG,
124
+ "very long": LENGTH_VERY_LONG,
125
+ }
126
+
127
 
128
  def load_tags(path: str | Path):
129
  if isinstance(path, str):
 
144
  assert isinstance(PEOPLE_TAG_IDS_LIST, list)
145
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  @torch.no_grad()
148
  def generate(
149
  input_text: str,
150
+ model_backend: str,
151
  max_new_tokens: int = 128,
152
  min_new_tokens: int = 0,
153
  do_sample: bool = True,
 
162
  inputs = tokenizer(
163
  input_text,
164
  return_tensors="pt",
165
+ ).input_ids.to(model[MODEL_BACKEND_MAP[model_backend]].device)
166
  negative_inputs = (
167
  tokenizer(
168
  negative_input_text,
169
  return_tensors="pt",
170
+ ).input_ids.to(model[MODEL_BACKEND_MAP[model_backend]].device)
171
  if negative_input_text is not None
172
  else None
173
  )
174
 
175
+ generated = model[MODEL_BACKEND_MAP[model_backend]].generate(
176
  inputs,
177
  max_new_tokens=max_new_tokens,
178
  min_new_tokens=min_new_tokens,
 
275
  do_cfg: bool = False,
276
  cfg_scale: float = 1.5,
277
  negative_tags: str = "",
278
+ total_token_length: str = "long",
279
  max_new_tokens: int = 128,
280
  min_new_tokens: int = 0,
281
  temperature: float = 1.0,
282
  top_p: float = 1.0,
283
  top_k: int = 20,
284
  num_beams: int = 1,
285
+ model_backend: str = "ONNX (quantized)",
286
  ):
287
  """
288
  Returns:
 
293
  input_prompt_raw,
294
  output_tags_raw,
295
  elapsed_time,
296
+ output_tags_natural_copy_btn,
297
+ output_tags_general_only_copy_btn,
298
+ output_tags_animagine_copy_btn
299
  ]
300
  """
301
 
 
304
  copyright_tags = ", ".join(copyright_tags_list)
305
  character_tags = ", ".join(character_tags_list)
306
 
307
+ token_length_tag = LENGTH_TAGS[total_token_length]
308
+
309
+ prompt: str = tokenizer.apply_chat_template(
310
+ { # type: ignore
311
+ "rating": prepare_rating_tags(rating_tags),
312
+ "copyright": copyright_tags,
313
+ "character": character_tags,
314
+ "general": general_tags,
315
+ "length": token_length_tag,
316
+ },
317
+ tokenize=False,
318
  )
319
 
320
+ negative_prompt: str = tokenizer.apply_chat_template(
321
+ { # type: ignore
322
+ "rating": prepare_rating_tags(rating_tags),
323
+ "copyright": "",
324
+ "character": "",
325
+ "general": negative_tags,
326
+ "length": token_length_tag,
327
+ },
328
+ tokenize=False,
329
  )
330
 
331
  bad_words_ids = tokenizer.encode_plus(
 
334
 
335
  generated_ids = generate(
336
  prompt,
337
+ model_backend=model_backend,
338
  max_new_tokens=max_new_tokens,
339
  min_new_tokens=min_new_tokens,
340
  do_sample=True,
 
355
  end_time = time.time()
356
  elapsed_time = f"Elapsed: {(end_time - start_time) * 1000:.2f} ms"
357
 
358
+ # update visibility of buttons
359
+ set_visible = gr.update(visible=True)
360
+
361
  return [
362
  decoded_normal,
363
  decoded_general_only,
 
365
  prompt,
366
  decoded_raw,
367
  elapsed_time,
368
+ set_visible,
369
+ set_visible,
370
+ set_visible,
371
  ]
372
 
373
 
374
+ # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
375
+ def copy_text(_text: None):
376
+ gr.Info("Copied!")
377
+
378
+
379
+ COPY_ACTION_JS = """\
380
+ (inputs, _outputs) => {
381
+ // inputs is the string value of the input_text
382
+ if (inputs.trim() !== "") {
383
+ navigator.clipboard.writeText(inputs);
384
+ }
385
+ }"""
386
+
387
+
388
  def demo():
389
  with gr.Blocks() as ui:
390
+ gr.Markdown(
391
+ """\
392
+ # Danbooru Tags Transformer Demo """
393
+ )
394
+
395
  with gr.Row():
396
  with gr.Column():
397
+
398
+ with gr.Group():
399
+ model_backend_radio = gr.Radio(
400
+ label="Model backend",
401
+ choices=list(MODEL_BACKEND_MAP.keys()),
402
+ value="ONNX (quantized)",
403
+ interactive=True,
404
+ )
405
+
406
  with gr.Group():
407
  rating_dropdown = gr.Dropdown(
408
  label="Rating",
 
474
 
475
  with gr.Group():
476
  general_tags_textbox = gr.Textbox(
477
+ label="General tags (the condition to generate tags)",
478
+ value="",
479
  placeholder="1girl, ...",
480
  lines=4,
481
  )
482
 
483
  ban_tags_textbox = gr.Textbox(
484
+ label="Ban tags (tags in this field never appear in generation)",
 
485
  value="",
486
+ placeholder="official alternate cosutme, english text,...",
487
  lines=2,
488
  )
489
 
490
+ generate_btn = gr.Button("Generate", variant="primary")
491
+
492
+ with gr.Accordion(label="Generation config (advanced)", open=False):
493
  with gr.Group():
494
  do_cfg_check = gr.Checkbox(
495
  label="Do CFG (Classifier Free Guidance)",
496
  value=False,
497
  )
498
  cfg_scale_slider = gr.Slider(
499
+ label="CFG scale",
500
  maximum=3.0,
501
  minimum=0.1,
502
  step=0.1,
 
521
  outputs=[cfg_scale_slider, negative_tags_textbox],
522
  )
523
 
524
+ with gr.Group():
525
+ total_token_length_radio = gr.Radio(
526
+ label="Total token length",
527
+ choices=list(LENGTH_TAGS.keys()),
528
+ value="long",
529
+ )
530
+
531
  with gr.Group():
532
  max_new_tokens_slider = gr.Slider(
533
  label="Max new tokens",
 
572
  value=1,
573
  )
574
 
 
 
575
  with gr.Column():
576
+ with gr.Group():
577
+ output_tags_natural = gr.Textbox(
578
+ label="Generation result",
579
+ # placeholder="tags will be here",
580
+ interactive=False,
581
+ )
582
+ output_tags_natural_copy_btn = gr.Button("Copy", visible=False)
583
+ output_tags_natural_copy_btn.click(
584
+ fn=copy_text,
585
+ inputs=[output_tags_natural],
586
+ js=COPY_ACTION_JS,
587
+ )
 
 
 
 
588
 
589
+ with gr.Group():
590
+ output_tags_general_only = gr.Textbox(
591
+ label="General tags only (sorted)",
592
+ interactive=False,
593
+ )
594
+ output_tags_general_only_copy_btn = gr.Button("Copy", visible=False)
595
+ output_tags_general_only_copy_btn.click(
596
+ fn=copy_text,
597
+ inputs=[output_tags_general_only],
598
+ js=COPY_ACTION_JS,
599
+ )
600
+
601
+ with gr.Group():
602
+ output_tags_animagine = gr.Textbox(
603
+ label="Output tags (AnimagineXL v3 style order)",
604
+ # placeholder="tags will be here in Animagine v3 style order",
605
+ interactive=False,
606
+ )
607
+ output_tags_animagine_copy_btn = gr.Button("Copy", visible=False)
608
+ output_tags_animagine_copy_btn.click(
609
+ fn=copy_text,
610
+ inputs=[output_tags_animagine],
611
+ js=COPY_ACTION_JS,
612
+ )
613
 
614
  with gr.Accordion(label="Metadata", open=False):
615
  input_prompt_raw = gr.Textbox(
 
624
  lines=4,
625
  )
626
 
627
+ elapsed_time_md = gr.Markdown(value="Waiting to generate...")
628
+
629
  copyright_tags_mode_dropdown.change(
630
  on_change_copyright_tags_dropdouwn,
631
  inputs=[copyright_tags_mode_dropdown],
 
648
  do_cfg_check,
649
  cfg_scale_slider,
650
  negative_tags_textbox,
651
+ total_token_length_radio,
652
  max_new_tokens_slider,
653
  min_new_tokens_slider,
654
  temperature_slider,
655
  top_p_slider,
656
  top_k_slider,
657
  num_beams_slider,
658
+ model_backend_radio,
659
  ],
660
  outputs=[
661
  output_tags_natural,
 
664
  input_prompt_raw,
665
  output_tags_raw,
666
  elapsed_time_md,
667
+ output_tags_natural_copy_btn,
668
+ output_tags_general_only_copy_btn,
669
+ output_tags_animagine_copy_btn,
670
  ],
671
  )
672
 
673
+ ui.launch(
674
+ share=True,
675
+ )
676
 
677
 
678
  if __name__ == "__main__":