XCLiu commited on
Commit
d5646a3
·
1 Parent(s): ea2f787

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +132 -0
  2. rf_models.py +249 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from rf_models import RF_model
4
+
5
+ import torch
6
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
7
+ import torch.nn.functional as F
8
+
9
+ from diffusers import StableDiffusionXLImg2ImgPipeline
10
+ import time
11
+ import copy
12
+ import numpy as np
13
+
14
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
15
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
16
+ )
17
+ pipe = pipe.to("cuda")
18
+
19
+ global model
20
+ global img
21
+
22
+ def set_model(model_id):
23
+ global model
24
+ if model_id == "InstaFlow-0.9B":
25
+ model = RF_model("./instaflow_09b.pt")
26
+ elif model_id == "InstaFlow-1.7B":
27
+ model = RF_model("./instaflow_17b.pt")
28
+ else:
29
+ raise NotImplementedError
30
+ print('Finished Loading Model!')
31
+
32
+
33
+ def set_new_latent_and_generate_new_image(seed, prompt, negative_prompt="", num_inference_steps=1, guidance_scale=0.0):
34
+ print('Generate with input seed')
35
+ global model
36
+ global img
37
+ seed = int(seed)
38
+ num_inference_steps = int(num_inference_steps)
39
+ guidance_scale = float(guidance_scale)
40
+ print(seed, num_inference_steps, guidance_scale)
41
+
42
+ t_s = time.time()
43
+ new_image = model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
44
+ print('time consumption:', time.time() - t_s)
45
+
46
+ img = copy.copy(new_image[0])
47
+
48
+ return new_image[0]
49
+
50
+ def set_new_latent_and_generate_new_image_and_random_seed(seed, prompt, negative_prompt="", num_inference_steps=1, guidance_scale=0.0):
51
+ print('Generate with a random seed')
52
+ global model
53
+ global img
54
+ seed = np.random.randint(0, 2**32)
55
+ num_inference_steps = int(num_inference_steps)
56
+ guidance_scale = float(guidance_scale)
57
+ print(seed, num_inference_steps, guidance_scale)
58
+
59
+ t_s = time.time()
60
+ new_image = model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
61
+ print('time consumption:', time.time() - t_s)
62
+
63
+ img = copy.copy(new_image[0])
64
+
65
+ return new_image[0], seed
66
+
67
+
68
+ def refine_image_512(prompt):
69
+ print('Refine with SDXL-Refiner (512)')
70
+ global img
71
+
72
+ t_s = time.time()
73
+ img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
74
+ img = img.permute(0, 2, 3, 1).squeeze(0).cpu().numpy()
75
+ new_image = pipe(prompt, image=img).images[0]
76
+ print('time consumption:', time.time() - t_s)
77
+ new_image = np.array(new_image) * 1.0 / 255.
78
+
79
+ img = new_image
80
+
81
+ return new_image
82
+
83
+ def refine_image_1024(prompt):
84
+ print('Refine with SDXL-Refiner (1024)')
85
+ global img
86
+
87
+ t_s = time.time()
88
+ img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
89
+ img = torch.nn.functional.interpolate(img, size=1024, mode='bilinear')
90
+ img = img.permute(0, 2, 3, 1).squeeze(0).cpu().numpy()
91
+ new_image = pipe(prompt, image=img).images[0]
92
+ print('time consumption:', time.time() - t_s)
93
+ new_image = np.array(new_image) * 1.0 / 255.
94
+
95
+ img = new_image
96
+
97
+ return new_image
98
+
99
+ set_model('InstaFlow-0.9B')
100
+
101
+ with gr.Blocks() as gradio_gui:
102
+
103
+ with gr.Row():
104
+ with gr.Column(scale=0.5):
105
+ im = gr.Image()
106
+
107
+ with gr.Column():
108
+ #model_id = gr.Dropdown(["InstaFlow-0.9B", "InstaFlow-1.7B"], label="Model ID", info="Choose Your Model")
109
+
110
+ #set_model_button = gr.Button(value="Set New Model")
111
+ #set_model_button.click(set_model, inputs=[model_id])
112
+
113
+ model_id = gr.Textbox(value='InstaFlow-0.9B', label="Model ID")
114
+
115
+ seed_input = gr.Textbox(value='101098274', label="Random Seed")
116
+ prompt_input = gr.Textbox(value='A high-resolution photograph of a waterfall in autumn; muted tone', label="Prompt")
117
+
118
+ new_image_button = gr.Button(value="Generate Image with the Input Seed")
119
+ new_image_button.click(set_new_latent_and_generate_new_image, inputs=[seed_input, prompt_input], outputs=[im])
120
+
121
+ next_image_button = gr.Button(value="Generate Image with a Random Seed")
122
+ next_image_button.click(set_new_latent_and_generate_new_image_and_random_seed, inputs=[seed_input, prompt_input], outputs=[im, seed_input])
123
+
124
+
125
+ refine_button_512 = gr.Button(value="Refine with Refiner (Resolution: 512)")
126
+ refine_button_512.click(refine_image_512, inputs=[prompt_input], outputs=[im])
127
+
128
+ refine_button_1024 = gr.Button(value="Refine with Refiner (Resolution: 1024)")
129
+ refine_button_1024.click(refine_image_1024, inputs=[prompt_input], outputs=[im])
130
+
131
+
132
+ gradio_gui.launch()
rf_models.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import logging
18
+ import math
19
+ import os
20
+ import random
21
+ from pathlib import Path
22
+ from typing import Optional, Union, List, Callable
23
+
24
+ import datasets
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ import transformers
30
+ from datasets import load_dataset
31
+ from huggingface_hub import HfFolder, Repository, create_repo, whoami
32
+ from packaging import version
33
+ from torchvision import transforms
34
+ from tqdm.auto import tqdm
35
+ from transformers import CLIPTextModel, CLIPTokenizer
36
+
37
+ import diffusers
38
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
39
+ from diffusers.optimization import get_scheduler
40
+ from diffusers.training_utils import EMAModel
41
+ from diffusers.utils import check_min_version, deprecate
42
+ from diffusers.utils.import_utils import is_xformers_available
43
+
44
+ import time
45
+
46
+ from torch.distributions import Normal, Categorical
47
+ from torch.distributions.multivariate_normal import MultivariateNormal
48
+ from torch.distributions.mixture_same_family import MixtureSameFamily
49
+
50
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
51
+ import torchvision
52
+
53
+ import cv2
54
+ import copy
55
+
56
+ @torch.no_grad()
57
+ def inference_latent_euler(
58
+ pipeline,
59
+ prompt: Union[str, List[str]],
60
+ height: Optional[int] = None,
61
+ width: Optional[int] = None,
62
+ num_inference_steps: int = 50,
63
+ guidance_scale: float = 7.5,
64
+ negative_prompt: Optional[Union[str, List[str]]] = None,
65
+ num_images_per_prompt: Optional[int] = 1,
66
+ eta: float = 0.0,
67
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
68
+ latents: Optional[torch.FloatTensor] = None,
69
+ output_type: Optional[str] = "pil",
70
+ return_dict: bool = True,
71
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
72
+ callback_steps: Optional[int] = 1,
73
+ ):
74
+ # 0. Default height and width to unet
75
+ height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
76
+ width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
77
+
78
+ # 1. Check inputs. Raise error if not correct
79
+ pipeline.check_inputs(prompt, height, width, callback_steps)
80
+
81
+ # 2. Define call parameters
82
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
83
+ device = pipeline._execution_device
84
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
85
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
86
+ # corresponds to doing no classifier free guidance.
87
+ do_classifier_free_guidance = guidance_scale > 1.0
88
+
89
+ # 3. Encode input prompt
90
+ t_s = time.time()
91
+ text_embeddings = pipeline._encode_prompt(
92
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
93
+ )
94
+ t_e = time.time()
95
+ print('Text Embedding Time:', t_e - t_s)
96
+
97
+ # 5. Prepare latent variables
98
+ num_channels_latents = pipeline.unet.in_channels
99
+ latents = pipeline.prepare_latents(
100
+ batch_size * num_images_per_prompt,
101
+ num_channels_latents,
102
+ height,
103
+ width,
104
+ text_embeddings.dtype,
105
+ device,
106
+ generator,
107
+ latents,
108
+ )
109
+
110
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
111
+ extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
112
+
113
+ # 7. Denoising loop
114
+ dt = 1./ num_inference_steps
115
+ init_latents = latents.detach().clone()
116
+
117
+ for i in range(num_inference_steps):
118
+ # expand the latents if we are doing classifier free guidance
119
+ latent_model_input = torch.cat(
120
+ [latents] * 2) if do_classifier_free_guidance else latents
121
+
122
+ vec_t = torch.ones((latent_model_input.shape[0],), device=latents.device) * (i / num_inference_steps * 1.0)
123
+
124
+
125
+ v_pred = pipeline.unet(
126
+ latent_model_input, (1.-vec_t) * 1000., encoder_hidden_states=text_embeddings).sample
127
+
128
+ # perform guidance
129
+ if do_classifier_free_guidance:
130
+ v_pred_uncond, v_pred_text = v_pred.chunk(2)
131
+ v_pred = v_pred_uncond + guidance_scale * \
132
+ (v_pred_text - v_pred_uncond)
133
+
134
+ latents = latents + dt * v_pred
135
+
136
+ example = {
137
+ 'latent': latents.detach(),
138
+ 'init_latent': init_latents.detach().clone(),
139
+ 'text_embeddings': text_embeddings.chunk(2)[1].detach() if do_classifier_free_guidance else text_embeddings.detach(),
140
+ }
141
+
142
+ return example
143
+
144
+ def setup_seed(seed):
145
+ import random
146
+ torch.manual_seed(seed)
147
+ torch.cuda.manual_seed_all(seed)
148
+ np.random.seed(seed)
149
+ random.seed(seed)
150
+ torch.backends.cudnn.benchmark = False
151
+ torch.backends.cudnn.deterministic = True
152
+ torch.cuda.empty_cache()
153
+
154
+
155
+ class RF_model():
156
+
157
+ def __init__(self, model_id):
158
+ pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
159
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
160
+
161
+ # Load scheduler, tokenizer and models.
162
+ noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler")
163
+ tokenizer = CLIPTokenizer.from_pretrained(
164
+ self.pretrained_model_name_or_path, subfolder="tokenizer"#, revision=args.revision
165
+ )
166
+ text_encoder = CLIPTextModel.from_pretrained(
167
+ self.pretrained_model_name_or_path, subfolder="text_encoder"#, revision=args.revision
168
+ )
169
+ vae = AutoencoderKL.from_pretrained(
170
+ self.pretrained_model_name_or_path, subfolder="vae"#, revision=args.revision
171
+ )
172
+ unet = UNet2DConditionModel.from_pretrained(
173
+ self.pretrained_model_name_or_path, subfolder="unet"#, revision=args.non_ema_revision
174
+ )
175
+
176
+ print('Loading: Stacked U-Net 0.9B')
177
+ unet = UNet2DConditionModel.from_config(unet.config)
178
+ unet.load_state_dict(torch.load(model_id, map_location='cpu'))
179
+
180
+ unet.eval()
181
+ vae.eval()
182
+ text_encoder.eval()
183
+
184
+ # Freeze vae and text_encoder
185
+ vae.requires_grad_(False)
186
+ text_encoder.requires_grad_(False)
187
+ unet.requires_grad_(False)
188
+
189
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
190
+ # as these models are only used for inference, keeping weights in full precision is not required.
191
+ weight_dtype = torch.float16
192
+ self.weight_dtype = weight_dtype
193
+ device = 'cuda'
194
+ self.device = device
195
+
196
+ # Move text_encode and vae to gpu and cast to weight_dtype
197
+ text_encoder.to(device, dtype=weight_dtype)
198
+ vae.to(device, dtype=weight_dtype)
199
+ unet.to(device, dtype=weight_dtype)
200
+
201
+ # Create the pipeline using the trained modules and save it.
202
+ pipeline = StableDiffusionPipeline.from_pretrained(
203
+ self.pretrained_model_name_or_path,
204
+ text_encoder=text_encoder,
205
+ vae=vae,
206
+ unet=unet,
207
+ torch_dtype=weight_dtype,
208
+ )
209
+ self.pipeline = pipeline.to(device)
210
+
211
+ def set_new_latent_and_generate_new_image(self, seed=None, prompt=None, negative_prompt="", num_inference_steps=50, guidance_scale=4.0, verbose=True):
212
+ if seed is None:
213
+ assert False, "Must have a pre-defined random seed"
214
+
215
+ if prompt is None:
216
+ assert False, "Must have a user-specified text prompt"
217
+
218
+ setup_seed(seed)
219
+ self.latents = torch.randn((1, 4, 64, 64), device=self.device).to(dtype=self.weight_dtype)
220
+ self.prompt = prompt
221
+ self.negative_prompt = negative_prompt
222
+ self.guidance_scale = guidance_scale
223
+ self.num_inference_steps = num_inference_steps
224
+
225
+ prompts = [prompt]
226
+ negative_prompts = [negative_prompt]
227
+ if verbose:
228
+ print(prompts)
229
+ print(negative_prompts)
230
+
231
+ output = inference_latent_euler(
232
+ self.pipeline,
233
+ prompt=prompts,
234
+ negative_prompt=negative_prompts,
235
+ num_inference_steps=num_inference_steps,
236
+ guidance_scale=self.guidance_scale,
237
+ latents=self.latents.detach().clone(),
238
+ )
239
+
240
+ t_s = time.time()
241
+ image = self.pipeline.decode_latents(output['latent'])
242
+ t_e = time.time()
243
+ print('Decoding Time:', t_e - t_s)
244
+
245
+ self.org_image = image
246
+
247
+ return image
248
+
249
+