Spaces:
Runtime error
Runtime error
feat: change for sft version
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ import gradio as gr
|
|
10 |
MODEL_NAME = (
|
11 |
os.environ.get("MODEL_NAME")
|
12 |
if os.environ.get("MODEL_NAME") is not None
|
13 |
-
else "p1atdev/dart-test-1"
|
14 |
)
|
15 |
HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
|
16 |
|
@@ -43,6 +43,8 @@ CHARACTER_EOS = "</character>"
|
|
43 |
GENERAL_BOS = "<general>"
|
44 |
GENERAL_EOS = "</general>"
|
45 |
|
|
|
|
|
46 |
RATING_BOS_ID = tokenizer.convert_tokens_to_ids(RATING_BOS)
|
47 |
RATING_EOS_ID = tokenizer.convert_tokens_to_ids(RATING_EOS)
|
48 |
COPYRIGHT_BOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_BOS)
|
@@ -52,6 +54,9 @@ CHARACTER_EOS_ID = tokenizer.convert_tokens_to_ids(CHARACTER_EOS)
|
|
52 |
GENERAL_BOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_BOS)
|
53 |
GENERAL_EOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_EOS)
|
54 |
|
|
|
|
|
|
|
55 |
assert isinstance(RATING_BOS_ID, int)
|
56 |
assert isinstance(RATING_EOS_ID, int)
|
57 |
assert isinstance(COPYRIGHT_BOS_ID, int)
|
@@ -60,6 +65,7 @@ assert isinstance(CHARACTER_BOS_ID, int)
|
|
60 |
assert isinstance(CHARACTER_EOS_ID, int)
|
61 |
assert isinstance(GENERAL_BOS_ID, int)
|
62 |
assert isinstance(GENERAL_EOS_ID, int)
|
|
|
63 |
|
64 |
SPECIAL_TAGS = [
|
65 |
BOS,
|
@@ -72,11 +78,12 @@ SPECIAL_TAGS = [
|
|
72 |
CHARACTER_EOS,
|
73 |
GENERAL_BOS,
|
74 |
GENERAL_EOS,
|
|
|
75 |
]
|
76 |
|
77 |
SPECIAL_TAG_IDS = tokenizer.convert_tokens_to_ids(SPECIAL_TAGS)
|
78 |
assert isinstance(SPECIAL_TAG_IDS, list)
|
79 |
-
|
80 |
|
81 |
RATING_TAGS = {
|
82 |
"sfw": "rating:sfw",
|
@@ -128,29 +135,42 @@ def compose_prompt(
|
|
128 |
CHARACTER_EOS,
|
129 |
GENERAL_BOS,
|
130 |
general,
|
|
|
131 |
]
|
132 |
)
|
133 |
|
134 |
|
135 |
@torch.no_grad()
|
136 |
def generate(
|
137 |
-
input_text,
|
138 |
-
max_new_tokens=128,
|
|
|
139 |
do_sample: bool = True,
|
140 |
temperature: float = 1.0,
|
141 |
top_p: float = 1,
|
142 |
top_k: int = 20,
|
143 |
num_beams: int = 1,
|
144 |
bad_words_ids: list[int] | None = None,
|
|
|
|
|
145 |
) -> list[int]:
|
146 |
inputs = tokenizer(
|
147 |
input_text,
|
148 |
return_tensors="pt",
|
149 |
).input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
generated = model.generate(
|
152 |
inputs,
|
153 |
max_new_tokens=max_new_tokens,
|
|
|
154 |
do_sample=do_sample,
|
155 |
temperature=temperature,
|
156 |
top_p=top_p,
|
@@ -159,6 +179,8 @@ def generate(
|
|
159 |
bad_words_ids=(
|
160 |
[[token] for token in bad_words_ids] if bad_words_ids is not None else None
|
161 |
),
|
|
|
|
|
162 |
no_repeat_ngram_size=1,
|
163 |
)[0]
|
164 |
|
@@ -171,7 +193,10 @@ def decode_normal(token_ids: list[int], skip_special_tokens: bool = True):
|
|
171 |
|
172 |
def decode_general_only(token_ids: list[int]):
|
173 |
token_ids = token_ids[token_ids.index(GENERAL_BOS_ID) :]
|
174 |
-
|
|
|
|
|
|
|
175 |
|
176 |
|
177 |
def split_people_tokens_part(token_ids: list[int]):
|
@@ -242,7 +267,11 @@ def handle_inputs(
|
|
242 |
character_tags_list: list[str],
|
243 |
general_tags: str,
|
244 |
ban_tags: str,
|
|
|
|
|
|
|
245 |
max_new_tokens: int = 128,
|
|
|
246 |
temperature: float = 1.0,
|
247 |
top_p: float = 1.0,
|
248 |
top_k: int = 20,
|
@@ -272,17 +301,29 @@ def handle_inputs(
|
|
272 |
general=general_tags,
|
273 |
)
|
274 |
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
generated_ids = generate(
|
278 |
prompt,
|
279 |
max_new_tokens=max_new_tokens,
|
|
|
280 |
do_sample=True,
|
281 |
temperature=temperature,
|
282 |
top_p=top_p,
|
283 |
top_k=top_k,
|
284 |
num_beams=num_beams,
|
285 |
bad_words_ids=bad_words_ids if len(bad_words_ids) > 0 else None,
|
|
|
|
|
286 |
)
|
287 |
|
288 |
decoded_normal = decode_normal(generated_ids, skip_special_tokens=True)
|
@@ -334,7 +375,7 @@ def demo():
|
|
334 |
)
|
335 |
copyright_tags_dropdown = gr.Dropdown(
|
336 |
label="Copyright tags",
|
337 |
-
choices=COPYRIGHT_TAGS_LIST,
|
338 |
value=[],
|
339 |
multiselect=True,
|
340 |
visible=False,
|
@@ -363,7 +404,7 @@ def demo():
|
|
363 |
)
|
364 |
character_tags_dropdown = gr.Dropdown(
|
365 |
label="Character tags",
|
366 |
-
choices=CHARACTER_TAGS_LIST,
|
367 |
value=[],
|
368 |
multiselect=True,
|
369 |
visible=False,
|
@@ -371,6 +412,8 @@ def demo():
|
|
371 |
|
372 |
def on_change_character_tags_dropdouwn(mode: str):
|
373 |
kwargs: dict = {"visible": mode == "Custom"}
|
|
|
|
|
374 |
|
375 |
return gr.update(**kwargs)
|
376 |
|
@@ -389,6 +432,37 @@ def demo():
|
|
389 |
)
|
390 |
|
391 |
with gr.Accordion(label="Generation config", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
with gr.Group():
|
393 |
max_new_tokens_slider = gr.Slider(
|
394 |
label="Max new tokens",
|
@@ -397,6 +471,13 @@ def demo():
|
|
397 |
step=1,
|
398 |
value=128,
|
399 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
temperature_slider = gr.Slider(
|
401 |
label="Temperature (larger is more random)",
|
402 |
maximum=1.0,
|
@@ -480,7 +561,11 @@ def demo():
|
|
480 |
character_tags_dropdown,
|
481 |
general_tags_textbox,
|
482 |
ban_tags_textbox,
|
|
|
|
|
|
|
483 |
max_new_tokens_slider,
|
|
|
484 |
temperature_slider,
|
485 |
top_p_slider,
|
486 |
top_k_slider,
|
|
|
10 |
MODEL_NAME = (
|
11 |
os.environ.get("MODEL_NAME")
|
12 |
if os.environ.get("MODEL_NAME") is not None
|
13 |
+
else "p1atdev/dart-test-3-sft-1"
|
14 |
)
|
15 |
HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
|
16 |
|
|
|
43 |
GENERAL_BOS = "<general>"
|
44 |
GENERAL_EOS = "</general>"
|
45 |
|
46 |
+
INPUT_END = "<|input_end|>"
|
47 |
+
|
48 |
RATING_BOS_ID = tokenizer.convert_tokens_to_ids(RATING_BOS)
|
49 |
RATING_EOS_ID = tokenizer.convert_tokens_to_ids(RATING_EOS)
|
50 |
COPYRIGHT_BOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_BOS)
|
|
|
54 |
GENERAL_BOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_BOS)
|
55 |
GENERAL_EOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_EOS)
|
56 |
|
57 |
+
INPUT_END_ID = tokenizer.convert_tokens_to_ids(INPUT_END)
|
58 |
+
|
59 |
+
|
60 |
assert isinstance(RATING_BOS_ID, int)
|
61 |
assert isinstance(RATING_EOS_ID, int)
|
62 |
assert isinstance(COPYRIGHT_BOS_ID, int)
|
|
|
65 |
assert isinstance(CHARACTER_EOS_ID, int)
|
66 |
assert isinstance(GENERAL_BOS_ID, int)
|
67 |
assert isinstance(GENERAL_EOS_ID, int)
|
68 |
+
assert isinstance(INPUT_END_ID, int)
|
69 |
|
70 |
SPECIAL_TAGS = [
|
71 |
BOS,
|
|
|
78 |
CHARACTER_EOS,
|
79 |
GENERAL_BOS,
|
80 |
GENERAL_EOS,
|
81 |
+
INPUT_END,
|
82 |
]
|
83 |
|
84 |
SPECIAL_TAG_IDS = tokenizer.convert_tokens_to_ids(SPECIAL_TAGS)
|
85 |
assert isinstance(SPECIAL_TAG_IDS, list)
|
86 |
+
assert all([token_id != tokenizer.unk_token_id for token_id in SPECIAL_TAG_IDS])
|
87 |
|
88 |
RATING_TAGS = {
|
89 |
"sfw": "rating:sfw",
|
|
|
135 |
CHARACTER_EOS,
|
136 |
GENERAL_BOS,
|
137 |
general,
|
138 |
+
INPUT_END,
|
139 |
]
|
140 |
)
|
141 |
|
142 |
|
143 |
@torch.no_grad()
|
144 |
def generate(
|
145 |
+
input_text: str,
|
146 |
+
max_new_tokens: int = 128,
|
147 |
+
min_new_tokens: int = 0,
|
148 |
do_sample: bool = True,
|
149 |
temperature: float = 1.0,
|
150 |
top_p: float = 1,
|
151 |
top_k: int = 20,
|
152 |
num_beams: int = 1,
|
153 |
bad_words_ids: list[int] | None = None,
|
154 |
+
cfg_scale: float = 1.5,
|
155 |
+
negative_input_text: str | None = None,
|
156 |
) -> list[int]:
|
157 |
inputs = tokenizer(
|
158 |
input_text,
|
159 |
return_tensors="pt",
|
160 |
).input_ids
|
161 |
+
negative_inputs = (
|
162 |
+
tokenizer(
|
163 |
+
negative_input_text,
|
164 |
+
return_tensors="pt",
|
165 |
+
).input_ids
|
166 |
+
if negative_input_text is not None
|
167 |
+
else None
|
168 |
+
)
|
169 |
|
170 |
generated = model.generate(
|
171 |
inputs,
|
172 |
max_new_tokens=max_new_tokens,
|
173 |
+
min_new_tokens=min_new_tokens,
|
174 |
do_sample=do_sample,
|
175 |
temperature=temperature,
|
176 |
top_p=top_p,
|
|
|
179 |
bad_words_ids=(
|
180 |
[[token] for token in bad_words_ids] if bad_words_ids is not None else None
|
181 |
),
|
182 |
+
negative_prompt_ids=negative_inputs,
|
183 |
+
guidance_scale=cfg_scale,
|
184 |
no_repeat_ngram_size=1,
|
185 |
)[0]
|
186 |
|
|
|
193 |
|
194 |
def decode_general_only(token_ids: list[int]):
|
195 |
token_ids = token_ids[token_ids.index(GENERAL_BOS_ID) :]
|
196 |
+
decoded = tokenizer.decode(token_ids, skip_special_tokens=True)
|
197 |
+
tags = [tag for tag in decoded.split(", ")]
|
198 |
+
tags = sorted(tags)
|
199 |
+
return ", ".join(tags)
|
200 |
|
201 |
|
202 |
def split_people_tokens_part(token_ids: list[int]):
|
|
|
267 |
character_tags_list: list[str],
|
268 |
general_tags: str,
|
269 |
ban_tags: str,
|
270 |
+
do_cfg: bool = False,
|
271 |
+
cfg_scale: float = 1.5,
|
272 |
+
negative_tags: str = "",
|
273 |
max_new_tokens: int = 128,
|
274 |
+
min_new_tokens: int = 0,
|
275 |
temperature: float = 1.0,
|
276 |
top_p: float = 1.0,
|
277 |
top_k: int = 20,
|
|
|
301 |
general=general_tags,
|
302 |
)
|
303 |
|
304 |
+
negative_prompt = compose_prompt(
|
305 |
+
rating=prepare_rating_tags(rating_tags),
|
306 |
+
copyright="",
|
307 |
+
character="",
|
308 |
+
general=negative_tags,
|
309 |
+
)
|
310 |
+
|
311 |
+
bad_words_ids = tokenizer.encode_plus(
|
312 |
+
ban_tags if negative_tags.strip() == "" else ban_tags + ", " + negative_tags
|
313 |
+
).input_ids
|
314 |
|
315 |
generated_ids = generate(
|
316 |
prompt,
|
317 |
max_new_tokens=max_new_tokens,
|
318 |
+
min_new_tokens=min_new_tokens,
|
319 |
do_sample=True,
|
320 |
temperature=temperature,
|
321 |
top_p=top_p,
|
322 |
top_k=top_k,
|
323 |
num_beams=num_beams,
|
324 |
bad_words_ids=bad_words_ids if len(bad_words_ids) > 0 else None,
|
325 |
+
cfg_scale=cfg_scale,
|
326 |
+
negative_input_text=negative_prompt if do_cfg else None,
|
327 |
)
|
328 |
|
329 |
decoded_normal = decode_normal(generated_ids, skip_special_tokens=True)
|
|
|
375 |
)
|
376 |
copyright_tags_dropdown = gr.Dropdown(
|
377 |
label="Copyright tags",
|
378 |
+
choices=COPYRIGHT_TAGS_LIST, # type: ignore
|
379 |
value=[],
|
380 |
multiselect=True,
|
381 |
visible=False,
|
|
|
404 |
)
|
405 |
character_tags_dropdown = gr.Dropdown(
|
406 |
label="Character tags",
|
407 |
+
choices=CHARACTER_TAGS_LIST, # type: ignore
|
408 |
value=[],
|
409 |
multiselect=True,
|
410 |
visible=False,
|
|
|
412 |
|
413 |
def on_change_character_tags_dropdouwn(mode: str):
|
414 |
kwargs: dict = {"visible": mode == "Custom"}
|
415 |
+
if mode == "None":
|
416 |
+
kwargs["value"] = []
|
417 |
|
418 |
return gr.update(**kwargs)
|
419 |
|
|
|
432 |
)
|
433 |
|
434 |
with gr.Accordion(label="Generation config", open=False):
|
435 |
+
with gr.Group():
|
436 |
+
do_cfg_check = gr.Checkbox(
|
437 |
+
label="Do CFG (Classifier Free Guidance)",
|
438 |
+
value=False,
|
439 |
+
)
|
440 |
+
cfg_scale_slider = gr.Slider(
|
441 |
+
label="Max new tokens",
|
442 |
+
maximum=3.0,
|
443 |
+
minimum=0.1,
|
444 |
+
step=0.1,
|
445 |
+
value=1.5,
|
446 |
+
visible=False,
|
447 |
+
)
|
448 |
+
negative_tags_textbox = gr.Textbox(
|
449 |
+
label="Negative prompt",
|
450 |
+
placeholder="simple background, ...",
|
451 |
+
value="",
|
452 |
+
lines=2,
|
453 |
+
visible=False,
|
454 |
+
)
|
455 |
+
|
456 |
+
def on_change_do_cfg_check(do_cfg: bool):
|
457 |
+
kwargs: dict = {"visible": do_cfg}
|
458 |
+
return gr.update(**kwargs), gr.update(**kwargs)
|
459 |
+
|
460 |
+
do_cfg_check.change(
|
461 |
+
on_change_do_cfg_check,
|
462 |
+
inputs=[do_cfg_check],
|
463 |
+
outputs=[cfg_scale_slider, negative_tags_textbox],
|
464 |
+
)
|
465 |
+
|
466 |
with gr.Group():
|
467 |
max_new_tokens_slider = gr.Slider(
|
468 |
label="Max new tokens",
|
|
|
471 |
step=1,
|
472 |
value=128,
|
473 |
)
|
474 |
+
min_new_tokens_slider = gr.Slider(
|
475 |
+
label="Min new tokens",
|
476 |
+
maximum=255,
|
477 |
+
minimum=0,
|
478 |
+
step=1,
|
479 |
+
value=0,
|
480 |
+
)
|
481 |
temperature_slider = gr.Slider(
|
482 |
label="Temperature (larger is more random)",
|
483 |
maximum=1.0,
|
|
|
561 |
character_tags_dropdown,
|
562 |
general_tags_textbox,
|
563 |
ban_tags_textbox,
|
564 |
+
do_cfg_check,
|
565 |
+
cfg_scale_slider,
|
566 |
+
negative_tags_textbox,
|
567 |
max_new_tokens_slider,
|
568 |
+
min_new_tokens_slider,
|
569 |
temperature_slider,
|
570 |
top_p_slider,
|
571 |
top_k_slider,
|