File size: 6,852 Bytes
aeed1bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import gradio as gr

from rf_models import RF_model
from sd_models import SD_model

import torch
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import torch.nn.functional as F

from diffusers import StableDiffusionXLImg2ImgPipeline
import time
import copy
import numpy as np

pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipe = pipe.to("cuda")

global model
global base_model
global img

def set_model(model_id):
    global model 
    if model_id == "InstaFlow-0.9B":
        model = RF_model("./instaflow_09b.pt")
    elif model_id == "InstaFlow-1.7B":
        model = RF_model("./instaflow_17b.pt")
    else:
        raise NotImplementedError
    print('Finished Loading Model!')

def set_base_model(model_id):
    global base_model 
    if model_id == "runwayml/stable-diffusion-v1-5":
        base_model = SD_model("runwayml/stable-diffusion-v1-5")
    else:
        raise NotImplementedError
    print('Finished Loading Base Model!')

def set_new_latent_and_generate_new_image(seed, prompt, negative_prompt="", num_inference_steps=1, guidance_scale=0.0):
    print('Generate with input seed')
    global model
    global img
    seed = int(seed)
    num_inference_steps = int(num_inference_steps)
    guidance_scale = float(guidance_scale)
    print(seed, num_inference_steps, guidance_scale)

    t_s = time.time()
    new_image = model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
    #print('time consumption:', time.time() - t_s) 
    inf_time = time.time() - t_s 

    img = copy.copy(new_image[0])

    return new_image[0], inf_time

def set_new_latent_and_generate_new_image_with_base_model(seed, prompt, num_inference_steps=1, guidance_scale=0.0):
    print('Generate with input seed')
    global base_model
    global img
    negative_prompt=""
    seed = int(seed)
    num_inference_steps = int(num_inference_steps)
    guidance_scale = float(guidance_scale)
    print(seed, num_inference_steps, guidance_scale)

    t_s = time.time()
    new_image = base_model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
    #print('time consumption:', time.time() - t_s) 
    inf_time = time.time() - t_s

    img = copy.copy(new_image[0])

    return new_image[0], inf_time


def set_new_latent_and_generate_new_image_and_random_seed(seed, prompt, negative_prompt="", num_inference_steps=1, guidance_scale=0.0):
    print('Generate with a random seed')
    global model
    global img
    seed = np.random.randint(0, 2**32)
    num_inference_steps = int(num_inference_steps)
    guidance_scale = float(guidance_scale)
    print(seed, num_inference_steps, guidance_scale)

    t_s = time.time()
    new_image = model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
    #print('time consumption:', time.time() - t_s) 
    inf_time = time.time() - t_s

    img = copy.copy(new_image[0])

    return new_image[0], seed, inf_time


def refine_image_512(prompt):
    print('Refine with SDXL-Refiner (512)')
    global img

    t_s = time.time()
    img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
    img = img.permute(0, 2, 3, 1).squeeze(0).cpu().numpy()
    new_image = pipe(prompt, image=img).images[0] 
    print('time consumption:', time.time() - t_s) 
    new_image = np.array(new_image) * 1.0 / 255.

    img = new_image

    return new_image

def refine_image_1024(prompt):
    print('Refine with SDXL-Refiner (1024)')
    global img

    t_s = time.time()
    img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
    img = torch.nn.functional.interpolate(img, size=1024, mode='bilinear')
    img = img.permute(0, 2, 3, 1).squeeze(0).cpu().numpy()
    new_image = pipe(prompt, image=img).images[0] 
    print('time consumption:', time.time() - t_s) 
    new_image = np.array(new_image) * 1.0 / 255.

    img = new_image

    return new_image

set_model('InstaFlow-0.9B')
set_base_model("runwayml/stable-diffusion-v1-5")

with gr.Blocks() as gradio_gui:
    gr.Markdown("Set Input Seed and Text Prompts Here")
    with gr.Row():
        with gr.Column(scale=0.4):
            seed_input = gr.Textbox(value='101098274', label="Random Seed") 
        with gr.Column(scale=0.4):
            prompt_input = gr.Textbox(value='A high-resolution photograph of a waterfall in autumn; muted tone', label="Prompt")

    with gr.Row():
        with gr.Column(scale=0.4):
            with gr.Group():
                gr.Markdown("Generation from InstaFlow-0.9B")
                im = gr.Image()
            
            gr.Markdown("Model ID: One-Step InstaFlow-0.9B")
            inference_time_output = gr.Textbox(value='0.0', label='Inference Time with One-Step Model (Second)')
            new_image_button = gr.Button(value="One-Step Generation with the Input Seed")
            new_image_button.click(set_new_latent_and_generate_new_image, inputs=[seed_input, prompt_input], outputs=[im, inference_time_output])

            next_image_button = gr.Button(value="One-Step Generation with a New Random Seed")
            next_image_button.click(set_new_latent_and_generate_new_image_and_random_seed, inputs=[seed_input, prompt_input], outputs=[im, seed_input, inference_time_output])

            refine_button_512 = gr.Button(value="Refine One-Step Generation with SDXL Refiner (Resolution: 512)")
            refine_button_512.click(refine_image_512, inputs=[prompt_input], outputs=[im])

            refine_button_1024 = gr.Button(value="Refine One-Step Generation with SDXL Refiner (Resolution: 1024)")
            refine_button_1024.click(refine_image_1024, inputs=[prompt_input], outputs=[im])

        with gr.Column(scale=0.4):
            with gr.Group():
                gr.Markdown("Generation from Stable Diffusion 1.5") 
                im_base = gr.Image()

            gr.Markdown("Model ID: Multi-Step Stable Diffusion 1.5")
            base_model_inference_time_output = gr.Textbox(value='0.0', label='Inference Time with Multi-Step Stable Diffusion (Second)')

            num_inference_steps = gr.Textbox(value='25', label="Number of Inference Steps for Stable Diffusion")
            guidance_scale = gr.Textbox(value='5.0', label="Guidance Scale for Stable Diffusion")
            
            base_new_image_button = gr.Button(value="Multi-Step Generation with Stable Diffusion and the Input Seed") 
            base_new_image_button.click(set_new_latent_and_generate_new_image_with_base_model, inputs=[seed_input, prompt_input,  num_inference_steps, guidance_scale], outputs=[im_base, base_model_inference_time_output])

gradio_gui.launch()