File size: 4,537 Bytes
f31a574
 
967a18e
 
f31a574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e71610b
f31a574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e71610b
f31a574
 
 
 
 
 
 
 
 
967a18e
f31a574
e71610b
 
 
f31a574
 
 
 
 
967a18e
f31a574
 
967a18e
f31a574
 
967a18e
f31a574
 
967a18e
f31a574
 
e71610b
f31a574
 
 
967a18e
e71610b
 
 
 
 
967a18e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import copy
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderKL
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download


class TimestepShiftLCMScheduler(LCMScheduler):
    def __init__(self, *args, shifted_timestep=250, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_to_config(shifted_timestep=shifted_timestep)

    def set_timesteps(self, *args, **kwargs):
        super().set_timesteps(*args, **kwargs)
        self.origin_timesteps = self.timesteps.clone()
        self.shifted_timesteps = (self.timesteps * self.config.shifted_timestep / self.config.num_train_timesteps).long()
        self.timesteps = self.shifted_timesteps

    def step(self, model_output, timestep, sample, generator=None, return_dict=True):
        if self.step_index is None:
            self._init_step_index(timestep)
        self.timesteps = self.origin_timesteps
        output = super().step(model_output, timestep, sample, generator, return_dict)
        self.timesteps = self.shifted_timesteps
        return output


vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(
    base_model_id,
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
).to("cuda")

repo = "ChenDY/NitroFusion"

unet_realism = pipe.unet
unet_realism.load_state_dict(load_file(hf_hub_download(repo, "nitrosd-realism_unet.safetensors"), device="cuda"))
scheduler_realism = TimestepShiftLCMScheduler.from_pretrained(base_model_id, subfolder="scheduler", shifted_timestep=250)
scheduler_realism.config.original_inference_steps = 4

unet_vibrant = copy.deepcopy(pipe.unet)
unet_vibrant.load_state_dict(load_file(hf_hub_download(repo, "nitrosd-vibrant_unet.safetensors"), device="cuda"))
scheduler_vibrant = TimestepShiftLCMScheduler.from_pretrained(base_model_id, subfolder="scheduler", shifted_timestep=500)
scheduler_vibrant.config.original_inference_steps = 4


@spaces.GPU
def process_image(model_choice, num_images, height, width, prompt, seed, inference_steps):
    global pipe
    # Switch to the selected model
    if model_choice == "NitroSD-Realism":
        pipe.unet = unet_realism
        pipe.scheduler = scheduler_realism
    elif model_choice == "NitroSD-Vibrant":
        pipe.unet = unet_vibrant
        pipe.scheduler = scheduler_vibrant
    else:
        raise ValueError("Invalid model choice.")
    # Generate the image
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
        return pipe(
            prompt=[prompt] * num_images,
            generator=torch.manual_seed(int(seed)),
            num_inference_steps=inference_steps,
            guidance_scale=0.0,
            height=int(height),
            width=int(width),
        ).images


# Gradio UI
with gr.Blocks() as demo:
    with gr.Column():
        with gr.Row():
            with gr.Column():
                gr.Markdown("""
                    ### NitroFusion Single-Step Text-To-Image
                """)
                model_choice = gr.Dropdown(
                    label="Choose Model",
                    choices=["NitroSD-Realism", "NitroSD-Vibrant"],
                    value="NitroSD-Realism",
                    interactive=True,
                )
                num_images = gr.Slider(
                    label="Number of Images", minimum=1, maximum=4, step=1, value=4, interactive=True
                )
                height = gr.Slider(
                    label="Image Height", minimum=768, maximum=1024, step=8, value=1024, interactive=True
                )
                width = gr.Slider(
                    label="Image Width", minimum=768, maximum=1024, step=8, value=1024, interactive=True
                )
                prompt = gr.Text(label="Prompt", value="a photo of a cat", interactive=True)
                seed = gr.Number(label="Seed", value=2024, interactive=True)
                inference_steps = gr.Number(label="Inference Steps", value=1, interactive=False)
                btn = gr.Button(value="Generate Image")
            with gr.Column():
                output = gr.Gallery(height=1024)

            btn.click(
                process_image,
                inputs=[model_choice, num_images, height, width, prompt, seed, inference_steps],
                outputs=[output],
            )

if __name__ == "__main__":
    demo.launch()