File size: 6,236 Bytes
2ea65a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96469d5
 
 
2ea65a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96469d5
2ea65a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b39ce8e
2ea65a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
from torchvision.utils import make_grid
import math
from PIL import Image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
import gradio as gr
from imagenet_class_data import IMAGENET_1K_CLASSES
from download import find_model
from models import DiT_XL_2


def load_model(image_size=256):
    assert image_size in [256, 512]
    latent_size = image_size // 8
    model = DiT_XL_2(input_size=latent_size).to(device)
    state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
    model.load_state_dict(state_dict)
    model.eval()
    return model


torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model(image_size=256)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
current_image_size = 256
current_vae_model = "stabilityai/sd-vae-ft-mse"


def generate(image_size, vae_model, class_label, cfg_scale, num_sampling_steps, n, seed):
    image_size = int(image_size.split("x")[0])
    global current_image_size
    if image_size != current_image_size:
        global model
        del model
        # if device == "cuda":
        #     torch.cuda.empty_cache()
        model = load_model(image_size=image_size)
    current_image_size = image_size

    global current_vae_model
    if vae_model != current_vae_model:
        global vae
        if device == "cuda":
            vae.to("cpu")
        del vae
        vae = AutoencoderKL.from_pretrained(vae_model).to(device)

    # Seed PyTorch:
    torch.manual_seed(seed)

    # Setup diffusion
    diffusion = create_diffusion(str(num_sampling_steps))

    # Create sampling noise:
    latent_size = image_size // 8
    z = torch.randn(n, 4, latent_size, latent_size, device=device)
    y = torch.tensor([class_label] * n, device=device)

    # Setup classifier-free guidance:
    z = torch.cat([z, z], 0)
    y_null = torch.tensor([1000] * n, device=device)
    y = torch.cat([y, y_null], 0)
    model_kwargs = dict(y=y, cfg_scale=cfg_scale)

    # Sample images:
    samples = diffusion.p_sample_loop(
        model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
    )
    samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
    samples = vae.decode(samples / 0.18215).sample

    # Convert to PIL.Image format:
    samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
    samples = [Image.fromarray(sample) for sample in samples]
    return samples


description = '''This is a demo of our DiT image generation models. DiTs are a new class of diffusion models with
transformer backbones. They are class-conditional models trained on ImageNet-1K, and they outperform prior DDPMs.'''

duplicate = '''Skip the queue by duplicating this space and upgrading to GPU in settings
<a href="https://huggingface.co/spaces/wpeebles/DiT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>'''

project_links = '''
<p style="text-align: center">
<a href="https://www.wpeebles.com/DiT.html">Project Page</a> &#183;
<a href="http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb">Colab</a> &#183;
<a href="http://arxiv.org/abs/2212.09748">Paper</a> &#183;
<a href="https://github.com/facebookresearch/DiT">GitHub</a></p>'''

examples = [
    ["512x512", "stabilityai/sd-vae-ft-mse", "golden retriever", 4.0, 200, 4, 1000],
    ["512x512", "stabilityai/sd-vae-ft-mse", "macaw", 4.0, 200, 4, 1],
    ["512x512", "stabilityai/sd-vae-ft-mse", "balloon", 4.0, 200, 4, 1],
    ["512x512", "stabilityai/sd-vae-ft-mse", "cliff, drop, drop-off", 4.0, 200, 4, 7],
    ["512x512", "stabilityai/sd-vae-ft-mse", "Pembroke, Pembroke Welsh corgi", 4.0, 200, 4, 0],
    ["256x256", "stabilityai/sd-vae-ft-mse", "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", 4.0, 200,
     4, 1],
    ["256x256", "stabilityai/sd-vae-ft-mse", "teddy, teddy bear", 4.0, 200, 4, 3],
    ["256x256", "stabilityai/sd-vae-ft-mse", "cheeseburger", 4.0, 200, 4, 2],

]

with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center'>Scalable Diffusion Models with Transformers (DiT)</h1>")
    gr.Markdown(project_links)
    gr.Markdown(description)
    gr.Markdown(duplicate)

    with gr.Tabs():
        with gr.TabItem('Generate'):
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        image_size = gr.inputs.Radio(choices=["256x256", "512x512"], default="256x256", label='DiT Model Resolution')
                        vae_model = gr.inputs.Radio(choices=["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"],
                                                    default="stabilityai/sd-vae-ft-mse", label='VAE Decoder')
                    with gr.Row():
                        i1k_class = gr.inputs.Dropdown(
                            list(IMAGENET_1K_CLASSES.values()),
                            default='golden retriever',
                            type="index", label='ImageNet-1K Class'
                        )
                    cfg_scale = gr.inputs.Slider(minimum=1, maximum=25, step=0.1, default=4.0, label='Classifier-free Guidance Scale')
                    steps = gr.inputs.Slider(minimum=4, maximum=1000, step=1, default=75, label='Sampling Steps')
                    n = gr.inputs.Slider(minimum=1, maximum=16, step=1, default=1, label='Number of Samples')
                    seed = gr.inputs.Number(default=0, label='Seed')
                    button = gr.Button("Generate", variant="primary")
                with gr.Column():
                    output = gr.Gallery(label='Generated Images').style(grid=[2], height="auto")
                    button.click(generate, inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed], outputs=[output])
            with gr.Row():
                ex = gr.Examples(examples=examples, fn=generate,
                                 inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, n, seed],
                                 outputs=[output],
                                 cache_examples=True)

    demo.queue()
    demo.launch()