File size: 4,214 Bytes
f31a574
 
967a18e
 
f31a574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
967a18e
f31a574
 
 
 
 
 
967a18e
f31a574
 
967a18e
f31a574
 
967a18e
f31a574
 
967a18e
f31a574
 
 
 
 
967a18e
f31a574
967a18e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import copy
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderKL
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download


class TimestepShiftLCMScheduler(LCMScheduler):
    def __init__(self, *args, shifted_timestep=250, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_to_config(shifted_timestep=shifted_timestep)

    def set_timesteps(self, *args, **kwargs):
        super().set_timesteps(*args, **kwargs)
        self.origin_timesteps = self.timesteps.clone()
        self.shifted_timesteps = (self.timesteps * self.config.shifted_timestep / self.config.num_train_timesteps).long()
        self.timesteps = self.shifted_timesteps

    def step(self, model_output, timestep, sample, generator=None, return_dict=True):
        if self.step_index is None:
            self._init_step_index(timestep)
        self.timesteps = self.origin_timesteps
        output = super().step(model_output, timestep, sample, generator, return_dict)
        self.timesteps = self.shifted_timesteps
        return output


vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(
    base_model_id,
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
).to("cuda")

repo = "ChenDY/NitroFusion"

unet_realism = pipe.unet
unet_realism.load_state_dict(load_file(hf_hub_download(repo, "nitrosd-realism_unet.safetensors"), device="cuda"))
scheduler_realism = TimestepShiftLCMScheduler.from_pretrained(base_model_id, subfolder="scheduler", shifted_timestep=250)
scheduler_realism.config.original_inference_steps = 4

unet_vibrant = copy.deepcopy(pipe.unet)
unet_vibrant.load_state_dict(load_file(hf_hub_download(repo, "nitrosd-vibrant_unet.safetensors"), device="cuda"))
scheduler_vibrant = TimestepShiftLCMScheduler.from_pretrained(base_model_id, subfolder="scheduler", shifted_timestep=500)
scheduler_vibrant.config.original_inference_steps = 4


@spaces.GPU
def process_image(model_choice, num_images, height, width, prompt, seed):
    global pipe
    # Switch to the selected model
    if model_choice == "NitroSD-Realism":
        pipe.unet = unet_realism
        pipe.scheduler = scheduler_realism
    elif model_choice == "NitroSD-Vibrant":
        pipe.unet = unet_vibrant
        pipe.scheduler = scheduler_vibrant
    else:
        raise ValueError("Invalid model choice.")
    # Generate the image
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
        return pipe(
            prompt=[prompt] * num_images,
            generator=torch.manual_seed(int(seed)),
            num_inference_steps=1,
            guidance_scale=0.0,
            height=int(height),
            width=int(width),
        ).images


# Gradio UI
with gr.Blocks() as demo:
    with gr.Column():
        with gr.Row():
            with gr.Column():
                model_choice = gr.Dropdown(
                    label="Choose Model",
                    choices=["NitroSD-Realism", "NitroSD-Vibrant"],
                    value="NitroSD-Realism",
                    interactive=True,
                )
                num_images = gr.Slider(
                    label="Number of Images", minimum=1, maximum=4, step=1, value=4, interactive=True
                )
                height = gr.Slider(
                    label="Image Height", minimum=768, maximum=1024, step=8, value=1024, interactive=True
                )
                width = gr.Slider(
                    label="Image Width", minimum=768, maximum=1024, step=8, value=1024, interactive=True
                )
                prompt = gr.Text(label="Prompt", value="a photo of a cat", interactive=True)
                seed = gr.Number(label="Seed", value=2024, interactive=True)
                btn = gr.Button(value="Generate Image")
            with gr.Column():
                output = gr.Gallery(height=1024)

            btn.click(process_image, inputs=[model_choice, num_images, height, width, prompt, seed], outputs=[output])

if __name__ == "__main__":
    demo.launch()