Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,605 Bytes
f31a574 967a18e f31a574 e71610b f31a574 e71610b f31a574 967a18e f31a574 e71610b f31a574 967a18e f31a574 967a18e f31a574 967a18e f31a574 967a18e 1393888 f31a574 967a18e e71610b 967a18e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import copy
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderKL
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
class TimestepShiftLCMScheduler(LCMScheduler):
def __init__(self, *args, shifted_timestep=250, **kwargs):
super().__init__(*args, **kwargs)
self.register_to_config(shifted_timestep=shifted_timestep)
def set_timesteps(self, *args, **kwargs):
super().set_timesteps(*args, **kwargs)
self.origin_timesteps = self.timesteps.clone()
self.shifted_timesteps = (self.timesteps * self.config.shifted_timestep / self.config.num_train_timesteps).long()
self.timesteps = self.shifted_timesteps
def step(self, model_output, timestep, sample, generator=None, return_dict=True):
if self.step_index is None:
self._init_step_index(timestep)
self.timesteps = self.origin_timesteps
output = super().step(model_output, timestep, sample, generator, return_dict)
self.timesteps = self.shifted_timesteps
return output
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(
base_model_id,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
repo = "ChenDY/NitroFusion"
unet_realism = pipe.unet
unet_realism.load_state_dict(load_file(hf_hub_download(repo, "nitrosd-realism_unet.safetensors"), device="cuda"))
scheduler_realism = TimestepShiftLCMScheduler.from_pretrained(base_model_id, subfolder="scheduler", shifted_timestep=250)
scheduler_realism.config.original_inference_steps = 4
unet_vibrant = copy.deepcopy(pipe.unet)
unet_vibrant.load_state_dict(load_file(hf_hub_download(repo, "nitrosd-vibrant_unet.safetensors"), device="cuda"))
scheduler_vibrant = TimestepShiftLCMScheduler.from_pretrained(base_model_id, subfolder="scheduler", shifted_timestep=500)
scheduler_vibrant.config.original_inference_steps = 4
@spaces.GPU
def process_image(model_choice, num_images, height, width, prompt, seed, inference_steps):
global pipe
# Switch to the selected model
if model_choice == "NitroSD-Realism":
pipe.unet = unet_realism
pipe.scheduler = scheduler_realism
elif model_choice == "NitroSD-Vibrant":
pipe.unet = unet_vibrant
pipe.scheduler = scheduler_vibrant
else:
raise ValueError("Invalid model choice.")
# Generate the image
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
return pipe(
prompt=[prompt] * num_images,
generator=torch.manual_seed(int(seed)),
num_inference_steps=inference_steps,
guidance_scale=0.0,
height=int(height),
width=int(width),
).images
# Gradio UI
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
with gr.Column():
gr.Markdown("""
### NitroFusion Single-Step Text-To-Image
""")
model_choice = gr.Dropdown(
label="Choose Model",
choices=["NitroSD-Realism", "NitroSD-Vibrant"],
value="NitroSD-Realism",
interactive=True,
)
num_images = gr.Slider(
label="Number of Images", minimum=1, maximum=4, step=1, value=4, interactive=True
)
height = gr.Slider(
label="Image Height", minimum=768, maximum=1024, step=8, value=1024, interactive=True
)
width = gr.Slider(
label="Image Width", minimum=768, maximum=1024, step=8, value=1024, interactive=True
)
inference_steps = gr.Slider(
label="Inference Steps", minimum=1, maximum=2, step=1, value=1, interactive=True,
)
prompt = gr.Text(label="Prompt", value="a photo of a cat", interactive=True)
seed = gr.Number(label="Seed", value=2024, interactive=True)
btn = gr.Button(value="Generate Image")
with gr.Column():
output = gr.Gallery(height=1024)
btn.click(
process_image,
inputs=[model_choice, num_images, height, width, prompt, seed, inference_steps],
outputs=[output],
)
if __name__ == "__main__":
demo.launch()
|