Azure99's picture
Update app.py
772f8c2 verified
raw history blame
No virus
5.94 kB
import json
import random
import uuid
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
device = torch.device("cuda:0")
llm = AutoModelForCausalLM.from_pretrained("Azure99/blossom-v5.1-9b", torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("Azure99/blossom-v5.1-9b")
diffusion_pipe = DiffusionPipeline.from_pretrained(
"playgroundai/playground-v2.5-1024px-aesthetic",
torch_dtype=torch.float16,
use_safetensors=True,
add_watermarker=False,
variant="fp16"
).to(device)
def get_input_ids(inst, bot_prefix):
return tokenizer.encode("A chat between a human and an artificial intelligence bot. "
"The bot gives helpful, detailed, and polite answers to the human's questions.\n"
f"|Human|: {inst}\n|Bot|: {bot_prefix}", add_special_tokens=True)
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
LLM_PROMPT = '''你的任务是从输入的[作画要求]中抽取画面描述(description),然后description翻译为英文(en_description),最后对en_description进行扩写(expanded_description),增加足够多的细节,且符合人类的第一直觉。
[输出]是一个json,包含description、en_description、expanded_description三个字符串字段,请直接输出一个完整的json,不要输出任何解释或其他无关内容。
下面是一些示例:
[作画要求]->"画一幅画:落霞与孤鹜齐飞,秋水共长天一色。"
[输出]->{"description": "落霞与孤鹜齐飞,秋水共长天一色", "en_description": "The setting sun and the solitary duck fly together, the autumn water shares a single hue with the vast sky", "expanded_description": "A lone duck gracefully gliding across the tranquil surface of a shimmering lake, bathed in the warm golden glow of the setting sun, creating a breathtaking scene of natural beauty and tranquility."}
[作画要求]->"原神中的可莉"
[输出]->{"description": "原神中的可莉", "en_description": "Klee in Genshin Impact", "expanded_description": "An artistic portrait of Klee from Genshin Impact, standing in a vibrant meadow with colorful explosions of her elemental abilities in the background."}
[作画要求]->"create an image for me. a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone"
[输出]->{"description": "a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone", "en_description": "a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone", "expanded_description": "A close-up portrait of an elegant woman with rich brown skin, wearing a stunning transparent, prismatic, and intricately detailed Nemes headdress, striking a confident and alluring over-the-shoulder pose."}
[作画要求]->"一只高贵的柯基犬,素描画风格\n根据上面的描述生成一张图片吧!"
[输出]->{"description": "一只高贵的柯基犬,素描画风格", "en_description": "A noble corgi dog, sketch style", "expanded_description": "A majestic corgi with a regal bearing, depicted in a detailed and intricate pencil sketch, capturing the essence of its noble lineage and dignified presence."}
[作画要求]->$USER_PROMPT
[输出]->'''
BOT_PREFIX = '{"description": "'
@spaces.GPU(enable_queue=True)
def generate(
prompt: str,
progress=gr.Progress(track_tqdm=True),
):
input_ids = get_input_ids(LLM_PROMPT.replace("$USER_PROMPT", json.dumps(prompt, ensure_ascii=False)), BOT_PREFIX)
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
llm_result = llm.generate(**generation_kwargs)
llm_result = llm_result.cpu()[0][len(input_ids):]
llm_result = BOT_PREFIX + tokenizer.decode(llm_result, skip_special_tokens=True)
print("----------")
print(prompt)
print(llm_result)
en_prompt = prompt
expanded_prompt = prompt
try:
en_prompt = json.loads(llm_result)["en_description"]
expanded_prompt = json.loads(llm_result)["expanded_description"]
except:
print("error, fallback to original prompt")
pass
seed = random.randint(0, 2147483647)
generator = torch.Generator().manual_seed(seed)
images = diffusion_pipe(
prompt=[expanded_prompt, en_prompt],
negative_prompt=None,
width=1024,
height=1024,
guidance_scale=3,
num_inference_steps=25,
generator=generator,
num_images_per_prompt=1,
use_resolution_binning=True,
output_type="pil",
).images
image_paths = [save_image(img) for img in images]
return image_paths
css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown("# Blossom & Playground v2.5")
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Gallery(label="Result", columns=2, rows=1, show_label=False)
gr.on(
triggers=[
prompt.submit,
run_button.click,
],
fn=generate,
inputs=[
prompt,
],
outputs=[result],
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()