File size: 6,800 Bytes
49f5a92
82d824b
eb3568a
49f5a92
bb7ee19
82d824b
bb7ee19
82d824b
0bb8ff5
82d824b
5e20c42
82d824b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f5a92
bb7ee19
 
 
 
 
 
 
 
 
48cfefd
bb7ee19
 
 
48cfefd
 
 
bb7ee19
48cfefd
bb7ee19
 
49f5a92
82d824b
bb7ee19
 
 
 
 
48cfefd
49f5a92
82d824b
 
 
5e20c42
eb3568a
82d824b
5e20c42
 
82d824b
5e20c42
48cfefd
5e20c42
48cfefd
82d824b
 
 
bb7ee19
82d824b
bb7ee19
82d824b
49f5a92
 
82d824b
48cfefd
82d824b
 
 
48cfefd
30476a5
13eec0b
30476a5
 
 
 
 
 
 
 
c13880f
48cfefd
 
82d824b
 
3735327
 
82d824b
 
 
d52f064
82d824b
49f5a92
82d824b
48cfefd
 
 
 
bb7ee19
58dde5b
 
 
 
b5d1ac9
 
58dde5b
 
82d824b
bb7ee19
 
 
 
 
82d824b
bb7ee19
 
 
 
 
 
 
 
82d824b
 
 
 
 
 
 
bb7ee19
82d824b
 
 
 
 
 
 
 
 
bb7ee19
82d824b
 
 
 
 
 
 
bb7ee19
82d824b
 
 
 
 
 
 
 
 
49f5a92
 
 
82d824b
48cfefd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import gradio as gr
from dataclasses import dataclass
import spaces
import torch
from huggingface_hub import hf_hub_download

from diffusers import StableDiffusionXLPipeline, FluxPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"


@dataclass
class GradioArgs:
    seed: list = None
    prompt: str = None
    mix_precision: str = "bf16"
    num_intervention_steps: int = 50
    model: str = "sdxl"
    binary: bool = False
    masking: str = "binary"
    scope: str = "global"
    ratio: list = None
    width: int = None
    height: int = None
    epsilon: float = 0.0
    lambda_threshold: float = 0.001

    def __post_init__(self):
        if self.seed is None:
            self.seed = [44]


def binary_mask_eval(args, model):
    model = model.lower()
    # load sdxl model
    if model == "sdxl":
        pruned_pipe = StableDiffusionXLPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
        ).to("cpu")
        pruned_pipe.unet = torch.load(
            hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"),
            map_location="cpu",
        )
    elif model == "flux":
        pruned_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(
            "cpu"
        )
        pruned_pipe.transformer = torch.load(
            hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"),
            map_location="cpu",
        )

    # reload the original model
    if model == "sdxl":
        pipe = StableDiffusionXLPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
        ).to("cpu")
    elif model == "flux":
        pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cpu")

    print("prune complete")
    return pipe, pruned_pipe


@spaces.GPU
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
    pipe.to("cuda")
    pruned_pipe.to("cuda")
    # Run the model and return images directly
    g_cpu = torch.Generator("cuda").manual_seed(seed)
    original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    g_cpu = torch.Generator("cuda").manual_seed(seed)
    ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    return original_image, ecodiff_image


def on_prune_click(prompt, seed, steps, model):
    args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps)
    pipe, pruned_pipe = binary_mask_eval(args, model)
    return pipe, pruned_pipe, [("Model Initialized", "green")]


def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
    original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe)
    return original_image, ecodiff_image


header = """
# 🌱 Text-to-Image Generation with EcoDiff Pruned Models

![Static Badge](https://img.shields.io/badge/ariXv-Paper-A42C25?link=https://arxiv.org/abs/2412.02852)
![Static Badge](https://img.shields.io/badge/🤗-Model-ffbd45?link=https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels)
![Static Badge](https://img.shields.io/badge/GitHub-Code-blue?logo=github&link=https://github.com/YaNgZhAnG-V5/EcoDiff)
"""
header_2 = """
For ⚡ <b>faster</b> ⚡ DEMO on one model only, please visit
![Static Badge](https://img.shields.io/badge/SDXL-fedcba?link=https://huggingface.co/spaces/zhangyang-0123/EcoDiff-SD-XL)
![Static Badge](https://img.shields.io/badge/FLUX-fgdfba?link=https://huggingface.co/spaces/zhangyang-0123/EcoDiff-FLUX-Schnell)
"""


def create_demo():
    with gr.Blocks() as demo:
        gr.Markdown(header)
        gr.Markdown(header_2)
        with gr.Row():
            gr.Markdown(
                """
                **Note: Please first initialize the model before generating images. This may take a while to fully load.**
                """
            )
        with gr.Row():
            model_choice = gr.Radio(choices=["SDXL", "FLUX"], value="SDXL", label="Model", scale=2)
            pruning_ratio = gr.Text("20% Pruning Ratio for SDXL, FLUX", label="Pruning Ratio", scale=2)
            status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
            prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)

        with gr.Row():
            gr.Markdown(
                """
                **Generate images with the original model and the pruned model. May take up to 1 minute due to dynamic allocation of GPU.**
                
                **Note: we prune on step-distilled FLUX, you should use step 5 (instead of 50) for FLUX generation.**
                """
            )
        with gr.Row():
            prompt = gr.Textbox(
                label="Prompt",
                value="A clock tower floating in a sea of clouds",
                scale=3,
            )
            seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
            steps = gr.Slider(
                label="Number of Steps",
                minimum=1,
                maximum=100,
                value=50,
                step=1,
                scale=1,
            )
            generate_btn = gr.Button("Generate Images")
        gr.Examples(
            examples=[
                "A clock tower floating in a sea of clouds",
                "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
                "An astronaut riding a green horse",
                "A delicious ceviche cheesecake slice",
                "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
            ],
            inputs=[prompt],
        )
        with gr.Row():
            original_output = gr.Image(label="Original Output")
            ecodiff_output = gr.Image(label="EcoDiff Output")

        pipe_state = gr.State(None)
        pruned_pipe_state = gr.State(None)

        prompt.submit(
            fn=on_generate_click,
            inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
            outputs=[original_output, ecodiff_output],
        )
        prune_btn.click(
            fn=on_prune_click,
            inputs=[prompt, seed, steps, model_choice],
            outputs=[pipe_state, pruned_pipe_state, status_label],
        )
        generate_btn.click(
            fn=on_generate_click,
            inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
            outputs=[original_output, ecodiff_output],
        )

    return demo


if __name__ == "__main__":
    demo = create_demo()
    demo.launch(share=True)