import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, TextIteratorStreamer, GenerationConfig, ) from threading import Thread import gradio as gr import time MARKDOWN_DESCRIPTION = """\ # SDPrompt-RetNet-v2-beta This is a demo of [SDPrompt-RetNet-v2-beta](https://huggingface.co/isek-ai/SDPrompt-RetNet-v2-beta), a pretrained RetNet model trained on [danbooru tags](isek-ai/danbooru-tags-2016-2023) dataset. This model can only complete tags after the input text, so you have to start with a tag like `1girl` or `1boy, 1girl`. Also, this model generates tags in alphabetical order, you should place tags in alphabetical order in the input text. """ MODEL_NAME = "isek-ai/SDPrompt-RetNet-v2-beta" DEVICE = "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16, device_map=DEVICE, trust_remote_code=True, ) model.config.use_cache = True model.eval() try: model = torch.compile(model) print("torch.compile is supported") print("warming up...") _ = model.generate( inputs=tokenizer( "1girl, arms behind back, brown eyes, brown hair, cat ears", return_tensors="pt", add_special_tokens=False, ).to(model.device)["input_ids"], max_new_tokens=128, ) print("warmed up") except Exception as e: print(e) print("torch.compile is not supported") streamer = TextIteratorStreamer( tokenizer, skip_prompt=False, skip_special_tokens=True, ) async def gen_stream( input_text: str, generation_config: GenerationConfig, ) -> TextIteratorStreamer: inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False).to( model.device ) config = dict( inputs=inputs["input_ids"], streamer=streamer, generation_config=generation_config, ) thread = Thread(target=model.generate, kwargs=config) thread.start() return streamer async def generate( input_text: str, max_new_tokens: int = 128, do_sample: bool = True, temperature: float = 1.0, top_k: int = 20, top_p: float = 0.95, repetition_penalty: float = 1.5, no_repeat_ngram_size: int = 3, ): start = time.time() generation_config = GenerationConfig( max_new_tokens=max_new_tokens, do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, num_beams=1, # 現時点では2以上を設定するとエラー (ビームサーチは使えない) ) stream = await gen_stream(input_text, generation_config) generated_text = "" for words in stream: if words is None or words == "": continue generated_text += words yield generated_text, "Generating..." elapsed = time.time() - start yield generated_text, f"Elapsed: {elapsed:.2f} sec." # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2 def copy_text(_text: str): gr.Info("Copied!") COPY_ACTION_JS = """\ (inputs, _outputs) => { // inputs is the string value of the input_text navigator.clipboard.writeText(inputs); }""" def demo(): with gr.Blocks() as ui: gr.Markdown(MARKDOWN_DESCRIPTION) with gr.Column(): input_text = gr.Textbox(label="Input", value="1girl") with gr.Group(): with gr.Column(): with gr.Group(): output_text = gr.Text( label="Generated prompt", interactive=False, ) copy_btn = gr.Button( value="Copy prompt", variant="secondary", ) elapsed = gr.Markdown("") with gr.Column(): btn = gr.Button(value="Generate!", variant="primary") with gr.Accordion(label="Advanced settings", open=False): with gr.Group(): with gr.Column(): max_new_tokens_slider = gr.Slider( label="Max new tokens", minimum=4, maximum=512, step=4, value=160, ) do_sample_checkbox = gr.Checkbox( label="Do sample", value=True, ) temperature_slider = gr.Slider( label="Temperature", minimum=0.0, maximum=1.0, step=0.01, value=1.0, ) top_k_slider = gr.Slider( label="Top k", minimum=0, maximum=100, step=1, value=20, ) top_p_slider = gr.Slider( label="Top p", minimum=0.0, maximum=1.0, step=0.01, value=0.95, ) repetition_penalty_slider = gr.Slider( label="Repetition penalty", minimum=0.0, maximum=2.0, step=0.1, value=1.1, ) no_repeat_ngram_size_slider = gr.Slider( label="No repeat n-gram size", minimum=0, maximum=10, step=1, value=6, ) gr.Examples( [ "1girl, blue hair", "1girl, animal ears", "1girl, detached sleeves", "1boy, 1girl", "1other", ], inputs=[input_text], ) btn.click( fn=generate, inputs=[ input_text, max_new_tokens_slider, do_sample_checkbox, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, no_repeat_ngram_size_slider, ], outputs=[output_text, elapsed], ) copy_btn.click( fn=copy_text, inputs=[output_text], js=COPY_ACTION_JS, ) ui.queue() ui.launch(debug=True) if __name__ == "__main__": try: demo() except KeyboardInterrupt: pass