Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from dataclasses import dataclass | |
import spaces | |
import torch | |
from huggingface_hub import hf_hub_download | |
from diffusers import StableDiffusionXLPipeline, FluxPipeline | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
class GradioArgs: | |
seed: list = None | |
prompt: str = None | |
mix_precision: str = "bf16" | |
num_intervention_steps: int = 50 | |
model: str = "sdxl" | |
binary: bool = False | |
masking: str = "binary" | |
scope: str = "global" | |
ratio: list = None | |
width: int = None | |
height: int = None | |
epsilon: float = 0.0 | |
lambda_threshold: float = 0.001 | |
def __post_init__(self): | |
if self.seed is None: | |
self.seed = [44] | |
def binary_mask_eval(args, model): | |
model = model.lower() | |
# load sdxl model | |
if model == "sdxl": | |
pruned_pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 | |
).to("cpu") | |
pruned_pipe.unet = torch.load( | |
hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"), | |
map_location="cpu", | |
) | |
elif model == "flux": | |
pruned_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to( | |
"cpu" | |
) | |
pruned_pipe.transformer = torch.load( | |
hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"), | |
map_location="cpu", | |
) | |
# reload the original model | |
if model == "sdxl": | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 | |
).to("cpu") | |
elif model == "flux": | |
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cpu") | |
print("prune complete") | |
return pipe, pruned_pipe | |
def generate_images(prompt, seed, steps, pipe, pruned_pipe): | |
pipe.to("cuda") | |
pruned_pipe.to("cuda") | |
# Run the model and return images directly | |
g_cpu = torch.Generator("cuda").manual_seed(seed) | |
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] | |
g_cpu = torch.Generator("cuda").manual_seed(seed) | |
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] | |
return original_image, ecodiff_image | |
def on_prune_click(prompt, seed, steps, model): | |
args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps) | |
pipe, pruned_pipe = binary_mask_eval(args, model) | |
return pipe, pruned_pipe, [("Model Initialized", "green")] | |
def on_generate_click(prompt, seed, steps, pipe, pruned_pipe): | |
original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe) | |
return original_image, ecodiff_image | |
header = """ | |
# π OminiControl / FLUX | |
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
<a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> | |
<a href="https://huggingface.co/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/π€-Model-ffbd45.svg" alt="HuggingFace"></a> | |
<a href="https://github.com/Yuanshi9815/OminiControl"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> | |
</div> | |
""" | |
def create_demo(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown(header) | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
**Note: Please first initialize the model before generating images. This may take a while to fully load.** | |
""" | |
) | |
with gr.Row(): | |
model_choice = gr.Radio(choices=["SDXL", "FLUX"], value="SDXL", label="Model", scale=2) | |
pruning_ratio = gr.Text("20% Pruning Ratio for SDXL, FLUX", label="Pruning Ratio", scale=2) | |
status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1) | |
prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1) | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
**Generate images with the original model and the pruned model. May take up to 1 minute due to dynamic allocation of GPU.** | |
**Note: we prune on step-distilled FLUX, you should use step 5 (instead of 50) for FLUX generation.** | |
""" | |
) | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="A clock tower floating in a sea of clouds", | |
scale=3, | |
) | |
seed = gr.Number(label="Seed", value=44, precision=0, scale=1) | |
steps = gr.Slider( | |
label="Number of Steps", | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
scale=1, | |
) | |
generate_btn = gr.Button("Generate Images") | |
gr.Examples( | |
examples=[ | |
"A clock tower floating in a sea of clouds", | |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
"An astronaut riding a green horse", | |
"A delicious ceviche cheesecake slice", | |
"A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages", | |
], | |
inputs=[prompt], | |
) | |
with gr.Row(): | |
original_output = gr.Image(label="Original Output") | |
ecodiff_output = gr.Image(label="EcoDiff Output") | |
pipe_state = gr.State(None) | |
pruned_pipe_state = gr.State(None) | |
prompt.submit( | |
fn=on_generate_click, | |
inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state], | |
outputs=[original_output, ecodiff_output], | |
) | |
prune_btn.click( | |
fn=on_prune_click, | |
inputs=[prompt, seed, steps, model_choice], | |
outputs=[pipe_state, pruned_pipe_state, status_label], | |
) | |
generate_btn.click( | |
fn=on_generate_click, | |
inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state], | |
outputs=[original_output, ecodiff_output], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch(share=True) | |