ChenDY's picture
add fixed inference steps and title
e71610b
raw
history blame
4.54 kB
import copy
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderKL
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
class TimestepShiftLCMScheduler(LCMScheduler):
def __init__(self, *args, shifted_timestep=250, **kwargs):
super().__init__(*args, **kwargs)
self.register_to_config(shifted_timestep=shifted_timestep)
def set_timesteps(self, *args, **kwargs):
super().set_timesteps(*args, **kwargs)
self.origin_timesteps = self.timesteps.clone()
self.shifted_timesteps = (self.timesteps * self.config.shifted_timestep / self.config.num_train_timesteps).long()
self.timesteps = self.shifted_timesteps
def step(self, model_output, timestep, sample, generator=None, return_dict=True):
if self.step_index is None:
self._init_step_index(timestep)
self.timesteps = self.origin_timesteps
output = super().step(model_output, timestep, sample, generator, return_dict)
self.timesteps = self.shifted_timesteps
return output
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(
base_model_id,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
repo = "ChenDY/NitroFusion"
unet_realism = pipe.unet
unet_realism.load_state_dict(load_file(hf_hub_download(repo, "nitrosd-realism_unet.safetensors"), device="cuda"))
scheduler_realism = TimestepShiftLCMScheduler.from_pretrained(base_model_id, subfolder="scheduler", shifted_timestep=250)
scheduler_realism.config.original_inference_steps = 4
unet_vibrant = copy.deepcopy(pipe.unet)
unet_vibrant.load_state_dict(load_file(hf_hub_download(repo, "nitrosd-vibrant_unet.safetensors"), device="cuda"))
scheduler_vibrant = TimestepShiftLCMScheduler.from_pretrained(base_model_id, subfolder="scheduler", shifted_timestep=500)
scheduler_vibrant.config.original_inference_steps = 4
@spaces.GPU
def process_image(model_choice, num_images, height, width, prompt, seed, inference_steps):
global pipe
# Switch to the selected model
if model_choice == "NitroSD-Realism":
pipe.unet = unet_realism
pipe.scheduler = scheduler_realism
elif model_choice == "NitroSD-Vibrant":
pipe.unet = unet_vibrant
pipe.scheduler = scheduler_vibrant
else:
raise ValueError("Invalid model choice.")
# Generate the image
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
return pipe(
prompt=[prompt] * num_images,
generator=torch.manual_seed(int(seed)),
num_inference_steps=inference_steps,
guidance_scale=0.0,
height=int(height),
width=int(width),
).images
# Gradio UI
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
with gr.Column():
gr.Markdown("""
### NitroFusion Single-Step Text-To-Image
""")
model_choice = gr.Dropdown(
label="Choose Model",
choices=["NitroSD-Realism", "NitroSD-Vibrant"],
value="NitroSD-Realism",
interactive=True,
)
num_images = gr.Slider(
label="Number of Images", minimum=1, maximum=4, step=1, value=4, interactive=True
)
height = gr.Slider(
label="Image Height", minimum=768, maximum=1024, step=8, value=1024, interactive=True
)
width = gr.Slider(
label="Image Width", minimum=768, maximum=1024, step=8, value=1024, interactive=True
)
prompt = gr.Text(label="Prompt", value="a photo of a cat", interactive=True)
seed = gr.Number(label="Seed", value=2024, interactive=True)
inference_steps = gr.Number(label="Inference Steps", value=1, interactive=False)
btn = gr.Button(value="Generate Image")
with gr.Column():
output = gr.Gallery(height=1024)
btn.click(
process_image,
inputs=[model_choice, num_images, height, width, prompt, seed, inference_steps],
outputs=[output],
)
if __name__ == "__main__":
demo.launch()