Spaces:
Runtime error
Runtime error
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TextIteratorStreamer, | |
TextIteratorStreamer, | |
GenerationConfig, | |
) | |
from threading import Thread | |
import gradio as gr | |
import time | |
MARKDOWN_DESCRIPTION = """\ | |
# SDPrompt-RetNet-v2-beta | |
This is a demo of [SDPrompt-RetNet-v2-beta](https://huggingface.co/isek-ai/SDPrompt-RetNet-v2-beta), a pretrained RetNet model trained on [danbooru tags](isek-ai/danbooru-tags-2016-2023) dataset. | |
This model can only complete tags after the input text, so you have to start with a tag like `1girl` or `1boy, 1girl`. Also, this model generates tags in alphabetical order, you should place tags in alphabetical order in the input text. | |
""" | |
MODEL_NAME = "isek-ai/SDPrompt-RetNet-v2-beta" | |
DEVICE = "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.bfloat16, | |
device_map=DEVICE, | |
trust_remote_code=True, | |
) | |
model.config.use_cache = True | |
model.eval() | |
try: | |
model = torch.compile(model) | |
print("torch.compile is supported") | |
print("warming up...") | |
_ = model.generate( | |
inputs=tokenizer( | |
"1girl, arms behind back, brown eyes, brown hair, cat ears", | |
return_tensors="pt", | |
add_special_tokens=False, | |
).to(model.device)["input_ids"], | |
max_new_tokens=128, | |
) | |
print("warmed up") | |
except Exception as e: | |
print(e) | |
print("torch.compile is not supported") | |
streamer = TextIteratorStreamer( | |
tokenizer, | |
skip_prompt=False, | |
skip_special_tokens=True, | |
) | |
async def gen_stream( | |
input_text: str, | |
generation_config: GenerationConfig, | |
) -> TextIteratorStreamer: | |
inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False).to( | |
model.device | |
) | |
config = dict( | |
inputs=inputs["input_ids"], | |
streamer=streamer, | |
generation_config=generation_config, | |
) | |
thread = Thread(target=model.generate, kwargs=config) | |
thread.start() | |
return streamer | |
async def generate( | |
input_text: str, | |
max_new_tokens: int = 128, | |
do_sample: bool = True, | |
temperature: float = 1.0, | |
top_k: int = 20, | |
top_p: float = 0.95, | |
repetition_penalty: float = 1.5, | |
no_repeat_ngram_size: int = 3, | |
): | |
start = time.time() | |
generation_config = GenerationConfig( | |
max_new_tokens=max_new_tokens, | |
do_sample=do_sample, | |
top_k=top_k, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
num_beams=1, # 現時点では2以上を設定するとエラー (ビームサーチは使えない) | |
) | |
stream = await gen_stream(input_text, generation_config) | |
generated_text = "" | |
for words in stream: | |
if words is None or words == "": | |
continue | |
generated_text += words | |
yield generated_text, "Generating..." | |
elapsed = time.time() - start | |
yield generated_text, f"Elapsed: {elapsed:.2f} sec." | |
# ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2 | |
def copy_text(_text: str): | |
gr.Info("Copied!") | |
COPY_ACTION_JS = """\ | |
(inputs, _outputs) => { | |
// inputs is the string value of the input_text | |
navigator.clipboard.writeText(inputs); | |
}""" | |
def demo(): | |
with gr.Blocks() as ui: | |
gr.Markdown(MARKDOWN_DESCRIPTION) | |
with gr.Column(): | |
input_text = gr.Textbox(label="Input", value="1girl") | |
with gr.Group(): | |
with gr.Column(): | |
with gr.Group(): | |
output_text = gr.Text( | |
label="Generated prompt", | |
interactive=False, | |
) | |
copy_btn = gr.Button( | |
value="Copy prompt", | |
variant="secondary", | |
) | |
elapsed = gr.Markdown("") | |
with gr.Column(): | |
btn = gr.Button(value="Generate!", variant="primary") | |
with gr.Accordion(label="Advanced settings", open=False): | |
with gr.Group(): | |
with gr.Column(): | |
max_new_tokens_slider = gr.Slider( | |
label="Max new tokens", | |
minimum=4, | |
maximum=512, | |
step=4, | |
value=160, | |
) | |
do_sample_checkbox = gr.Checkbox( | |
label="Do sample", | |
value=True, | |
) | |
temperature_slider = gr.Slider( | |
label="Temperature", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=1.0, | |
) | |
top_k_slider = gr.Slider( | |
label="Top k", | |
minimum=0, | |
maximum=100, | |
step=1, | |
value=20, | |
) | |
top_p_slider = gr.Slider( | |
label="Top p", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.95, | |
) | |
repetition_penalty_slider = gr.Slider( | |
label="Repetition penalty", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.1, | |
value=1.1, | |
) | |
no_repeat_ngram_size_slider = gr.Slider( | |
label="No repeat n-gram size", | |
minimum=0, | |
maximum=10, | |
step=1, | |
value=6, | |
) | |
gr.Examples( | |
[ | |
"1girl, blue hair", | |
"1girl, animal ears", | |
"1girl, detached sleeves", | |
"1boy, 1girl", | |
"1other", | |
], | |
inputs=[input_text], | |
) | |
btn.click( | |
fn=generate, | |
inputs=[ | |
input_text, | |
max_new_tokens_slider, | |
do_sample_checkbox, | |
temperature_slider, | |
top_k_slider, | |
top_p_slider, | |
repetition_penalty_slider, | |
no_repeat_ngram_size_slider, | |
], | |
outputs=[output_text, elapsed], | |
) | |
copy_btn.click( | |
fn=copy_text, | |
inputs=[output_text], | |
js=COPY_ACTION_JS, | |
) | |
ui.queue() | |
ui.launch(debug=True) | |
if __name__ == "__main__": | |
try: | |
demo() | |
except KeyboardInterrupt: | |
pass | |