p1atdev's picture
fix: parameter name
2cc70f5
raw
history blame contribute delete
No virus
6.86 kB
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
TextIteratorStreamer,
GenerationConfig,
)
from threading import Thread
import gradio as gr
import time
MARKDOWN_DESCRIPTION = """\
# SDPrompt-RetNet-v2-beta
This is a demo of [SDPrompt-RetNet-v2-beta](https://huggingface.co/isek-ai/SDPrompt-RetNet-v2-beta), a pretrained RetNet model trained on [danbooru tags](isek-ai/danbooru-tags-2016-2023) dataset.
This model can only complete tags after the input text, so you have to start with a tag like `1girl` or `1boy, 1girl`. Also, this model generates tags in alphabetical order, you should place tags in alphabetical order in the input text.
"""
MODEL_NAME = "isek-ai/SDPrompt-RetNet-v2-beta"
DEVICE = "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map=DEVICE,
trust_remote_code=True,
)
model.config.use_cache = True
model.eval()
try:
model = torch.compile(model)
print("torch.compile is supported")
print("warming up...")
_ = model.generate(
inputs=tokenizer(
"1girl, arms behind back, brown eyes, brown hair, cat ears",
return_tensors="pt",
add_special_tokens=False,
).to(model.device)["input_ids"],
max_new_tokens=128,
)
print("warmed up")
except Exception as e:
print(e)
print("torch.compile is not supported")
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=False,
skip_special_tokens=True,
)
async def gen_stream(
input_text: str,
generation_config: GenerationConfig,
) -> TextIteratorStreamer:
inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False).to(
model.device
)
config = dict(
inputs=inputs["input_ids"],
streamer=streamer,
generation_config=generation_config,
)
thread = Thread(target=model.generate, kwargs=config)
thread.start()
return streamer
async def generate(
input_text: str,
max_new_tokens: int = 128,
do_sample: bool = True,
temperature: float = 1.0,
top_k: int = 20,
top_p: float = 0.95,
repetition_penalty: float = 1.5,
no_repeat_ngram_size: int = 3,
):
start = time.time()
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
num_beams=1, # 現時点では2以上を設定するとエラー (ビームサーチは使えない)
)
stream = await gen_stream(input_text, generation_config)
generated_text = ""
for words in stream:
if words is None or words == "":
continue
generated_text += words
yield generated_text, "Generating..."
elapsed = time.time() - start
yield generated_text, f"Elapsed: {elapsed:.2f} sec."
# ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
def copy_text(_text: str):
gr.Info("Copied!")
COPY_ACTION_JS = """\
(inputs, _outputs) => {
// inputs is the string value of the input_text
navigator.clipboard.writeText(inputs);
}"""
def demo():
with gr.Blocks() as ui:
gr.Markdown(MARKDOWN_DESCRIPTION)
with gr.Column():
input_text = gr.Textbox(label="Input", value="1girl")
with gr.Group():
with gr.Column():
with gr.Group():
output_text = gr.Text(
label="Generated prompt",
interactive=False,
)
copy_btn = gr.Button(
value="Copy prompt",
variant="secondary",
)
elapsed = gr.Markdown("")
with gr.Column():
btn = gr.Button(value="Generate!", variant="primary")
with gr.Accordion(label="Advanced settings", open=False):
with gr.Group():
with gr.Column():
max_new_tokens_slider = gr.Slider(
label="Max new tokens",
minimum=4,
maximum=512,
step=4,
value=160,
)
do_sample_checkbox = gr.Checkbox(
label="Do sample",
value=True,
)
temperature_slider = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
top_k_slider = gr.Slider(
label="Top k",
minimum=0,
maximum=100,
step=1,
value=20,
)
top_p_slider = gr.Slider(
label="Top p",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.95,
)
repetition_penalty_slider = gr.Slider(
label="Repetition penalty",
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.1,
)
no_repeat_ngram_size_slider = gr.Slider(
label="No repeat n-gram size",
minimum=0,
maximum=10,
step=1,
value=6,
)
gr.Examples(
[
"1girl, blue hair",
"1girl, animal ears",
"1girl, detached sleeves",
"1boy, 1girl",
"1other",
],
inputs=[input_text],
)
btn.click(
fn=generate,
inputs=[
input_text,
max_new_tokens_slider,
do_sample_checkbox,
temperature_slider,
top_k_slider,
top_p_slider,
repetition_penalty_slider,
no_repeat_ngram_size_slider,
],
outputs=[output_text, elapsed],
)
copy_btn.click(
fn=copy_text,
inputs=[output_text],
js=COPY_ACTION_JS,
)
ui.queue()
ui.launch(debug=True)
if __name__ == "__main__":
try:
demo()
except KeyboardInterrupt:
pass