p1atdev commited on
Commit
ad52e8b
1 Parent(s): 6277c6e

feat: change for sft version

Browse files
Files changed (1) hide show
  1. app.py +93 -8
app.py CHANGED
@@ -10,7 +10,7 @@ import gradio as gr
10
  MODEL_NAME = (
11
  os.environ.get("MODEL_NAME")
12
  if os.environ.get("MODEL_NAME") is not None
13
- else "p1atdev/dart-test-1"
14
  )
15
  HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
16
 
@@ -43,6 +43,8 @@ CHARACTER_EOS = "</character>"
43
  GENERAL_BOS = "<general>"
44
  GENERAL_EOS = "</general>"
45
 
 
 
46
  RATING_BOS_ID = tokenizer.convert_tokens_to_ids(RATING_BOS)
47
  RATING_EOS_ID = tokenizer.convert_tokens_to_ids(RATING_EOS)
48
  COPYRIGHT_BOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_BOS)
@@ -52,6 +54,9 @@ CHARACTER_EOS_ID = tokenizer.convert_tokens_to_ids(CHARACTER_EOS)
52
  GENERAL_BOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_BOS)
53
  GENERAL_EOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_EOS)
54
 
 
 
 
55
  assert isinstance(RATING_BOS_ID, int)
56
  assert isinstance(RATING_EOS_ID, int)
57
  assert isinstance(COPYRIGHT_BOS_ID, int)
@@ -60,6 +65,7 @@ assert isinstance(CHARACTER_BOS_ID, int)
60
  assert isinstance(CHARACTER_EOS_ID, int)
61
  assert isinstance(GENERAL_BOS_ID, int)
62
  assert isinstance(GENERAL_EOS_ID, int)
 
63
 
64
  SPECIAL_TAGS = [
65
  BOS,
@@ -72,11 +78,12 @@ SPECIAL_TAGS = [
72
  CHARACTER_EOS,
73
  GENERAL_BOS,
74
  GENERAL_EOS,
 
75
  ]
76
 
77
  SPECIAL_TAG_IDS = tokenizer.convert_tokens_to_ids(SPECIAL_TAGS)
78
  assert isinstance(SPECIAL_TAG_IDS, list)
79
-
80
 
81
  RATING_TAGS = {
82
  "sfw": "rating:sfw",
@@ -128,29 +135,42 @@ def compose_prompt(
128
  CHARACTER_EOS,
129
  GENERAL_BOS,
130
  general,
 
131
  ]
132
  )
133
 
134
 
135
  @torch.no_grad()
136
  def generate(
137
- input_text,
138
- max_new_tokens=128,
 
139
  do_sample: bool = True,
140
  temperature: float = 1.0,
141
  top_p: float = 1,
142
  top_k: int = 20,
143
  num_beams: int = 1,
144
  bad_words_ids: list[int] | None = None,
 
 
145
  ) -> list[int]:
146
  inputs = tokenizer(
147
  input_text,
148
  return_tensors="pt",
149
  ).input_ids
 
 
 
 
 
 
 
 
150
 
151
  generated = model.generate(
152
  inputs,
153
  max_new_tokens=max_new_tokens,
 
154
  do_sample=do_sample,
155
  temperature=temperature,
156
  top_p=top_p,
@@ -159,6 +179,8 @@ def generate(
159
  bad_words_ids=(
160
  [[token] for token in bad_words_ids] if bad_words_ids is not None else None
161
  ),
 
 
162
  no_repeat_ngram_size=1,
163
  )[0]
164
 
@@ -171,7 +193,10 @@ def decode_normal(token_ids: list[int], skip_special_tokens: bool = True):
171
 
172
  def decode_general_only(token_ids: list[int]):
173
  token_ids = token_ids[token_ids.index(GENERAL_BOS_ID) :]
174
- return tokenizer.decode(token_ids, skip_special_tokens=True)
 
 
 
175
 
176
 
177
  def split_people_tokens_part(token_ids: list[int]):
@@ -242,7 +267,11 @@ def handle_inputs(
242
  character_tags_list: list[str],
243
  general_tags: str,
244
  ban_tags: str,
 
 
 
245
  max_new_tokens: int = 128,
 
246
  temperature: float = 1.0,
247
  top_p: float = 1.0,
248
  top_k: int = 20,
@@ -272,17 +301,29 @@ def handle_inputs(
272
  general=general_tags,
273
  )
274
 
275
- bad_words_ids = tokenizer.encode_plus(ban_tags).input_ids
 
 
 
 
 
 
 
 
 
276
 
277
  generated_ids = generate(
278
  prompt,
279
  max_new_tokens=max_new_tokens,
 
280
  do_sample=True,
281
  temperature=temperature,
282
  top_p=top_p,
283
  top_k=top_k,
284
  num_beams=num_beams,
285
  bad_words_ids=bad_words_ids if len(bad_words_ids) > 0 else None,
 
 
286
  )
287
 
288
  decoded_normal = decode_normal(generated_ids, skip_special_tokens=True)
@@ -334,7 +375,7 @@ def demo():
334
  )
335
  copyright_tags_dropdown = gr.Dropdown(
336
  label="Copyright tags",
337
- choices=COPYRIGHT_TAGS_LIST,
338
  value=[],
339
  multiselect=True,
340
  visible=False,
@@ -363,7 +404,7 @@ def demo():
363
  )
364
  character_tags_dropdown = gr.Dropdown(
365
  label="Character tags",
366
- choices=CHARACTER_TAGS_LIST,
367
  value=[],
368
  multiselect=True,
369
  visible=False,
@@ -371,6 +412,8 @@ def demo():
371
 
372
  def on_change_character_tags_dropdouwn(mode: str):
373
  kwargs: dict = {"visible": mode == "Custom"}
 
 
374
 
375
  return gr.update(**kwargs)
376
 
@@ -389,6 +432,37 @@ def demo():
389
  )
390
 
391
  with gr.Accordion(label="Generation config", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  with gr.Group():
393
  max_new_tokens_slider = gr.Slider(
394
  label="Max new tokens",
@@ -397,6 +471,13 @@ def demo():
397
  step=1,
398
  value=128,
399
  )
 
 
 
 
 
 
 
400
  temperature_slider = gr.Slider(
401
  label="Temperature (larger is more random)",
402
  maximum=1.0,
@@ -480,7 +561,11 @@ def demo():
480
  character_tags_dropdown,
481
  general_tags_textbox,
482
  ban_tags_textbox,
 
 
 
483
  max_new_tokens_slider,
 
484
  temperature_slider,
485
  top_p_slider,
486
  top_k_slider,
 
10
  MODEL_NAME = (
11
  os.environ.get("MODEL_NAME")
12
  if os.environ.get("MODEL_NAME") is not None
13
+ else "p1atdev/dart-test-3-sft-1"
14
  )
15
  HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
16
 
 
43
  GENERAL_BOS = "<general>"
44
  GENERAL_EOS = "</general>"
45
 
46
+ INPUT_END = "<|input_end|>"
47
+
48
  RATING_BOS_ID = tokenizer.convert_tokens_to_ids(RATING_BOS)
49
  RATING_EOS_ID = tokenizer.convert_tokens_to_ids(RATING_EOS)
50
  COPYRIGHT_BOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_BOS)
 
54
  GENERAL_BOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_BOS)
55
  GENERAL_EOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_EOS)
56
 
57
+ INPUT_END_ID = tokenizer.convert_tokens_to_ids(INPUT_END)
58
+
59
+
60
  assert isinstance(RATING_BOS_ID, int)
61
  assert isinstance(RATING_EOS_ID, int)
62
  assert isinstance(COPYRIGHT_BOS_ID, int)
 
65
  assert isinstance(CHARACTER_EOS_ID, int)
66
  assert isinstance(GENERAL_BOS_ID, int)
67
  assert isinstance(GENERAL_EOS_ID, int)
68
+ assert isinstance(INPUT_END_ID, int)
69
 
70
  SPECIAL_TAGS = [
71
  BOS,
 
78
  CHARACTER_EOS,
79
  GENERAL_BOS,
80
  GENERAL_EOS,
81
+ INPUT_END,
82
  ]
83
 
84
  SPECIAL_TAG_IDS = tokenizer.convert_tokens_to_ids(SPECIAL_TAGS)
85
  assert isinstance(SPECIAL_TAG_IDS, list)
86
+ assert all([token_id != tokenizer.unk_token_id for token_id in SPECIAL_TAG_IDS])
87
 
88
  RATING_TAGS = {
89
  "sfw": "rating:sfw",
 
135
  CHARACTER_EOS,
136
  GENERAL_BOS,
137
  general,
138
+ INPUT_END,
139
  ]
140
  )
141
 
142
 
143
  @torch.no_grad()
144
  def generate(
145
+ input_text: str,
146
+ max_new_tokens: int = 128,
147
+ min_new_tokens: int = 0,
148
  do_sample: bool = True,
149
  temperature: float = 1.0,
150
  top_p: float = 1,
151
  top_k: int = 20,
152
  num_beams: int = 1,
153
  bad_words_ids: list[int] | None = None,
154
+ cfg_scale: float = 1.5,
155
+ negative_input_text: str | None = None,
156
  ) -> list[int]:
157
  inputs = tokenizer(
158
  input_text,
159
  return_tensors="pt",
160
  ).input_ids
161
+ negative_inputs = (
162
+ tokenizer(
163
+ negative_input_text,
164
+ return_tensors="pt",
165
+ ).input_ids
166
+ if negative_input_text is not None
167
+ else None
168
+ )
169
 
170
  generated = model.generate(
171
  inputs,
172
  max_new_tokens=max_new_tokens,
173
+ min_new_tokens=min_new_tokens,
174
  do_sample=do_sample,
175
  temperature=temperature,
176
  top_p=top_p,
 
179
  bad_words_ids=(
180
  [[token] for token in bad_words_ids] if bad_words_ids is not None else None
181
  ),
182
+ negative_prompt_ids=negative_inputs,
183
+ guidance_scale=cfg_scale,
184
  no_repeat_ngram_size=1,
185
  )[0]
186
 
 
193
 
194
  def decode_general_only(token_ids: list[int]):
195
  token_ids = token_ids[token_ids.index(GENERAL_BOS_ID) :]
196
+ decoded = tokenizer.decode(token_ids, skip_special_tokens=True)
197
+ tags = [tag for tag in decoded.split(", ")]
198
+ tags = sorted(tags)
199
+ return ", ".join(tags)
200
 
201
 
202
  def split_people_tokens_part(token_ids: list[int]):
 
267
  character_tags_list: list[str],
268
  general_tags: str,
269
  ban_tags: str,
270
+ do_cfg: bool = False,
271
+ cfg_scale: float = 1.5,
272
+ negative_tags: str = "",
273
  max_new_tokens: int = 128,
274
+ min_new_tokens: int = 0,
275
  temperature: float = 1.0,
276
  top_p: float = 1.0,
277
  top_k: int = 20,
 
301
  general=general_tags,
302
  )
303
 
304
+ negative_prompt = compose_prompt(
305
+ rating=prepare_rating_tags(rating_tags),
306
+ copyright="",
307
+ character="",
308
+ general=negative_tags,
309
+ )
310
+
311
+ bad_words_ids = tokenizer.encode_plus(
312
+ ban_tags if negative_tags.strip() == "" else ban_tags + ", " + negative_tags
313
+ ).input_ids
314
 
315
  generated_ids = generate(
316
  prompt,
317
  max_new_tokens=max_new_tokens,
318
+ min_new_tokens=min_new_tokens,
319
  do_sample=True,
320
  temperature=temperature,
321
  top_p=top_p,
322
  top_k=top_k,
323
  num_beams=num_beams,
324
  bad_words_ids=bad_words_ids if len(bad_words_ids) > 0 else None,
325
+ cfg_scale=cfg_scale,
326
+ negative_input_text=negative_prompt if do_cfg else None,
327
  )
328
 
329
  decoded_normal = decode_normal(generated_ids, skip_special_tokens=True)
 
375
  )
376
  copyright_tags_dropdown = gr.Dropdown(
377
  label="Copyright tags",
378
+ choices=COPYRIGHT_TAGS_LIST, # type: ignore
379
  value=[],
380
  multiselect=True,
381
  visible=False,
 
404
  )
405
  character_tags_dropdown = gr.Dropdown(
406
  label="Character tags",
407
+ choices=CHARACTER_TAGS_LIST, # type: ignore
408
  value=[],
409
  multiselect=True,
410
  visible=False,
 
412
 
413
  def on_change_character_tags_dropdouwn(mode: str):
414
  kwargs: dict = {"visible": mode == "Custom"}
415
+ if mode == "None":
416
+ kwargs["value"] = []
417
 
418
  return gr.update(**kwargs)
419
 
 
432
  )
433
 
434
  with gr.Accordion(label="Generation config", open=False):
435
+ with gr.Group():
436
+ do_cfg_check = gr.Checkbox(
437
+ label="Do CFG (Classifier Free Guidance)",
438
+ value=False,
439
+ )
440
+ cfg_scale_slider = gr.Slider(
441
+ label="Max new tokens",
442
+ maximum=3.0,
443
+ minimum=0.1,
444
+ step=0.1,
445
+ value=1.5,
446
+ visible=False,
447
+ )
448
+ negative_tags_textbox = gr.Textbox(
449
+ label="Negative prompt",
450
+ placeholder="simple background, ...",
451
+ value="",
452
+ lines=2,
453
+ visible=False,
454
+ )
455
+
456
+ def on_change_do_cfg_check(do_cfg: bool):
457
+ kwargs: dict = {"visible": do_cfg}
458
+ return gr.update(**kwargs), gr.update(**kwargs)
459
+
460
+ do_cfg_check.change(
461
+ on_change_do_cfg_check,
462
+ inputs=[do_cfg_check],
463
+ outputs=[cfg_scale_slider, negative_tags_textbox],
464
+ )
465
+
466
  with gr.Group():
467
  max_new_tokens_slider = gr.Slider(
468
  label="Max new tokens",
 
471
  step=1,
472
  value=128,
473
  )
474
+ min_new_tokens_slider = gr.Slider(
475
+ label="Min new tokens",
476
+ maximum=255,
477
+ minimum=0,
478
+ step=1,
479
+ value=0,
480
+ )
481
  temperature_slider = gr.Slider(
482
  label="Temperature (larger is more random)",
483
  maximum=1.0,
 
561
  character_tags_dropdown,
562
  general_tags_textbox,
563
  ban_tags_textbox,
564
+ do_cfg_check,
565
+ cfg_scale_slider,
566
+ negative_tags_textbox,
567
  max_new_tokens_slider,
568
+ min_new_tokens_slider,
569
  temperature_slider,
570
  top_p_slider,
571
  top_k_slider,