p1atdev's picture
fix: use_auth_token
6d2b23d
raw
history blame
No virus
22.6 kB
from pathlib import Path
import time
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from optimum.onnxruntime import ORTModelForCausalLM
import gradio as gr
MODEL_NAME = (
os.environ.get("MODEL_NAME")
if os.environ.get("MODEL_NAME") is not None
else "p1atdev/dart-v1-sft"
)
HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
assert isinstance(MODEL_NAME, str)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
token=HF_READ_TOKEN,
)
model = {
"default": AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
use_auth_token=HF_READ_TOKEN,
),
"ort": ORTModelForCausalLM.from_pretrained(
MODEL_NAME,
use_auth_token=HF_READ_TOKEN,
),
"ort_qantized": ORTModelForCausalLM.from_pretrained(
MODEL_NAME,
file_name="model_quantized.onnx",
token=HF_READ_TOKEN,
),
}
MODEL_BACKEND_MAP = {
"Default": "default",
"ONNX (normal)": "ort",
"ONNX (quantized)": "ort_qantized",
}
try:
model["default"].to("cuda")
except:
print("No GPU")
try:
model["default"] = torch.compile(model["default"])
except:
print("torch.compile is not supported")
BOS = "<|bos|>"
EOS = "<|eos|>"
RATING_BOS = "<rating>"
RATING_EOS = "</rating>"
COPYRIGHT_BOS = "<copyright>"
COPYRIGHT_EOS = "</copyright>"
CHARACTER_BOS = "<character>"
CHARACTER_EOS = "</character>"
GENERAL_BOS = "<general>"
GENERAL_EOS = "</general>"
INPUT_END = "<|input_end|>"
LENGTH_VERY_SHORT = "<|very_short|>"
LENGTH_SHORT = "<|short|>"
LENGTH_LONG = "<|long|>"
LENGTH_VERY_LONG = "<|very_long|>"
RATING_BOS_ID = tokenizer.convert_tokens_to_ids(RATING_BOS)
RATING_EOS_ID = tokenizer.convert_tokens_to_ids(RATING_EOS)
COPYRIGHT_BOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_BOS)
COPYRIGHT_EOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_EOS)
CHARACTER_BOS_ID = tokenizer.convert_tokens_to_ids(CHARACTER_BOS)
CHARACTER_EOS_ID = tokenizer.convert_tokens_to_ids(CHARACTER_EOS)
GENERAL_BOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_BOS)
GENERAL_EOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_EOS)
assert isinstance(RATING_BOS_ID, int)
assert isinstance(RATING_EOS_ID, int)
assert isinstance(COPYRIGHT_BOS_ID, int)
assert isinstance(COPYRIGHT_EOS_ID, int)
assert isinstance(CHARACTER_BOS_ID, int)
assert isinstance(CHARACTER_EOS_ID, int)
assert isinstance(GENERAL_BOS_ID, int)
assert isinstance(GENERAL_EOS_ID, int)
SPECIAL_TAGS = [
BOS,
EOS,
RATING_BOS,
RATING_EOS,
COPYRIGHT_BOS,
COPYRIGHT_EOS,
CHARACTER_BOS,
CHARACTER_EOS,
GENERAL_BOS,
GENERAL_EOS,
INPUT_END,
LENGTH_VERY_SHORT,
LENGTH_SHORT,
LENGTH_LONG,
LENGTH_VERY_LONG,
]
SPECIAL_TAG_IDS = tokenizer.convert_tokens_to_ids(SPECIAL_TAGS)
assert isinstance(SPECIAL_TAG_IDS, list)
assert all([token_id != tokenizer.unk_token_id for token_id in SPECIAL_TAG_IDS])
RATING_TAGS = {
"sfw": "rating:sfw",
"nsfw": "rating:nsfw",
"general": "rating:general",
"sensitive": "rating:sensitive",
"questionable": "rating:questionable",
"explicit": "rating:explicit",
}
RATING_TAG_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in RATING_TAGS.items()}
LENGTH_TAGS = {
"very short": LENGTH_VERY_SHORT,
"short": LENGTH_SHORT,
"long": LENGTH_LONG,
"very long": LENGTH_VERY_LONG,
}
def load_tags(path: str | Path):
if isinstance(path, str):
path = Path(path)
with open(path, "r", encoding="utf-8") as file:
lines = [line.strip() for line in file.readlines() if line.strip() != ""]
return lines
COPYRIGHT_TAGS_LIST: list[str] = load_tags("./tags/copyright.txt")
CHARACTER_TAGS_LIST: list[str] = load_tags("./tags/character.txt")
PEOPLE_TAGS_LIST: list[str] = load_tags("./tags/people.txt")
PEOPLE_TAG_IDS_LIST = tokenizer.convert_tokens_to_ids(PEOPLE_TAGS_LIST)
assert isinstance(PEOPLE_TAG_IDS_LIST, list)
@torch.no_grad()
def generate(
input_text: str,
model_backend: str,
max_new_tokens: int = 128,
min_new_tokens: int = 0,
do_sample: bool = True,
temperature: float = 1.0,
top_p: float = 1,
top_k: int = 20,
num_beams: int = 1,
bad_words_ids: list[int] | None = None,
cfg_scale: float = 1.5,
negative_input_text: str | None = None,
) -> list[int]:
inputs = tokenizer(
input_text,
return_tensors="pt",
).input_ids.to(model[MODEL_BACKEND_MAP[model_backend]].device)
negative_inputs = (
tokenizer(
negative_input_text,
return_tensors="pt",
).input_ids.to(model[MODEL_BACKEND_MAP[model_backend]].device)
if negative_input_text is not None
else None
)
generated = model[MODEL_BACKEND_MAP[model_backend]].generate(
inputs,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
bad_words_ids=(
[[token] for token in bad_words_ids] if bad_words_ids is not None else None
),
negative_prompt_ids=negative_inputs,
guidance_scale=cfg_scale,
no_repeat_ngram_size=1,
)[0]
return generated.tolist()
def decode_normal(token_ids: list[int], skip_special_tokens: bool = True):
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
def decode_general_only(token_ids: list[int]):
token_ids = token_ids[token_ids.index(GENERAL_BOS_ID) :]
decoded = tokenizer.decode(token_ids, skip_special_tokens=True)
tags = [tag for tag in decoded.split(", ")]
tags = sorted(tags)
return ", ".join(tags)
def split_people_tokens_part(token_ids: list[int]):
people_tokens = []
other_tokens = []
for token in token_ids:
if token in PEOPLE_TAG_IDS_LIST:
people_tokens.append(token)
else:
other_tokens.append(token)
return people_tokens, other_tokens
def decode_animagine(token_ids: list[int]):
def get_part(eos_token_id: int, remains_part: list[int]):
part = []
for i, token_id in enumerate(remains_part):
if token_id == eos_token_id:
return part, remains_part[i:]
part.append(token_id)
raise Exception("The provided EOS token was not found in the token_ids.")
# get each part
rating_part, remains = get_part(RATING_EOS_ID, token_ids)
copyright_part, remains = get_part(COPYRIGHT_EOS_ID, remains)
character_part, remains = get_part(CHARACTER_EOS_ID, remains)
general_part, _ = get_part(GENERAL_EOS_ID, remains)
# separete people tags (1girl, 1boy, no humans...)
people_part, other_general_part = split_people_tokens_part(general_part)
# remove "rating:sfw"
rating_part = [token for token in rating_part if token != RATING_TAG_IDS["sfw"]]
# AnimagineXL v3 style order
rearranged_tokens = (
people_part + character_part + copyright_part + other_general_part + rating_part
)
rearranged_tokens = [
token for token in rearranged_tokens if token not in SPECIAL_TAG_IDS
]
decoded = tokenizer.decode(rearranged_tokens, skip_special_tokens=True)
# fix "nsfw" tag
decoded = decoded.replace("rating:nsfw", "nsfw")
return decoded
def prepare_rating_tags(rating: str):
tag = RATING_TAGS[rating]
if tag in [RATING_TAGS["general"], RATING_TAGS["sensitive"]]:
parent = RATING_TAGS["sfw"]
else:
parent = RATING_TAGS["nsfw"]
return f"{parent}, {tag}"
def handle_inputs(
rating_tags: str,
copyright_tags_list: list[str],
character_tags_list: list[str],
general_tags: str,
ban_tags: str,
do_cfg: bool = False,
cfg_scale: float = 1.5,
negative_tags: str = "",
total_token_length: str = "long",
max_new_tokens: int = 128,
min_new_tokens: int = 0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 20,
num_beams: int = 1,
model_backend: str = "ONNX (quantized)",
):
"""
Returns:
[
output_tags_natural,
output_tags_general_only,
output_tags_animagine,
input_prompt_raw,
output_tags_raw,
elapsed_time,
output_tags_natural_copy_btn,
output_tags_general_only_copy_btn,
output_tags_animagine_copy_btn
]
"""
start_time = time.time()
copyright_tags = ", ".join(copyright_tags_list)
character_tags = ", ".join(character_tags_list)
token_length_tag = LENGTH_TAGS[total_token_length]
prompt: str = tokenizer.apply_chat_template(
{ # type: ignore
"rating": prepare_rating_tags(rating_tags),
"copyright": copyright_tags,
"character": character_tags,
"general": general_tags,
"length": token_length_tag,
},
tokenize=False,
)
negative_prompt: str = tokenizer.apply_chat_template(
{ # type: ignore
"rating": prepare_rating_tags(rating_tags),
"copyright": "",
"character": "",
"general": negative_tags,
"length": token_length_tag,
},
tokenize=False,
)
bad_words_ids = tokenizer.encode_plus(
ban_tags if negative_tags.strip() == "" else ban_tags + ", " + negative_tags
).input_ids
generated_ids = generate(
prompt,
model_backend=model_backend,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
bad_words_ids=bad_words_ids if len(bad_words_ids) > 0 else None,
cfg_scale=cfg_scale,
negative_input_text=negative_prompt if do_cfg else None,
)
decoded_normal = decode_normal(generated_ids, skip_special_tokens=True)
decoded_general_only = decode_general_only(generated_ids)
decoded_animagine = decode_animagine(generated_ids)
decoded_raw = decode_normal(generated_ids, skip_special_tokens=False)
end_time = time.time()
elapsed_time = f"Elapsed: {(end_time - start_time) * 1000:.2f} ms"
# update visibility of buttons
set_visible = gr.update(visible=True)
return [
decoded_normal,
decoded_general_only,
decoded_animagine,
prompt,
decoded_raw,
elapsed_time,
set_visible,
set_visible,
set_visible,
]
# ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
def copy_text(_text: None):
gr.Info("Copied!")
COPY_ACTION_JS = """\
(inputs, _outputs) => {
// inputs is the string value of the input_text
if (inputs.trim() !== "") {
navigator.clipboard.writeText(inputs);
}
}"""
def demo():
with gr.Blocks() as ui:
gr.Markdown(
"""\
# Danbooru Tags Transformer Demo """
)
with gr.Row():
with gr.Column():
with gr.Group():
model_backend_radio = gr.Radio(
label="Model backend",
choices=list(MODEL_BACKEND_MAP.keys()),
value="ONNX (quantized)",
interactive=True,
)
with gr.Group():
rating_dropdown = gr.Dropdown(
label="Rating",
choices=[
"general",
"sensitive",
"questionable",
"explicit",
],
value="general",
)
with gr.Group():
copyright_tags_mode_dropdown = gr.Dropdown(
label="Copyright tags mode",
choices=[
"None",
"Original",
# "Auto", # TODO: implement these modes
# "Random",
"Custom",
],
value="None",
interactive=True,
)
copyright_tags_dropdown = gr.Dropdown(
label="Copyright tags",
choices=COPYRIGHT_TAGS_LIST, # type: ignore
value=[],
multiselect=True,
visible=False,
)
def on_change_copyright_tags_dropdouwn(mode: str):
kwargs: dict = {"visible": mode == "Custom"}
if mode == "Original":
kwargs["value"] = ["original"]
elif mode == "None":
kwargs["value"] = []
return gr.update(**kwargs)
with gr.Group():
character_tags_mode_dropdown = gr.Dropdown(
label="Character tags mode",
choices=[
"None",
# "Auto", # TODO: implement these modes
# "Random",
"Custom",
],
value="None",
interactive=True,
)
character_tags_dropdown = gr.Dropdown(
label="Character tags",
choices=CHARACTER_TAGS_LIST, # type: ignore
value=[],
multiselect=True,
visible=False,
)
def on_change_character_tags_dropdouwn(mode: str):
kwargs: dict = {"visible": mode == "Custom"}
if mode == "None":
kwargs["value"] = []
return gr.update(**kwargs)
with gr.Group():
general_tags_textbox = gr.Textbox(
label="General tags (the condition to generate tags)",
value="",
placeholder="1girl, ...",
lines=4,
)
ban_tags_textbox = gr.Textbox(
label="Ban tags (tags in this field never appear in generation)",
value="",
placeholder="official alternate cosutme, english text,...",
lines=2,
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Accordion(label="Generation config (advanced)", open=False):
with gr.Group():
do_cfg_check = gr.Checkbox(
label="Do CFG (Classifier Free Guidance)",
value=False,
)
cfg_scale_slider = gr.Slider(
label="CFG scale",
maximum=3.0,
minimum=0.1,
step=0.1,
value=1.5,
visible=False,
)
negative_tags_textbox = gr.Textbox(
label="Negative prompt",
placeholder="simple background, ...",
value="",
lines=2,
visible=False,
)
def on_change_do_cfg_check(do_cfg: bool):
kwargs: dict = {"visible": do_cfg}
return gr.update(**kwargs), gr.update(**kwargs)
do_cfg_check.change(
on_change_do_cfg_check,
inputs=[do_cfg_check],
outputs=[cfg_scale_slider, negative_tags_textbox],
)
with gr.Group():
total_token_length_radio = gr.Radio(
label="Total token length",
choices=list(LENGTH_TAGS.keys()),
value="long",
)
with gr.Group():
max_new_tokens_slider = gr.Slider(
label="Max new tokens",
maximum=256,
minimum=1,
step=1,
value=128,
)
min_new_tokens_slider = gr.Slider(
label="Min new tokens",
maximum=255,
minimum=0,
step=1,
value=0,
)
temperature_slider = gr.Slider(
label="Temperature (larger is more random)",
maximum=1.0,
minimum=0.0,
step=0.1,
value=1.0,
)
top_p_slider = gr.Slider(
label="Top p (larger is more random)",
maximum=1.0,
minimum=0.0,
step=0.1,
value=1.0,
)
top_k_slider = gr.Slider(
label="Top k (larger is more random)",
maximum=500,
minimum=1,
step=1,
value=100,
)
num_beams_slider = gr.Slider(
label="Number of beams (smaller is more random)",
maximum=10,
minimum=1,
step=1,
value=1,
)
with gr.Column():
with gr.Group():
output_tags_natural = gr.Textbox(
label="Generation result",
# placeholder="tags will be here",
interactive=False,
)
output_tags_natural_copy_btn = gr.Button("Copy", visible=False)
output_tags_natural_copy_btn.click(
fn=copy_text,
inputs=[output_tags_natural],
js=COPY_ACTION_JS,
)
with gr.Group():
output_tags_general_only = gr.Textbox(
label="General tags only (sorted)",
interactive=False,
)
output_tags_general_only_copy_btn = gr.Button("Copy", visible=False)
output_tags_general_only_copy_btn.click(
fn=copy_text,
inputs=[output_tags_general_only],
js=COPY_ACTION_JS,
)
with gr.Group():
output_tags_animagine = gr.Textbox(
label="Output tags (AnimagineXL v3 style order)",
# placeholder="tags will be here in Animagine v3 style order",
interactive=False,
)
output_tags_animagine_copy_btn = gr.Button("Copy", visible=False)
output_tags_animagine_copy_btn.click(
fn=copy_text,
inputs=[output_tags_animagine],
js=COPY_ACTION_JS,
)
with gr.Accordion(label="Metadata", open=False):
input_prompt_raw = gr.Textbox(
label="Input prompt (raw)",
interactive=False,
lines=4,
)
output_tags_raw = gr.Textbox(
label="Output tags (raw)",
interactive=False,
lines=4,
)
elapsed_time_md = gr.Markdown(value="Waiting to generate...")
copyright_tags_mode_dropdown.change(
on_change_copyright_tags_dropdouwn,
inputs=[copyright_tags_mode_dropdown],
outputs=[copyright_tags_dropdown],
)
character_tags_mode_dropdown.change(
on_change_character_tags_dropdouwn,
inputs=[character_tags_mode_dropdown],
outputs=[character_tags_dropdown],
)
generate_btn.click(
handle_inputs,
inputs=[
rating_dropdown,
copyright_tags_dropdown,
character_tags_dropdown,
general_tags_textbox,
ban_tags_textbox,
do_cfg_check,
cfg_scale_slider,
negative_tags_textbox,
total_token_length_radio,
max_new_tokens_slider,
min_new_tokens_slider,
temperature_slider,
top_p_slider,
top_k_slider,
num_beams_slider,
model_backend_radio,
],
outputs=[
output_tags_natural,
output_tags_general_only,
output_tags_animagine,
input_prompt_raw,
output_tags_raw,
elapsed_time_md,
output_tags_natural_copy_btn,
output_tags_general_only_copy_btn,
output_tags_animagine_copy_btn,
],
)
ui.launch(
share=True,
)
if __name__ == "__main__":
demo()