File size: 3,688 Bytes
79fcfbd
 
78ea680
79fcfbd
ee69d30
79fcfbd
 
 
 
78ea680
79fcfbd
4bc008c
79fcfbd
 
 
 
78ea680
 
 
 
 
 
 
 
79fcfbd
 
 
 
ee69d30
 
 
 
 
 
 
 
79fcfbd
 
 
 
 
 
 
 
 
 
 
 
 
 
ee69d30
 
 
 
 
 
 
 
 
79fcfbd
 
ee69d30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79fcfbd
ee69d30
 
79fcfbd
 
 
ee69d30
 
 
79fcfbd
ee69d30
79fcfbd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download

import spaces  # [uncomment to use ZeroGPU]
from diffusers import DiffusionPipeline
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"  # Replace to the model you would like to use

torch_dtype = torch.bfloat16

pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)

# load pruned model
pruned_pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pruned_pipe.transformer = torch.load(
    hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"),
    map_location="cpu",
)
pruned_pipe = pruned_pipe.to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024


@spaces.GPU  # [uncomment to use ZeroGPU]
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
    # Run the model and return images directly
    g_cpu = torch.Generator("cuda").manual_seed(seed)
    original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    g_cpu = torch.Generator("cuda").manual_seed(seed)
    ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    return original_image, ecodiff_image


examples = [
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
    "A delicious ceviche cheesecake slice",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""
header = """
# 🌱 Text-to-Image Generation with EcoDiff Pruned SD-XL (20% Pruning Ratio)
# Under Construction!!!
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
<a href="https://github.com/YaNgZhAnG-V5/EcoDiff"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
</div>
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(header)
    with gr.Row():
        prompt = gr.Textbox(
            label="Prompt",
            value="A clock tower floating in a sea of clouds",
            scale=3,
        )
        seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
        steps = gr.Slider(
            label="Number of Steps",
            minimum=1,
            maximum=100,
            value=50,
            step=1,
            scale=1,
        )
        generate_btn = gr.Button("Generate Images")
    gr.Examples(
        examples=[
            "A clock tower floating in a sea of clouds",
            "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
            "An astronaut riding a green horse",
            "A delicious ceviche cheesecake slice",
            "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
        ],
        inputs=[prompt],
    )
    with gr.Row():
        original_output = gr.Image(label="Original Output")
        ecodiff_output = gr.Image(label="EcoDiff Output")
    gr.on(
        triggers=[generate_btn.click, prompt.submit],
        fn=generate_images,
        inputs=[
            prompt,
            seed,
            steps,
            pipe,
            pipe,
        ],
        outputs=[original_output, ecodiff_output],
    )

if __name__ == "__main__":
    demo.launch()