Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- app.py +132 -0
- rf_models.py +249 -0
app.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from rf_models import RF_model
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from diffusers import StableDiffusionXLImg2ImgPipeline
|
10 |
+
import time
|
11 |
+
import copy
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
15 |
+
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
16 |
+
)
|
17 |
+
pipe = pipe.to("cuda")
|
18 |
+
|
19 |
+
global model
|
20 |
+
global img
|
21 |
+
|
22 |
+
def set_model(model_id):
|
23 |
+
global model
|
24 |
+
if model_id == "InstaFlow-0.9B":
|
25 |
+
model = RF_model("./instaflow_09b.pt")
|
26 |
+
elif model_id == "InstaFlow-1.7B":
|
27 |
+
model = RF_model("./instaflow_17b.pt")
|
28 |
+
else:
|
29 |
+
raise NotImplementedError
|
30 |
+
print('Finished Loading Model!')
|
31 |
+
|
32 |
+
|
33 |
+
def set_new_latent_and_generate_new_image(seed, prompt, negative_prompt="", num_inference_steps=1, guidance_scale=0.0):
|
34 |
+
print('Generate with input seed')
|
35 |
+
global model
|
36 |
+
global img
|
37 |
+
seed = int(seed)
|
38 |
+
num_inference_steps = int(num_inference_steps)
|
39 |
+
guidance_scale = float(guidance_scale)
|
40 |
+
print(seed, num_inference_steps, guidance_scale)
|
41 |
+
|
42 |
+
t_s = time.time()
|
43 |
+
new_image = model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
|
44 |
+
print('time consumption:', time.time() - t_s)
|
45 |
+
|
46 |
+
img = copy.copy(new_image[0])
|
47 |
+
|
48 |
+
return new_image[0]
|
49 |
+
|
50 |
+
def set_new_latent_and_generate_new_image_and_random_seed(seed, prompt, negative_prompt="", num_inference_steps=1, guidance_scale=0.0):
|
51 |
+
print('Generate with a random seed')
|
52 |
+
global model
|
53 |
+
global img
|
54 |
+
seed = np.random.randint(0, 2**32)
|
55 |
+
num_inference_steps = int(num_inference_steps)
|
56 |
+
guidance_scale = float(guidance_scale)
|
57 |
+
print(seed, num_inference_steps, guidance_scale)
|
58 |
+
|
59 |
+
t_s = time.time()
|
60 |
+
new_image = model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
|
61 |
+
print('time consumption:', time.time() - t_s)
|
62 |
+
|
63 |
+
img = copy.copy(new_image[0])
|
64 |
+
|
65 |
+
return new_image[0], seed
|
66 |
+
|
67 |
+
|
68 |
+
def refine_image_512(prompt):
|
69 |
+
print('Refine with SDXL-Refiner (512)')
|
70 |
+
global img
|
71 |
+
|
72 |
+
t_s = time.time()
|
73 |
+
img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
|
74 |
+
img = img.permute(0, 2, 3, 1).squeeze(0).cpu().numpy()
|
75 |
+
new_image = pipe(prompt, image=img).images[0]
|
76 |
+
print('time consumption:', time.time() - t_s)
|
77 |
+
new_image = np.array(new_image) * 1.0 / 255.
|
78 |
+
|
79 |
+
img = new_image
|
80 |
+
|
81 |
+
return new_image
|
82 |
+
|
83 |
+
def refine_image_1024(prompt):
|
84 |
+
print('Refine with SDXL-Refiner (1024)')
|
85 |
+
global img
|
86 |
+
|
87 |
+
t_s = time.time()
|
88 |
+
img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
|
89 |
+
img = torch.nn.functional.interpolate(img, size=1024, mode='bilinear')
|
90 |
+
img = img.permute(0, 2, 3, 1).squeeze(0).cpu().numpy()
|
91 |
+
new_image = pipe(prompt, image=img).images[0]
|
92 |
+
print('time consumption:', time.time() - t_s)
|
93 |
+
new_image = np.array(new_image) * 1.0 / 255.
|
94 |
+
|
95 |
+
img = new_image
|
96 |
+
|
97 |
+
return new_image
|
98 |
+
|
99 |
+
set_model('InstaFlow-0.9B')
|
100 |
+
|
101 |
+
with gr.Blocks() as gradio_gui:
|
102 |
+
|
103 |
+
with gr.Row():
|
104 |
+
with gr.Column(scale=0.5):
|
105 |
+
im = gr.Image()
|
106 |
+
|
107 |
+
with gr.Column():
|
108 |
+
#model_id = gr.Dropdown(["InstaFlow-0.9B", "InstaFlow-1.7B"], label="Model ID", info="Choose Your Model")
|
109 |
+
|
110 |
+
#set_model_button = gr.Button(value="Set New Model")
|
111 |
+
#set_model_button.click(set_model, inputs=[model_id])
|
112 |
+
|
113 |
+
model_id = gr.Textbox(value='InstaFlow-0.9B', label="Model ID")
|
114 |
+
|
115 |
+
seed_input = gr.Textbox(value='101098274', label="Random Seed")
|
116 |
+
prompt_input = gr.Textbox(value='A high-resolution photograph of a waterfall in autumn; muted tone', label="Prompt")
|
117 |
+
|
118 |
+
new_image_button = gr.Button(value="Generate Image with the Input Seed")
|
119 |
+
new_image_button.click(set_new_latent_and_generate_new_image, inputs=[seed_input, prompt_input], outputs=[im])
|
120 |
+
|
121 |
+
next_image_button = gr.Button(value="Generate Image with a Random Seed")
|
122 |
+
next_image_button.click(set_new_latent_and_generate_new_image_and_random_seed, inputs=[seed_input, prompt_input], outputs=[im, seed_input])
|
123 |
+
|
124 |
+
|
125 |
+
refine_button_512 = gr.Button(value="Refine with Refiner (Resolution: 512)")
|
126 |
+
refine_button_512.click(refine_image_512, inputs=[prompt_input], outputs=[im])
|
127 |
+
|
128 |
+
refine_button_1024 = gr.Button(value="Refine with Refiner (Resolution: 1024)")
|
129 |
+
refine_button_1024.click(refine_image_1024, inputs=[prompt_input], outputs=[im])
|
130 |
+
|
131 |
+
|
132 |
+
gradio_gui.launch()
|
rf_models.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import logging
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import random
|
21 |
+
from pathlib import Path
|
22 |
+
from typing import Optional, Union, List, Callable
|
23 |
+
|
24 |
+
import datasets
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
+
import torch.utils.checkpoint
|
29 |
+
import transformers
|
30 |
+
from datasets import load_dataset
|
31 |
+
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
32 |
+
from packaging import version
|
33 |
+
from torchvision import transforms
|
34 |
+
from tqdm.auto import tqdm
|
35 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
36 |
+
|
37 |
+
import diffusers
|
38 |
+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
39 |
+
from diffusers.optimization import get_scheduler
|
40 |
+
from diffusers.training_utils import EMAModel
|
41 |
+
from diffusers.utils import check_min_version, deprecate
|
42 |
+
from diffusers.utils.import_utils import is_xformers_available
|
43 |
+
|
44 |
+
import time
|
45 |
+
|
46 |
+
from torch.distributions import Normal, Categorical
|
47 |
+
from torch.distributions.multivariate_normal import MultivariateNormal
|
48 |
+
from torch.distributions.mixture_same_family import MixtureSameFamily
|
49 |
+
|
50 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
51 |
+
import torchvision
|
52 |
+
|
53 |
+
import cv2
|
54 |
+
import copy
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def inference_latent_euler(
|
58 |
+
pipeline,
|
59 |
+
prompt: Union[str, List[str]],
|
60 |
+
height: Optional[int] = None,
|
61 |
+
width: Optional[int] = None,
|
62 |
+
num_inference_steps: int = 50,
|
63 |
+
guidance_scale: float = 7.5,
|
64 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
65 |
+
num_images_per_prompt: Optional[int] = 1,
|
66 |
+
eta: float = 0.0,
|
67 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
68 |
+
latents: Optional[torch.FloatTensor] = None,
|
69 |
+
output_type: Optional[str] = "pil",
|
70 |
+
return_dict: bool = True,
|
71 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
72 |
+
callback_steps: Optional[int] = 1,
|
73 |
+
):
|
74 |
+
# 0. Default height and width to unet
|
75 |
+
height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
76 |
+
width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
77 |
+
|
78 |
+
# 1. Check inputs. Raise error if not correct
|
79 |
+
pipeline.check_inputs(prompt, height, width, callback_steps)
|
80 |
+
|
81 |
+
# 2. Define call parameters
|
82 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
83 |
+
device = pipeline._execution_device
|
84 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
85 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
86 |
+
# corresponds to doing no classifier free guidance.
|
87 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
88 |
+
|
89 |
+
# 3. Encode input prompt
|
90 |
+
t_s = time.time()
|
91 |
+
text_embeddings = pipeline._encode_prompt(
|
92 |
+
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
93 |
+
)
|
94 |
+
t_e = time.time()
|
95 |
+
print('Text Embedding Time:', t_e - t_s)
|
96 |
+
|
97 |
+
# 5. Prepare latent variables
|
98 |
+
num_channels_latents = pipeline.unet.in_channels
|
99 |
+
latents = pipeline.prepare_latents(
|
100 |
+
batch_size * num_images_per_prompt,
|
101 |
+
num_channels_latents,
|
102 |
+
height,
|
103 |
+
width,
|
104 |
+
text_embeddings.dtype,
|
105 |
+
device,
|
106 |
+
generator,
|
107 |
+
latents,
|
108 |
+
)
|
109 |
+
|
110 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
111 |
+
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
|
112 |
+
|
113 |
+
# 7. Denoising loop
|
114 |
+
dt = 1./ num_inference_steps
|
115 |
+
init_latents = latents.detach().clone()
|
116 |
+
|
117 |
+
for i in range(num_inference_steps):
|
118 |
+
# expand the latents if we are doing classifier free guidance
|
119 |
+
latent_model_input = torch.cat(
|
120 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
121 |
+
|
122 |
+
vec_t = torch.ones((latent_model_input.shape[0],), device=latents.device) * (i / num_inference_steps * 1.0)
|
123 |
+
|
124 |
+
|
125 |
+
v_pred = pipeline.unet(
|
126 |
+
latent_model_input, (1.-vec_t) * 1000., encoder_hidden_states=text_embeddings).sample
|
127 |
+
|
128 |
+
# perform guidance
|
129 |
+
if do_classifier_free_guidance:
|
130 |
+
v_pred_uncond, v_pred_text = v_pred.chunk(2)
|
131 |
+
v_pred = v_pred_uncond + guidance_scale * \
|
132 |
+
(v_pred_text - v_pred_uncond)
|
133 |
+
|
134 |
+
latents = latents + dt * v_pred
|
135 |
+
|
136 |
+
example = {
|
137 |
+
'latent': latents.detach(),
|
138 |
+
'init_latent': init_latents.detach().clone(),
|
139 |
+
'text_embeddings': text_embeddings.chunk(2)[1].detach() if do_classifier_free_guidance else text_embeddings.detach(),
|
140 |
+
}
|
141 |
+
|
142 |
+
return example
|
143 |
+
|
144 |
+
def setup_seed(seed):
|
145 |
+
import random
|
146 |
+
torch.manual_seed(seed)
|
147 |
+
torch.cuda.manual_seed_all(seed)
|
148 |
+
np.random.seed(seed)
|
149 |
+
random.seed(seed)
|
150 |
+
torch.backends.cudnn.benchmark = False
|
151 |
+
torch.backends.cudnn.deterministic = True
|
152 |
+
torch.cuda.empty_cache()
|
153 |
+
|
154 |
+
|
155 |
+
class RF_model():
|
156 |
+
|
157 |
+
def __init__(self, model_id):
|
158 |
+
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
|
159 |
+
self.pretrained_model_name_or_path = pretrained_model_name_or_path
|
160 |
+
|
161 |
+
# Load scheduler, tokenizer and models.
|
162 |
+
noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler")
|
163 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
164 |
+
self.pretrained_model_name_or_path, subfolder="tokenizer"#, revision=args.revision
|
165 |
+
)
|
166 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
167 |
+
self.pretrained_model_name_or_path, subfolder="text_encoder"#, revision=args.revision
|
168 |
+
)
|
169 |
+
vae = AutoencoderKL.from_pretrained(
|
170 |
+
self.pretrained_model_name_or_path, subfolder="vae"#, revision=args.revision
|
171 |
+
)
|
172 |
+
unet = UNet2DConditionModel.from_pretrained(
|
173 |
+
self.pretrained_model_name_or_path, subfolder="unet"#, revision=args.non_ema_revision
|
174 |
+
)
|
175 |
+
|
176 |
+
print('Loading: Stacked U-Net 0.9B')
|
177 |
+
unet = UNet2DConditionModel.from_config(unet.config)
|
178 |
+
unet.load_state_dict(torch.load(model_id, map_location='cpu'))
|
179 |
+
|
180 |
+
unet.eval()
|
181 |
+
vae.eval()
|
182 |
+
text_encoder.eval()
|
183 |
+
|
184 |
+
# Freeze vae and text_encoder
|
185 |
+
vae.requires_grad_(False)
|
186 |
+
text_encoder.requires_grad_(False)
|
187 |
+
unet.requires_grad_(False)
|
188 |
+
|
189 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
190 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
191 |
+
weight_dtype = torch.float16
|
192 |
+
self.weight_dtype = weight_dtype
|
193 |
+
device = 'cuda'
|
194 |
+
self.device = device
|
195 |
+
|
196 |
+
# Move text_encode and vae to gpu and cast to weight_dtype
|
197 |
+
text_encoder.to(device, dtype=weight_dtype)
|
198 |
+
vae.to(device, dtype=weight_dtype)
|
199 |
+
unet.to(device, dtype=weight_dtype)
|
200 |
+
|
201 |
+
# Create the pipeline using the trained modules and save it.
|
202 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
203 |
+
self.pretrained_model_name_or_path,
|
204 |
+
text_encoder=text_encoder,
|
205 |
+
vae=vae,
|
206 |
+
unet=unet,
|
207 |
+
torch_dtype=weight_dtype,
|
208 |
+
)
|
209 |
+
self.pipeline = pipeline.to(device)
|
210 |
+
|
211 |
+
def set_new_latent_and_generate_new_image(self, seed=None, prompt=None, negative_prompt="", num_inference_steps=50, guidance_scale=4.0, verbose=True):
|
212 |
+
if seed is None:
|
213 |
+
assert False, "Must have a pre-defined random seed"
|
214 |
+
|
215 |
+
if prompt is None:
|
216 |
+
assert False, "Must have a user-specified text prompt"
|
217 |
+
|
218 |
+
setup_seed(seed)
|
219 |
+
self.latents = torch.randn((1, 4, 64, 64), device=self.device).to(dtype=self.weight_dtype)
|
220 |
+
self.prompt = prompt
|
221 |
+
self.negative_prompt = negative_prompt
|
222 |
+
self.guidance_scale = guidance_scale
|
223 |
+
self.num_inference_steps = num_inference_steps
|
224 |
+
|
225 |
+
prompts = [prompt]
|
226 |
+
negative_prompts = [negative_prompt]
|
227 |
+
if verbose:
|
228 |
+
print(prompts)
|
229 |
+
print(negative_prompts)
|
230 |
+
|
231 |
+
output = inference_latent_euler(
|
232 |
+
self.pipeline,
|
233 |
+
prompt=prompts,
|
234 |
+
negative_prompt=negative_prompts,
|
235 |
+
num_inference_steps=num_inference_steps,
|
236 |
+
guidance_scale=self.guidance_scale,
|
237 |
+
latents=self.latents.detach().clone(),
|
238 |
+
)
|
239 |
+
|
240 |
+
t_s = time.time()
|
241 |
+
image = self.pipeline.decode_latents(output['latent'])
|
242 |
+
t_e = time.time()
|
243 |
+
print('Decoding Time:', t_e - t_s)
|
244 |
+
|
245 |
+
self.org_image = image
|
246 |
+
|
247 |
+
return image
|
248 |
+
|
249 |
+
|