File size: 3,678 Bytes
aaead26
 
 
7fb327a
aaead26
 
7fb327a
aaead26
 
 
7fb327a
 
aaead26
7fb327a
aaead26
 
7fb327a
 
 
 
 
 
 
aaead26
 
7fb327a
 
aaead26
 
7fb327a
 
 
 
 
 
 
 
 
 
aaead26
 
 
7fb327a
aaead26
 
 
7fb327a
aaead26
 
 
 
 
 
 
 
 
7fb327a
 
aaead26
7fb327a
 
 
 
 
aaead26
7fb327a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaead26
7fb327a
 
aaead26
 
 
7fb327a
aaead26
7fb327a
aaead26
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import numpy as np
import random
from huggingface_hub import hf_hub_download

# import spaces #[uncomment to use ZeroGPU]
from diffusers import FluxPipeline
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "black-forest-labs/FLUX.1-schnell"  # Replace to the model you would like to use
torch_dtype = torch.bfloat16

pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)

# load pruned model
pruned_pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to("cpu")
pruned_pipe.transformer = torch.load(
    hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"),
    map_location="cpu",
)
pruned_pipe = pruned_pipe.to(device)


MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024


# @spaces.GPU
def generate_images(prompt, seed, steps):
    pipe.to("cuda")
    pruned_pipe.to("cuda")
    # Run the model and return images directly
    g_cpu = torch.Generator("cuda").manual_seed(seed)
    original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    g_cpu = torch.Generator("cuda").manual_seed(seed)
    ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    return original_image, ecodiff_image


examples = [
    "A clock tower floating in a sea of clouds",
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
    "A delicious ceviche cheesecake slice",
    "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

header = """
# 🌱 EcoDiff Pruned FLUX-Schnell (20% Pruning Ratio)

<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
</div>
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(header)
    with gr.Row():
        prompt = gr.Textbox(
            label="Prompt",
            value="A clock tower floating in a sea of clouds",
            scale=3,
        )
        seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
        steps = gr.Slider(
            label="Number of Steps",
            minimum=1,
            maximum=100,
            value=5,
            step=1,
            scale=1,
        )
        generate_btn = gr.Button("Generate Images")
    gr.Examples(
        examples=[
            "A clock tower floating in a sea of clouds",
            "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
            "An astronaut riding a green horse",
            "A delicious ceviche cheesecake slice",
            "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
        ],
        inputs=[prompt],
    )
    with gr.Row():
        original_output = gr.Image(label="Original Output")
        ecodiff_output = gr.Image(label="EcoDiff Output")
    gr.on(
        triggers=[generate_btn.click, prompt.submit],
        fn=generate_images,
        inputs=[
            prompt,
            seed,
            steps,
        ],
        outputs=[original_output, ecodiff_output],
    )

if __name__ == "__main__":
    demo.launch()