XCLiu commited on
Commit
aeed1bf
·
1 Parent(s): 28e87bc

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +175 -0
  2. sd_models.py +239 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from rf_models import RF_model
4
+ from sd_models import SD_model
5
+
6
+ import torch
7
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
8
+ import torch.nn.functional as F
9
+
10
+ from diffusers import StableDiffusionXLImg2ImgPipeline
11
+ import time
12
+ import copy
13
+ import numpy as np
14
+
15
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
16
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
17
+ )
18
+ pipe = pipe.to("cuda")
19
+
20
+ global model
21
+ global base_model
22
+ global img
23
+
24
+ def set_model(model_id):
25
+ global model
26
+ if model_id == "InstaFlow-0.9B":
27
+ model = RF_model("./instaflow_09b.pt")
28
+ elif model_id == "InstaFlow-1.7B":
29
+ model = RF_model("./instaflow_17b.pt")
30
+ else:
31
+ raise NotImplementedError
32
+ print('Finished Loading Model!')
33
+
34
+ def set_base_model(model_id):
35
+ global base_model
36
+ if model_id == "runwayml/stable-diffusion-v1-5":
37
+ base_model = SD_model("runwayml/stable-diffusion-v1-5")
38
+ else:
39
+ raise NotImplementedError
40
+ print('Finished Loading Base Model!')
41
+
42
+ def set_new_latent_and_generate_new_image(seed, prompt, negative_prompt="", num_inference_steps=1, guidance_scale=0.0):
43
+ print('Generate with input seed')
44
+ global model
45
+ global img
46
+ seed = int(seed)
47
+ num_inference_steps = int(num_inference_steps)
48
+ guidance_scale = float(guidance_scale)
49
+ print(seed, num_inference_steps, guidance_scale)
50
+
51
+ t_s = time.time()
52
+ new_image = model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
53
+ #print('time consumption:', time.time() - t_s)
54
+ inf_time = time.time() - t_s
55
+
56
+ img = copy.copy(new_image[0])
57
+
58
+ return new_image[0], inf_time
59
+
60
+ def set_new_latent_and_generate_new_image_with_base_model(seed, prompt, num_inference_steps=1, guidance_scale=0.0):
61
+ print('Generate with input seed')
62
+ global base_model
63
+ global img
64
+ negative_prompt=""
65
+ seed = int(seed)
66
+ num_inference_steps = int(num_inference_steps)
67
+ guidance_scale = float(guidance_scale)
68
+ print(seed, num_inference_steps, guidance_scale)
69
+
70
+ t_s = time.time()
71
+ new_image = base_model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
72
+ #print('time consumption:', time.time() - t_s)
73
+ inf_time = time.time() - t_s
74
+
75
+ img = copy.copy(new_image[0])
76
+
77
+ return new_image[0], inf_time
78
+
79
+
80
+ def set_new_latent_and_generate_new_image_and_random_seed(seed, prompt, negative_prompt="", num_inference_steps=1, guidance_scale=0.0):
81
+ print('Generate with a random seed')
82
+ global model
83
+ global img
84
+ seed = np.random.randint(0, 2**32)
85
+ num_inference_steps = int(num_inference_steps)
86
+ guidance_scale = float(guidance_scale)
87
+ print(seed, num_inference_steps, guidance_scale)
88
+
89
+ t_s = time.time()
90
+ new_image = model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
91
+ #print('time consumption:', time.time() - t_s)
92
+ inf_time = time.time() - t_s
93
+
94
+ img = copy.copy(new_image[0])
95
+
96
+ return new_image[0], seed, inf_time
97
+
98
+
99
+ def refine_image_512(prompt):
100
+ print('Refine with SDXL-Refiner (512)')
101
+ global img
102
+
103
+ t_s = time.time()
104
+ img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
105
+ img = img.permute(0, 2, 3, 1).squeeze(0).cpu().numpy()
106
+ new_image = pipe(prompt, image=img).images[0]
107
+ print('time consumption:', time.time() - t_s)
108
+ new_image = np.array(new_image) * 1.0 / 255.
109
+
110
+ img = new_image
111
+
112
+ return new_image
113
+
114
+ def refine_image_1024(prompt):
115
+ print('Refine with SDXL-Refiner (1024)')
116
+ global img
117
+
118
+ t_s = time.time()
119
+ img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
120
+ img = torch.nn.functional.interpolate(img, size=1024, mode='bilinear')
121
+ img = img.permute(0, 2, 3, 1).squeeze(0).cpu().numpy()
122
+ new_image = pipe(prompt, image=img).images[0]
123
+ print('time consumption:', time.time() - t_s)
124
+ new_image = np.array(new_image) * 1.0 / 255.
125
+
126
+ img = new_image
127
+
128
+ return new_image
129
+
130
+ set_model('InstaFlow-0.9B')
131
+ set_base_model("runwayml/stable-diffusion-v1-5")
132
+
133
+ with gr.Blocks() as gradio_gui:
134
+ gr.Markdown("Set Input Seed and Text Prompts Here")
135
+ with gr.Row():
136
+ with gr.Column(scale=0.4):
137
+ seed_input = gr.Textbox(value='101098274', label="Random Seed")
138
+ with gr.Column(scale=0.4):
139
+ prompt_input = gr.Textbox(value='A high-resolution photograph of a waterfall in autumn; muted tone', label="Prompt")
140
+
141
+ with gr.Row():
142
+ with gr.Column(scale=0.4):
143
+ with gr.Group():
144
+ gr.Markdown("Generation from InstaFlow-0.9B")
145
+ im = gr.Image()
146
+
147
+ gr.Markdown("Model ID: One-Step InstaFlow-0.9B")
148
+ inference_time_output = gr.Textbox(value='0.0', label='Inference Time with One-Step Model (Second)')
149
+ new_image_button = gr.Button(value="One-Step Generation with the Input Seed")
150
+ new_image_button.click(set_new_latent_and_generate_new_image, inputs=[seed_input, prompt_input], outputs=[im, inference_time_output])
151
+
152
+ next_image_button = gr.Button(value="One-Step Generation with a New Random Seed")
153
+ next_image_button.click(set_new_latent_and_generate_new_image_and_random_seed, inputs=[seed_input, prompt_input], outputs=[im, seed_input, inference_time_output])
154
+
155
+ refine_button_512 = gr.Button(value="Refine One-Step Generation with SDXL Refiner (Resolution: 512)")
156
+ refine_button_512.click(refine_image_512, inputs=[prompt_input], outputs=[im])
157
+
158
+ refine_button_1024 = gr.Button(value="Refine One-Step Generation with SDXL Refiner (Resolution: 1024)")
159
+ refine_button_1024.click(refine_image_1024, inputs=[prompt_input], outputs=[im])
160
+
161
+ with gr.Column(scale=0.4):
162
+ with gr.Group():
163
+ gr.Markdown("Generation from Stable Diffusion 1.5")
164
+ im_base = gr.Image()
165
+
166
+ gr.Markdown("Model ID: Multi-Step Stable Diffusion 1.5")
167
+ base_model_inference_time_output = gr.Textbox(value='0.0', label='Inference Time with Multi-Step Stable Diffusion (Second)')
168
+
169
+ num_inference_steps = gr.Textbox(value='25', label="Number of Inference Steps for Stable Diffusion")
170
+ guidance_scale = gr.Textbox(value='5.0', label="Guidance Scale for Stable Diffusion")
171
+
172
+ base_new_image_button = gr.Button(value="Multi-Step Generation with Stable Diffusion and the Input Seed")
173
+ base_new_image_button.click(set_new_latent_and_generate_new_image_with_base_model, inputs=[seed_input, prompt_input, num_inference_steps, guidance_scale], outputs=[im_base, base_model_inference_time_output])
174
+
175
+ gradio_gui.launch()
sd_models.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import logging
18
+ import math
19
+ import os
20
+ import random
21
+ from pathlib import Path
22
+ from typing import Optional, Union, List, Callable
23
+
24
+ import datasets
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ import transformers
30
+ from datasets import load_dataset
31
+ from huggingface_hub import HfFolder, Repository, create_repo, whoami
32
+ from packaging import version
33
+ from torchvision import transforms
34
+ from tqdm.auto import tqdm
35
+ from transformers import CLIPTextModel, CLIPTokenizer
36
+
37
+ import diffusers
38
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel#, StackUNet2DConditionModel
39
+ from diffusers.optimization import get_scheduler
40
+ from diffusers.training_utils import EMAModel
41
+ from diffusers.utils import check_min_version, deprecate
42
+ from diffusers.utils.import_utils import is_xformers_available
43
+
44
+ import time
45
+
46
+ from torch.distributions import Normal, Categorical
47
+ from torch.distributions.multivariate_normal import MultivariateNormal
48
+ from torch.distributions.mixture_same_family import MixtureSameFamily
49
+
50
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
51
+ import torchvision
52
+
53
+ import cv2
54
+
55
+ def inference_latent(
56
+ pipeline,
57
+ prompt: Union[str, List[str]],
58
+ height: Optional[int] = None,
59
+ width: Optional[int] = None,
60
+ num_inference_steps: int = 50,
61
+ guidance_scale: float = 7.5,
62
+ negative_prompt: Optional[Union[str, List[str]]] = None,
63
+ num_images_per_prompt: Optional[int] = 1,
64
+ eta: float = 0.0,
65
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
66
+ latents: Optional[torch.FloatTensor] = None,
67
+ output_type: Optional[str] = "pil",
68
+ return_dict: bool = True,
69
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
70
+ callback_steps: Optional[int] = 1,
71
+ ):
72
+
73
+ # 0. Default height and width to unet
74
+ height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
75
+ width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
76
+
77
+ # 1. Check inputs. Raise error if not correct
78
+ #pipeline.check_inputs(prompt, height, width, callback_steps)
79
+
80
+ # 2. Define call parameters
81
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
82
+ device = pipeline._execution_device
83
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
84
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
85
+ # corresponds to doing no classifier free guidance.
86
+ do_classifier_free_guidance = guidance_scale > 1.0
87
+
88
+ # 3. Encode input prompt
89
+ #setup_seed(0)
90
+ text_embeddings = pipeline._encode_prompt(
91
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
92
+ )
93
+
94
+ # 4. Prepare timesteps
95
+ pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
96
+ timesteps = pipeline.scheduler.timesteps
97
+
98
+ # 5. Prepare latent variables
99
+ num_channels_latents = pipeline.unet.in_channels
100
+ latents = latents.reshape(1, num_channels_latents, 64, 64)
101
+
102
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
103
+ extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
104
+
105
+ # 7. Denoising loop
106
+ num_warmup_steps = len(timesteps) - \
107
+ num_inference_steps * pipeline.scheduler.order
108
+
109
+ latents_cllt = [latents.detach().clone()]
110
+ with torch.no_grad():
111
+ for i, t in enumerate(timesteps):
112
+ # expand the latents if we are doing classifier free guidance
113
+ latent_model_input = torch.cat(
114
+ [latents] * 2) if do_classifier_free_guidance else latents
115
+ latent_model_input = pipeline.scheduler.scale_model_input(
116
+ latent_model_input, t)
117
+
118
+ noise_pred = pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
119
+
120
+ # perform guidance
121
+ if do_classifier_free_guidance:
122
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
123
+ noise_pred = noise_pred_uncond + guidance_scale * \
124
+ (noise_pred_text - noise_pred_uncond)
125
+
126
+ # compute the previous noisy sample x_t -> x_t-1
127
+ outputs = pipeline.scheduler.step(
128
+ noise_pred, t, latents, **extra_step_kwargs)
129
+
130
+ latents = outputs.prev_sample
131
+
132
+
133
+ example = {
134
+ 'latent': latents.detach().clone(),
135
+ 'text_embeddings': text_embeddings.chunk(2)[1].detach() if do_classifier_free_guidance else text_embeddings.detach(),
136
+ }
137
+ return example
138
+
139
+
140
+
141
+ def setup_seed(seed):
142
+ import random
143
+ torch.manual_seed(seed)
144
+ torch.cuda.manual_seed_all(seed)
145
+ np.random.seed(seed)
146
+ random.seed(seed)
147
+ torch.backends.cudnn.benchmark = False
148
+ torch.backends.cudnn.deterministic = True
149
+ torch.cuda.empty_cache()
150
+
151
+
152
+ class SD_model():
153
+
154
+ def __init__(self, pretrained_model_name_or_path):
155
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
156
+
157
+ # Load scheduler, tokenizer and models.
158
+ noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler")
159
+ tokenizer = CLIPTokenizer.from_pretrained(
160
+ self.pretrained_model_name_or_path, subfolder="tokenizer"#, revision=args.revision
161
+ )
162
+ text_encoder = CLIPTextModel.from_pretrained(
163
+ self.pretrained_model_name_or_path, subfolder="text_encoder"#, revision=args.revision
164
+ )
165
+ vae = AutoencoderKL.from_pretrained(
166
+ self.pretrained_model_name_or_path, subfolder="vae"#, revision=args.revision
167
+ )
168
+ unet = UNet2DConditionModel.from_pretrained(
169
+ self.pretrained_model_name_or_path, subfolder="unet"#, revision=args.non_ema_revision
170
+ )
171
+
172
+
173
+ unet.eval()
174
+ vae.eval()
175
+ text_encoder.eval()
176
+
177
+ # Freeze vae and text_encoder
178
+ vae.requires_grad_(False)
179
+ text_encoder.requires_grad_(False)
180
+ unet.requires_grad_(False)
181
+
182
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
183
+ # as these models are only used for inference, keeping weights in full precision is not required.
184
+ weight_dtype = torch.float16
185
+ self.weight_dtype = weight_dtype
186
+ device = 'cuda'
187
+ self.device = device
188
+
189
+ # Move text_encode and vae to gpu and cast to weight_dtype
190
+ text_encoder.to(device, dtype=weight_dtype)
191
+ vae.to(device, dtype=weight_dtype)
192
+ unet.to(device, dtype=weight_dtype)
193
+
194
+ # Create the pipeline using the trained modules and save it.
195
+ pipeline = StableDiffusionPipeline.from_pretrained(
196
+ self.pretrained_model_name_or_path,
197
+ text_encoder=text_encoder,
198
+ vae=vae,
199
+ unet=unet,
200
+ torch_dtype=weight_dtype,
201
+ )
202
+ pipeline = pipeline.to(device)
203
+ from diffusers import DPMSolverMultistepScheduler
204
+
205
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
206
+ self.pipeline = pipeline
207
+
208
+ def set_new_latent_and_generate_new_image(self, seed=None, prompt=None, negative_prompt="", num_inference_steps=25, guidance_scale=5.0):
209
+ if seed is None:
210
+ assert False, "Must have a pre-defined random seed"
211
+
212
+ if prompt is None:
213
+ assert False, "Must have a user-specified text prompt"
214
+
215
+ setup_seed(seed)
216
+ self.latents = torch.randn((1, 4*64*64), device=self.device).to(dtype=self.weight_dtype)
217
+ self.prompt = prompt
218
+ self.negative_prompt = negative_prompt
219
+ self.guidance_scale = guidance_scale
220
+ self.num_inference_steps = num_inference_steps
221
+
222
+ prompts = [prompt]
223
+ negative_prompts = [negative_prompt]
224
+
225
+ output = inference_latent(
226
+ self.pipeline,
227
+ prompt=prompts,
228
+ negative_prompt=negative_prompts,
229
+ num_inference_steps=num_inference_steps,
230
+ guidance_scale=self.guidance_scale,
231
+ latents=self.latents.detach().clone(),
232
+ )
233
+
234
+ image = self.pipeline.decode_latents(output['latent'])
235
+
236
+ self.org_image = image
237
+
238
+ return image
239
+