File size: 5,939 Bytes
6ae32bf
690d6e4
 
 
 
 
 
 
6ae32bf
690d6e4
 
 
772f8c2
 
6ae32bf
690d6e4
 
 
 
 
885aeb8
6ae32bf
 
 
 
 
 
690d6e4
 
 
 
 
 
 
 
6ae32bf
c5bb7ca
70bee57
6ae32bf
 
70bee57
 
6ae32bf
70bee57
 
6ae32bf
70bee57
 
6ae32bf
70bee57
 
6ae32bf
 
 
 
 
 
690d6e4
 
 
 
 
6bdb935
6ae32bf
 
 
f6bc3e9
 
c5bb7ca
 
6bdb935
c5bb7ca
 
 
 
 
 
 
 
6ae32bf
690d6e4
 
 
6ae32bf
effd684
690d6e4
 
 
 
 
 
 
 
 
 
 
 
 
 
6ae32bf
9d40320
70bee57
 
 
9d40320
effd684
690d6e4
 
 
 
 
 
 
 
 
 
c5bb7ca
690d6e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
import random
import uuid

import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from transformers import AutoModelForCausalLM, AutoTokenizer

device = torch.device("cuda:0")

llm = AutoModelForCausalLM.from_pretrained("Azure99/blossom-v5.1-9b", torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("Azure99/blossom-v5.1-9b")
diffusion_pipe = DiffusionPipeline.from_pretrained(
    "playgroundai/playground-v2.5-1024px-aesthetic",
    torch_dtype=torch.float16,
    use_safetensors=True,
    add_watermarker=False,
    variant="fp16"
).to(device)


def get_input_ids(inst, bot_prefix):
    return tokenizer.encode("A chat between a human and an artificial intelligence bot. "
                            "The bot gives helpful, detailed, and polite answers to the human's questions.\n"
                            f"|Human|: {inst}\n|Bot|: {bot_prefix}", add_special_tokens=True)


def save_image(img):
    unique_name = str(uuid.uuid4()) + ".png"
    img.save(unique_name)
    return unique_name


LLM_PROMPT = '''你的任务是从输入的[作画要求]中抽取画面描述(description),然后description翻译为英文(en_description),最后对en_description进行扩写(expanded_description),增加足够多的细节,且符合人类的第一直觉。
[输出]是一个json,包含description、en_description、expanded_description三个字符串字段,请直接输出一个完整的json,不要输出任何解释或其他无关内容。
    
下面是一些示例:
[作画要求]->"画一幅画:落霞与孤鹜齐飞,秋水共长天一色。"
[输出]->{"description": "落霞与孤鹜齐飞,秋水共长天一色", "en_description": "The setting sun and the solitary duck fly together, the autumn water shares a single hue with the vast sky", "expanded_description": "A lone duck gracefully gliding across the tranquil surface of a shimmering lake, bathed in the warm golden glow of the setting sun, creating a breathtaking scene of natural beauty and tranquility."}
    
[作画要求]->"原神中的可莉"
[输出]->{"description": "原神中的可莉", "en_description": "Klee in Genshin Impact", "expanded_description": "An artistic portrait of Klee from Genshin Impact, standing in a vibrant meadow with colorful explosions of her elemental abilities in the background."}
    
[作画要求]->"create an image for me. a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone"
[输出]->{"description": "a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone", "en_description": "a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone", "expanded_description": "A close-up portrait of an elegant woman with rich brown skin, wearing a stunning transparent, prismatic, and intricately detailed Nemes headdress, striking a confident and alluring over-the-shoulder pose."}
    
[作画要求]->"一只高贵的柯基犬,素描画风格\n根据上面的描述生成一张图片吧!"
[输出]->{"description": "一只高贵的柯基犬,素描画风格", "en_description": "A noble corgi dog, sketch style", "expanded_description": "A majestic corgi with a regal bearing, depicted in a detailed and intricate pencil sketch, capturing the essence of its noble lineage and dignified presence."}
    
[作画要求]->$USER_PROMPT
[输出]->'''

BOT_PREFIX = '{"description": "'


@spaces.GPU(enable_queue=True)
def generate(
        prompt: str,
        progress=gr.Progress(track_tqdm=True),
):
    input_ids = get_input_ids(LLM_PROMPT.replace("$USER_PROMPT", json.dumps(prompt, ensure_ascii=False)), BOT_PREFIX)
    generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
                             max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
    llm_result = llm.generate(**generation_kwargs)
    llm_result = llm_result.cpu()[0][len(input_ids):]
    llm_result = BOT_PREFIX + tokenizer.decode(llm_result, skip_special_tokens=True)
    print("----------")
    print(prompt)
    print(llm_result)
    en_prompt = prompt
    expanded_prompt = prompt
    try:
        en_prompt = json.loads(llm_result)["en_description"]
        expanded_prompt = json.loads(llm_result)["expanded_description"]
    except:
        print("error, fallback to original prompt")
        pass

    seed = random.randint(0, 2147483647)
    generator = torch.Generator().manual_seed(seed)

    images = diffusion_pipe(
        prompt=[expanded_prompt, en_prompt],
        negative_prompt=None,
        width=1024,
        height=1024,
        guidance_scale=3,
        num_inference_steps=25,
        generator=generator,
        num_images_per_prompt=1,
        use_resolution_binning=True,
        output_type="pil",
    ).images

    image_paths = [save_image(img) for img in images]
    return image_paths


css = '''
    .gradio-container{max-width: 560px !important}
    h1{text-align:center}
    '''
with gr.Blocks(css=css) as demo:
    gr.Markdown("# Blossom & Playground v2.5")
    with gr.Group():
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        result = gr.Gallery(label="Result", columns=2, rows=1, show_label=False)

    gr.on(
        triggers=[
            prompt.submit,
            run_button.click,
        ],
        fn=generate,
        inputs=[
            prompt,
        ],
        outputs=[result],
        api_name="run",
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()