File size: 3,953 Bytes
94c2073
719e8a8
94c2073
 
 
 
 
 
95174b6
 
 
 
 
 
 
 
 
 
 
94c2073
 
 
 
 
 
 
 
 
 
 
 
 
 
719e8a8
94c2073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95174b6
 
94c2073
 
 
55427bc
94c2073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95174b6
94c2073
 
 
 
 
 
 
 
 
55427bc
94c2073
 
 
95174b6
94c2073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b961cc
94c2073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import random
import spaces
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline

from scheduling_tcd import TCDScheduler

css = """
h1 {
    text-align: center;
    display:block;
}
h3 {
    text-align: center;
    display:block;
}
"""

device = "cuda"
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
tcd_lora_id = "h1t/TCD-SDXL-LoRA"

pipe = StableDiffusionXLPipeline.from_pretrained(
    base_model_id,
    torch_dtype=torch.float16,
    variant="fp16"
).to(device)
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)

pipe.load_lora_weights(tcd_lora_id)
pipe.fuse_lora()

@spaces.GPU
def inference(prompt, num_inference_steps=4, seed=-1, eta=0.3):
    if seed is None or seed == '' or seed == -1:
        seed = int(random.randrange(4294967294))
    generator = torch.Generator(device=device).manual_seed(int(seed))
    image = pipe(
        prompt=prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=0,
        eta=eta,
        generator=generator,
    ).images[0]
    return image


# Define style
title = "<h1>Trajectory Consistency Distillation</h1>"
description = "<h3>Official 🤗 Gradio demo for Trajectory Consistency Distillation</h3>"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/' target='_blank'>Trajectory Consistency Distillation</a> | <a href='https://github.com/jabir-zheng/TCD' target='_blank'>Github Repo</a></p>"


default_prompt = " "
examples = [
    [
        "Beautiful woman, bubblegum pink, lemon yellow, minty blue, futuristic, high-detail, epic composition, watercolor.",
        4
    ],
    [
        "Beautiful man, bubblegum pink, lemon yellow, minty blue, futuristic, high-detail, epic composition, watercolor.",
        8
    ],
    [
        "Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.",
        16
    ],
    [
        "closeup portrait of 1 Persian princess, royal clothing, makeup, jewelry, wind-blown long hair, symmetric, desert, sands, dusty and foggy, sand storm, winds bokeh, depth of field, centered.",
        16
    ],
]

outputs = gr.Label(label='Generated Images')

with gr.Blocks(css=css) as demo:
    gr.Markdown(f'# {title}\n### {description}')
    
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label='Prompt', value=default_prompt)
            num_inference_steps = gr.Slider(
                label='Inference steps',
                minimum=4,
                maximum=16,
                value=8,
                step=1,
            )
            
            with gr.Accordion("Advanced Options", open=False):
                with gr.Row():
                    with gr.Column():
                        seed = gr.Number(label="Random Seed", value=-1)
                    with gr.Column():
                        eta = gr.Slider(
                                label='Gamma',
                                minimum=0.,
                                maximum=1.,
                                value=0.3,
                                step=0.1,
                            )

            with gr.Row():
                clear = gr.ClearButton(
                    components=[prompt, num_inference_steps, seed, eta])
                submit = gr.Button(value='Submit')

            examples = gr.Examples(
                label="Quick Examples",
                examples=examples,
                inputs=[prompt, num_inference_steps, 0, 0.3],
                outputs="outputs",
                cache_examples=False
            )

        with gr.Column():
            outputs = gr.Image(label='Generated Images')

    gr.Markdown(f'{article}')

    submit.click(
        fn=inference,
        inputs=[prompt, num_inference_steps, seed, eta],
        outputs=outputs,
    )

demo.launch()