Spaces:
No application file
No application file
cleanup
Browse files- gradio_ui.py +492 -0
- latent_blending.py +213 -579
- movie_util.py +46 -54
- stable_diffusion_holder.py +87 -355
- utils.py +260 -0
gradio_ui.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Lunar Ring. All rights reserved.
|
2 |
+
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import torch
|
18 |
+
torch.backends.cudnn.benchmark = False
|
19 |
+
torch.set_grad_enabled(False)
|
20 |
+
import numpy as np
|
21 |
+
import warnings
|
22 |
+
warnings.filterwarnings('ignore')
|
23 |
+
import warnings
|
24 |
+
from tqdm.auto import tqdm
|
25 |
+
from PIL import Image
|
26 |
+
from movie_util import MovieSaver, concatenate_movies
|
27 |
+
from latent_blending import LatentBlending
|
28 |
+
from stable_diffusion_holder import StableDiffusionHolder
|
29 |
+
import gradio as gr
|
30 |
+
from dotenv import find_dotenv, load_dotenv
|
31 |
+
import shutil
|
32 |
+
import random
|
33 |
+
from utils import get_time, add_frames_linear_interp
|
34 |
+
from huggingface_hub import hf_hub_download
|
35 |
+
|
36 |
+
|
37 |
+
class BlendingFrontend():
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
sdh,
|
41 |
+
share=False):
|
42 |
+
r"""
|
43 |
+
Gradio Helper Class to collect UI data and start latent blending.
|
44 |
+
Args:
|
45 |
+
sdh:
|
46 |
+
StableDiffusionHolder
|
47 |
+
share: bool
|
48 |
+
Set true to get a shareable gradio link (e.g. for running a remote server)
|
49 |
+
"""
|
50 |
+
self.share = share
|
51 |
+
|
52 |
+
# UI Defaults
|
53 |
+
self.num_inference_steps = 30
|
54 |
+
self.depth_strength = 0.25
|
55 |
+
self.seed1 = 420
|
56 |
+
self.seed2 = 420
|
57 |
+
self.prompt1 = ""
|
58 |
+
self.prompt2 = ""
|
59 |
+
self.negative_prompt = ""
|
60 |
+
self.fps = 30
|
61 |
+
self.duration_video = 8
|
62 |
+
self.t_compute_max_allowed = 10
|
63 |
+
|
64 |
+
self.lb = LatentBlending(sdh)
|
65 |
+
self.lb.sdh.num_inference_steps = self.num_inference_steps
|
66 |
+
self.init_parameters_from_lb()
|
67 |
+
self.init_save_dir()
|
68 |
+
|
69 |
+
# Vars
|
70 |
+
self.list_fp_imgs_current = []
|
71 |
+
self.recycle_img1 = False
|
72 |
+
self.recycle_img2 = False
|
73 |
+
self.list_all_segments = []
|
74 |
+
self.dp_session = ""
|
75 |
+
self.user_id = None
|
76 |
+
|
77 |
+
def init_parameters_from_lb(self):
|
78 |
+
r"""
|
79 |
+
Automatically init parameters from latentblending instance
|
80 |
+
"""
|
81 |
+
self.height = self.lb.sdh.height
|
82 |
+
self.width = self.lb.sdh.width
|
83 |
+
self.guidance_scale = self.lb.guidance_scale
|
84 |
+
self.guidance_scale_mid_damper = self.lb.guidance_scale_mid_damper
|
85 |
+
self.mid_compression_scaler = self.lb.mid_compression_scaler
|
86 |
+
self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
|
87 |
+
self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
|
88 |
+
self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
|
89 |
+
self.parental_crossfeed_power = self.lb.parental_crossfeed_power
|
90 |
+
self.parental_crossfeed_range = self.lb.parental_crossfeed_range
|
91 |
+
self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
|
92 |
+
|
93 |
+
def init_save_dir(self):
|
94 |
+
r"""
|
95 |
+
Initializes the directory where stuff is being saved.
|
96 |
+
You can specify this directory in a ".env" file in your latentblending root, setting
|
97 |
+
DIR_OUT='/path/to/saving'
|
98 |
+
"""
|
99 |
+
load_dotenv(find_dotenv(), verbose=False)
|
100 |
+
self.dp_out = os.getenv("DIR_OUT")
|
101 |
+
if self.dp_out is None:
|
102 |
+
self.dp_out = ""
|
103 |
+
self.dp_imgs = os.path.join(self.dp_out, "imgs")
|
104 |
+
os.makedirs(self.dp_imgs, exist_ok=True)
|
105 |
+
self.dp_movies = os.path.join(self.dp_out, "movies")
|
106 |
+
os.makedirs(self.dp_movies, exist_ok=True)
|
107 |
+
self.save_empty_image()
|
108 |
+
|
109 |
+
def save_empty_image(self):
|
110 |
+
r"""
|
111 |
+
Saves an empty/black dummy image.
|
112 |
+
"""
|
113 |
+
self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
|
114 |
+
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
115 |
+
|
116 |
+
def randomize_seed1(self):
|
117 |
+
r"""
|
118 |
+
Randomizes the first seed
|
119 |
+
"""
|
120 |
+
seed = np.random.randint(0, 10000000)
|
121 |
+
self.seed1 = int(seed)
|
122 |
+
print(f"randomize_seed1: new seed = {self.seed1}")
|
123 |
+
return seed
|
124 |
+
|
125 |
+
def randomize_seed2(self):
|
126 |
+
r"""
|
127 |
+
Randomizes the second seed
|
128 |
+
"""
|
129 |
+
seed = np.random.randint(0, 10000000)
|
130 |
+
self.seed2 = int(seed)
|
131 |
+
print(f"randomize_seed2: new seed = {self.seed2}")
|
132 |
+
return seed
|
133 |
+
|
134 |
+
def setup_lb(self, list_ui_vals):
|
135 |
+
r"""
|
136 |
+
Sets all parameters from the UI. Since gradio does not support to pass dictionaries,
|
137 |
+
we have to instead pass keys (list_ui_keys, global) and values (list_ui_vals)
|
138 |
+
"""
|
139 |
+
# Collect latent blending variables
|
140 |
+
self.lb.set_width(list_ui_vals[list_ui_keys.index('width')])
|
141 |
+
self.lb.set_height(list_ui_vals[list_ui_keys.index('height')])
|
142 |
+
self.lb.set_prompt1(list_ui_vals[list_ui_keys.index('prompt1')])
|
143 |
+
self.lb.set_prompt2(list_ui_vals[list_ui_keys.index('prompt2')])
|
144 |
+
self.lb.set_negative_prompt(list_ui_vals[list_ui_keys.index('negative_prompt')])
|
145 |
+
self.lb.guidance_scale = list_ui_vals[list_ui_keys.index('guidance_scale')]
|
146 |
+
self.lb.guidance_scale_mid_damper = list_ui_vals[list_ui_keys.index('guidance_scale_mid_damper')]
|
147 |
+
self.t_compute_max_allowed = list_ui_vals[list_ui_keys.index('duration_compute')]
|
148 |
+
self.lb.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
149 |
+
self.lb.sdh.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
150 |
+
self.duration_video = list_ui_vals[list_ui_keys.index('duration_video')]
|
151 |
+
self.lb.seed1 = list_ui_vals[list_ui_keys.index('seed1')]
|
152 |
+
self.lb.seed2 = list_ui_vals[list_ui_keys.index('seed2')]
|
153 |
+
self.lb.branch1_crossfeed_power = list_ui_vals[list_ui_keys.index('branch1_crossfeed_power')]
|
154 |
+
self.lb.branch1_crossfeed_range = list_ui_vals[list_ui_keys.index('branch1_crossfeed_range')]
|
155 |
+
self.lb.branch1_crossfeed_decay = list_ui_vals[list_ui_keys.index('branch1_crossfeed_decay')]
|
156 |
+
self.lb.parental_crossfeed_power = list_ui_vals[list_ui_keys.index('parental_crossfeed_power')]
|
157 |
+
self.lb.parental_crossfeed_range = list_ui_vals[list_ui_keys.index('parental_crossfeed_range')]
|
158 |
+
self.lb.parental_crossfeed_power_decay = list_ui_vals[list_ui_keys.index('parental_crossfeed_power_decay')]
|
159 |
+
self.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
160 |
+
self.depth_strength = list_ui_vals[list_ui_keys.index('depth_strength')]
|
161 |
+
|
162 |
+
if len(list_ui_vals[list_ui_keys.index('user_id')]) > 1:
|
163 |
+
self.user_id = list_ui_vals[list_ui_keys.index('user_id')]
|
164 |
+
else:
|
165 |
+
# generate new user id
|
166 |
+
self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8)))
|
167 |
+
print(f"made new user_id: {self.user_id} at {get_time('second')}")
|
168 |
+
|
169 |
+
def save_latents(self, fp_latents, list_latents):
|
170 |
+
r"""
|
171 |
+
Saves a latent trajectory on disk, in npy format.
|
172 |
+
"""
|
173 |
+
list_latents_cpu = [l.cpu().numpy() for l in list_latents]
|
174 |
+
np.save(fp_latents, list_latents_cpu)
|
175 |
+
|
176 |
+
def load_latents(self, fp_latents):
|
177 |
+
r"""
|
178 |
+
Loads a latent trajectory from disk, converts to torch tensor.
|
179 |
+
"""
|
180 |
+
list_latents_cpu = np.load(fp_latents)
|
181 |
+
list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
|
182 |
+
return list_latents
|
183 |
+
|
184 |
+
def compute_img1(self, *args):
|
185 |
+
r"""
|
186 |
+
Computes the first transition image and returns it for display.
|
187 |
+
Sets all other transition images and last image to empty (as they are obsolete with this operation)
|
188 |
+
"""
|
189 |
+
list_ui_vals = args
|
190 |
+
self.setup_lb(list_ui_vals)
|
191 |
+
fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
|
192 |
+
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
193 |
+
img1.save(fp_img1 + ".jpg")
|
194 |
+
self.save_latents(fp_img1 + ".npy", self.lb.tree_latents[0])
|
195 |
+
self.recycle_img1 = True
|
196 |
+
self.recycle_img2 = False
|
197 |
+
return [fp_img1 + ".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
198 |
+
|
199 |
+
def compute_img2(self, *args):
|
200 |
+
r"""
|
201 |
+
Computes the last transition image and returns it for display.
|
202 |
+
Sets all other transition images to empty (as they are obsolete with this operation)
|
203 |
+
"""
|
204 |
+
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
|
205 |
+
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
206 |
+
list_ui_vals = args
|
207 |
+
self.setup_lb(list_ui_vals)
|
208 |
+
|
209 |
+
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
210 |
+
fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
|
211 |
+
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
212 |
+
img2.save(fp_img2 + '.jpg')
|
213 |
+
self.save_latents(fp_img2 + ".npy", self.lb.tree_latents[-1])
|
214 |
+
self.recycle_img2 = True
|
215 |
+
# fixme save seeds. change filenames?
|
216 |
+
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2 + ".jpg", self.user_id]
|
217 |
+
|
218 |
+
def compute_transition(self, *args):
|
219 |
+
r"""
|
220 |
+
Computes transition images and movie.
|
221 |
+
"""
|
222 |
+
list_ui_vals = args
|
223 |
+
self.setup_lb(list_ui_vals)
|
224 |
+
print("STARTING TRANSITION...")
|
225 |
+
fixed_seeds = [self.seed1, self.seed2]
|
226 |
+
# Inject loaded latents (other user interference)
|
227 |
+
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
228 |
+
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
|
229 |
+
imgs_transition = self.lb.run_transition(
|
230 |
+
recycle_img1=self.recycle_img1,
|
231 |
+
recycle_img2=self.recycle_img2,
|
232 |
+
num_inference_steps=self.num_inference_steps,
|
233 |
+
depth_strength=self.depth_strength,
|
234 |
+
t_compute_max_allowed=self.t_compute_max_allowed,
|
235 |
+
fixed_seeds=fixed_seeds)
|
236 |
+
print(f"Latent Blending pass finished ({get_time('second')}). Resulted in {len(imgs_transition)} images")
|
237 |
+
|
238 |
+
# Subselect three preview images
|
239 |
+
idx_img_prev = np.round(np.linspace(0, len(imgs_transition) - 1, 5)[1:-1]).astype(np.int32)
|
240 |
+
|
241 |
+
list_imgs_preview = []
|
242 |
+
for j in idx_img_prev:
|
243 |
+
list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
|
244 |
+
|
245 |
+
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
|
246 |
+
current_timestamp = get_time('second')
|
247 |
+
self.list_fp_imgs_current = []
|
248 |
+
for i in range(len(list_imgs_preview)):
|
249 |
+
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{current_timestamp}.jpg")
|
250 |
+
list_imgs_preview[i].save(fp_img)
|
251 |
+
self.list_fp_imgs_current.append(fp_img)
|
252 |
+
# Insert cheap frames for the movie
|
253 |
+
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
|
254 |
+
|
255 |
+
# Save as movie
|
256 |
+
self.fp_movie = self.get_fp_video_last()
|
257 |
+
if os.path.isfile(self.fp_movie):
|
258 |
+
os.remove(self.fp_movie)
|
259 |
+
ms = MovieSaver(self.fp_movie, fps=self.fps)
|
260 |
+
for img in tqdm(imgs_transition_ext):
|
261 |
+
ms.write_frame(img)
|
262 |
+
ms.finalize()
|
263 |
+
print("DONE SAVING MOVIE! SENDING BACK...")
|
264 |
+
|
265 |
+
# Assemble Output, updating the preview images and le movie
|
266 |
+
list_return = self.list_fp_imgs_current + [self.fp_movie]
|
267 |
+
return list_return
|
268 |
+
|
269 |
+
def stack_forward(self, prompt2, seed2):
|
270 |
+
r"""
|
271 |
+
Allows to generate multi-segment movies. Sets last image -> first image with all
|
272 |
+
relevant parameters.
|
273 |
+
"""
|
274 |
+
# Save preview images, prompts and seeds into dictionary for stacking
|
275 |
+
if len(self.list_all_segments) == 0:
|
276 |
+
timestamp_session = get_time('second')
|
277 |
+
self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
|
278 |
+
os.makedirs(self.dp_session)
|
279 |
+
|
280 |
+
idx_segment = len(self.list_all_segments)
|
281 |
+
dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
|
282 |
+
|
283 |
+
self.list_all_segments.append(dp_segment)
|
284 |
+
self.lb.write_imgs_transition(dp_segment)
|
285 |
+
|
286 |
+
fp_movie_last = self.get_fp_video_last()
|
287 |
+
fp_movie_next = self.get_fp_video_next()
|
288 |
+
|
289 |
+
shutil.copyfile(fp_movie_last, fp_movie_next)
|
290 |
+
|
291 |
+
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
292 |
+
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
|
293 |
+
self.lb.swap_forward()
|
294 |
+
|
295 |
+
shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
296 |
+
fp_multi = self.multi_concat()
|
297 |
+
list_out = [fp_multi]
|
298 |
+
|
299 |
+
list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
|
300 |
+
list_out.extend([self.fp_img_empty] * 4)
|
301 |
+
list_out.append(gr.update(interactive=False, value=prompt2))
|
302 |
+
list_out.append(gr.update(interactive=False, value=seed2))
|
303 |
+
list_out.append("")
|
304 |
+
list_out.append(np.random.randint(0, 10000000))
|
305 |
+
print(f"stack_forward: fp_multi {fp_multi}")
|
306 |
+
return list_out
|
307 |
+
|
308 |
+
def multi_concat(self):
|
309 |
+
r"""
|
310 |
+
Concatentates all stacked segments into one long movie.
|
311 |
+
"""
|
312 |
+
list_fp_movies = self.get_fp_video_all()
|
313 |
+
# Concatenate movies and save
|
314 |
+
fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
|
315 |
+
concatenate_movies(fp_final, list_fp_movies)
|
316 |
+
return fp_final
|
317 |
+
|
318 |
+
def get_fp_video_all(self):
|
319 |
+
r"""
|
320 |
+
Collects all stacked movie segments.
|
321 |
+
"""
|
322 |
+
list_all = os.listdir(self.dp_movies)
|
323 |
+
str_beg = f"movie_{self.user_id}_"
|
324 |
+
list_user = [l for l in list_all if str_beg in l]
|
325 |
+
list_user.sort()
|
326 |
+
list_user = [os.path.join(self.dp_movies, l) for l in list_user]
|
327 |
+
return list_user
|
328 |
+
|
329 |
+
def get_fp_video_next(self):
|
330 |
+
r"""
|
331 |
+
Gets the filepath of the next movie segment.
|
332 |
+
"""
|
333 |
+
list_videos = self.get_fp_video_all()
|
334 |
+
if len(list_videos) == 0:
|
335 |
+
idx_next = 0
|
336 |
+
else:
|
337 |
+
idx_next = len(list_videos)
|
338 |
+
fp_video_next = os.path.join(self.dp_movies, f"movie_{self.user_id}_{str(idx_next).zfill(3)}.mp4")
|
339 |
+
return fp_video_next
|
340 |
+
|
341 |
+
def get_fp_video_last(self):
|
342 |
+
r"""
|
343 |
+
Gets the current video that was saved.
|
344 |
+
"""
|
345 |
+
fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
|
346 |
+
return fp_video_last
|
347 |
+
|
348 |
+
|
349 |
+
if __name__ == "__main__":
|
350 |
+
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
351 |
+
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
|
352 |
+
bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
|
353 |
+
# self = BlendingFrontend(None)
|
354 |
+
|
355 |
+
with gr.Blocks() as demo:
|
356 |
+
with gr.Row():
|
357 |
+
prompt1 = gr.Textbox(label="prompt 1")
|
358 |
+
prompt2 = gr.Textbox(label="prompt 2")
|
359 |
+
|
360 |
+
with gr.Row():
|
361 |
+
duration_compute = gr.Slider(5, 200, bf.t_compute_max_allowed, step=1, label='compute budget', interactive=True)
|
362 |
+
duration_video = gr.Slider(1, 100, bf.duration_video, step=0.1, label='video duration', interactive=True)
|
363 |
+
height = gr.Slider(256, 2048, bf.height, step=128, label='height', interactive=True)
|
364 |
+
width = gr.Slider(256, 2048, bf.width, step=128, label='width', interactive=True)
|
365 |
+
|
366 |
+
with gr.Accordion("Advanced Settings (click to expand)", open=False):
|
367 |
+
|
368 |
+
with gr.Accordion("Diffusion settings", open=True):
|
369 |
+
with gr.Row():
|
370 |
+
num_inference_steps = gr.Slider(5, 100, bf.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
|
371 |
+
guidance_scale = gr.Slider(1, 25, bf.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
|
372 |
+
negative_prompt = gr.Textbox(label="negative prompt")
|
373 |
+
|
374 |
+
with gr.Accordion("Seed control: adjust seeds for first and last images", open=True):
|
375 |
+
with gr.Row():
|
376 |
+
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
|
377 |
+
seed1 = gr.Number(bf.seed1, label="seed 1", interactive=True)
|
378 |
+
seed2 = gr.Number(bf.seed2, label="seed 2", interactive=True)
|
379 |
+
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
|
380 |
+
|
381 |
+
with gr.Accordion("Last image crossfeeding.", open=True):
|
382 |
+
with gr.Row():
|
383 |
+
branch1_crossfeed_power = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_power, step=0.01, label='branch1 crossfeed power', interactive=True)
|
384 |
+
branch1_crossfeed_range = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_range, step=0.01, label='branch1 crossfeed range', interactive=True)
|
385 |
+
branch1_crossfeed_decay = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_decay, step=0.01, label='branch1 crossfeed decay', interactive=True)
|
386 |
+
|
387 |
+
with gr.Accordion("Transition settings", open=True):
|
388 |
+
with gr.Row():
|
389 |
+
parental_crossfeed_power = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power, step=0.01, label='parental crossfeed power', interactive=True)
|
390 |
+
parental_crossfeed_range = gr.Slider(0.0, 1.0, bf.parental_crossfeed_range, step=0.01, label='parental crossfeed range', interactive=True)
|
391 |
+
parental_crossfeed_power_decay = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power_decay, step=0.01, label='parental crossfeed decay', interactive=True)
|
392 |
+
with gr.Row():
|
393 |
+
depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
394 |
+
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
395 |
+
|
396 |
+
with gr.Row():
|
397 |
+
b_compute1 = gr.Button('compute first image', variant='primary')
|
398 |
+
b_compute_transition = gr.Button('compute transition', variant='primary')
|
399 |
+
b_compute2 = gr.Button('compute last image', variant='primary')
|
400 |
+
|
401 |
+
with gr.Row():
|
402 |
+
img1 = gr.Image(label="1/5")
|
403 |
+
img2 = gr.Image(label="2/5", show_progress=False)
|
404 |
+
img3 = gr.Image(label="3/5", show_progress=False)
|
405 |
+
img4 = gr.Image(label="4/5", show_progress=False)
|
406 |
+
img5 = gr.Image(label="5/5")
|
407 |
+
|
408 |
+
with gr.Row():
|
409 |
+
vid_single = gr.Video(label="current single trans")
|
410 |
+
vid_multi = gr.Video(label="concatented multi trans")
|
411 |
+
|
412 |
+
with gr.Row():
|
413 |
+
b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
|
414 |
+
|
415 |
+
with gr.Row():
|
416 |
+
gr.Markdown(
|
417 |
+
"""
|
418 |
+
# Parameters
|
419 |
+
## Main
|
420 |
+
- compute budget: set your waiting time for the transition. high values = better quality
|
421 |
+
- video duration: seconds per segment
|
422 |
+
- height/width: in pixels
|
423 |
+
|
424 |
+
## Diffusion settings
|
425 |
+
- num_inference_steps: number of diffusion steps
|
426 |
+
- guidance_scale: latent blending seems to prefer lower values here
|
427 |
+
- negative prompt: enter negative prompt here, applied for all images
|
428 |
+
|
429 |
+
## Last image crossfeeding
|
430 |
+
- branch1_crossfeed_power: Controls the level of cross-feeding between the first and last image branch. For preserving structures.
|
431 |
+
- branch1_crossfeed_range: Sets the duration of active crossfeed during development. High values enforce strong structural similarity.
|
432 |
+
- branch1_crossfeed_decay: Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
|
433 |
+
|
434 |
+
## Transition settings
|
435 |
+
- parental_crossfeed_power: Similar to branch1_crossfeed_power, however applied for the images withinin the transition.
|
436 |
+
- parental_crossfeed_range: Similar to branch1_crossfeed_range, however applied for the images withinin the transition.
|
437 |
+
- parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition.
|
438 |
+
- depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion.
|
439 |
+
- guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition.
|
440 |
+
""")
|
441 |
+
|
442 |
+
with gr.Row():
|
443 |
+
user_id = gr.Textbox(label="user id", interactive=False)
|
444 |
+
|
445 |
+
# Collect all UI elemts in list to easily pass as inputs in gradio
|
446 |
+
dict_ui_elem = {}
|
447 |
+
dict_ui_elem["prompt1"] = prompt1
|
448 |
+
dict_ui_elem["negative_prompt"] = negative_prompt
|
449 |
+
dict_ui_elem["prompt2"] = prompt2
|
450 |
+
|
451 |
+
dict_ui_elem["duration_compute"] = duration_compute
|
452 |
+
dict_ui_elem["duration_video"] = duration_video
|
453 |
+
dict_ui_elem["height"] = height
|
454 |
+
dict_ui_elem["width"] = width
|
455 |
+
|
456 |
+
dict_ui_elem["depth_strength"] = depth_strength
|
457 |
+
dict_ui_elem["branch1_crossfeed_power"] = branch1_crossfeed_power
|
458 |
+
dict_ui_elem["branch1_crossfeed_range"] = branch1_crossfeed_range
|
459 |
+
dict_ui_elem["branch1_crossfeed_decay"] = branch1_crossfeed_decay
|
460 |
+
|
461 |
+
dict_ui_elem["num_inference_steps"] = num_inference_steps
|
462 |
+
dict_ui_elem["guidance_scale"] = guidance_scale
|
463 |
+
dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
|
464 |
+
dict_ui_elem["seed1"] = seed1
|
465 |
+
dict_ui_elem["seed2"] = seed2
|
466 |
+
|
467 |
+
dict_ui_elem["parental_crossfeed_range"] = parental_crossfeed_range
|
468 |
+
dict_ui_elem["parental_crossfeed_power"] = parental_crossfeed_power
|
469 |
+
dict_ui_elem["parental_crossfeed_power_decay"] = parental_crossfeed_power_decay
|
470 |
+
dict_ui_elem["user_id"] = user_id
|
471 |
+
|
472 |
+
# Convert to list, as gradio doesn't seem to accept dicts
|
473 |
+
list_ui_vals = []
|
474 |
+
list_ui_keys = []
|
475 |
+
for k in dict_ui_elem.keys():
|
476 |
+
list_ui_vals.append(dict_ui_elem[k])
|
477 |
+
list_ui_keys.append(k)
|
478 |
+
bf.list_ui_keys = list_ui_keys
|
479 |
+
|
480 |
+
b_newseed1.click(bf.randomize_seed1, outputs=seed1)
|
481 |
+
b_newseed2.click(bf.randomize_seed2, outputs=seed2)
|
482 |
+
b_compute1.click(bf.compute_img1, inputs=list_ui_vals, outputs=[img1, img2, img3, img4, img5, user_id])
|
483 |
+
b_compute2.click(bf.compute_img2, inputs=list_ui_vals, outputs=[img2, img3, img4, img5, user_id])
|
484 |
+
b_compute_transition.click(bf.compute_transition,
|
485 |
+
inputs=list_ui_vals,
|
486 |
+
outputs=[img2, img3, img4, vid_single])
|
487 |
+
|
488 |
+
b_stackforward.click(bf.stack_forward,
|
489 |
+
inputs=[prompt2, seed2],
|
490 |
+
outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
|
491 |
+
|
492 |
+
demo.launch(share=bf.share, inbrowser=True, inline=False)
|
latent_blending.py
CHANGED
@@ -13,48 +13,31 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
|
16 |
-
import os
|
17 |
-
dp_git = "/home/lugo/git/"
|
18 |
-
sys.path.append('util')
|
19 |
-
# sys.path.append('../stablediffusion/ldm')
|
20 |
import torch
|
21 |
torch.backends.cudnn.benchmark = False
|
|
|
22 |
import numpy as np
|
23 |
import warnings
|
24 |
warnings.filterwarnings('ignore')
|
25 |
import time
|
26 |
-
import subprocess
|
27 |
import warnings
|
28 |
-
import torch
|
29 |
from tqdm.auto import tqdm
|
30 |
from PIL import Image
|
31 |
-
# import matplotlib.pyplot as plt
|
32 |
-
import torch
|
33 |
from movie_util import MovieSaver
|
34 |
-
import
|
35 |
-
from typing import Callable, List, Optional, Union
|
36 |
-
import inspect
|
37 |
-
from threading import Thread
|
38 |
-
torch.set_grad_enabled(False)
|
39 |
-
from omegaconf import OmegaConf
|
40 |
-
from torch import autocast
|
41 |
-
from contextlib import nullcontext
|
42 |
-
|
43 |
-
from ldm.models.diffusion.ddim import DDIMSampler
|
44 |
-
from ldm.util import instantiate_from_config
|
45 |
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
|
46 |
-
from stable_diffusion_holder import StableDiffusionHolder
|
47 |
-
import yaml
|
48 |
import lpips
|
49 |
-
|
|
|
|
|
50 |
class LatentBlending():
|
51 |
def __init__(
|
52 |
-
self,
|
53 |
sdh: None,
|
54 |
guidance_scale: float = 4,
|
55 |
guidance_scale_mid_damper: float = 0.5,
|
56 |
-
mid_compression_scaler: float = 1.2
|
57 |
-
):
|
58 |
r"""
|
59 |
Initializes the latent blending class.
|
60 |
Args:
|
@@ -71,9 +54,10 @@ class LatentBlending():
|
|
71 |
Increases the sampling density in the middle (where most changes happen). Higher value
|
72 |
imply more values in the middle. However the inflection point can occur outside the middle,
|
73 |
thus high values can give rough transitions. Values around 2 should be fine.
|
74 |
-
|
75 |
"""
|
76 |
-
assert guidance_scale_mid_damper>
|
|
|
|
|
77 |
|
78 |
self.sdh = sdh
|
79 |
self.device = self.sdh.device
|
@@ -81,20 +65,20 @@ class LatentBlending():
|
|
81 |
self.height = self.sdh.height
|
82 |
self.guidance_scale_mid_damper = guidance_scale_mid_damper
|
83 |
self.mid_compression_scaler = mid_compression_scaler
|
84 |
-
self.seed1 = 0
|
85 |
self.seed2 = 0
|
86 |
-
|
87 |
# Initialize vars
|
88 |
self.prompt1 = ""
|
89 |
self.prompt2 = ""
|
90 |
self.negative_prompt = ""
|
91 |
-
|
92 |
self.tree_latents = [None, None]
|
93 |
self.tree_fracts = None
|
94 |
self.idx_injection = []
|
95 |
self.tree_status = None
|
96 |
self.tree_final_imgs = []
|
97 |
-
|
98 |
self.list_nmb_branches_prev = []
|
99 |
self.list_injection_idx_prev = []
|
100 |
self.text_embedding1 = None
|
@@ -106,25 +90,23 @@ class LatentBlending():
|
|
106 |
self.noise_level_upscaling = 20
|
107 |
self.list_injection_idx = None
|
108 |
self.list_nmb_branches = None
|
109 |
-
|
110 |
# Mixing parameters
|
111 |
self.branch1_crossfeed_power = 0.1
|
112 |
self.branch1_crossfeed_range = 0.6
|
113 |
self.branch1_crossfeed_decay = 0.8
|
114 |
-
|
115 |
self.parental_crossfeed_power = 0.1
|
116 |
self.parental_crossfeed_range = 0.8
|
117 |
-
self.parental_crossfeed_power_decay = 0.8
|
118 |
-
|
119 |
self.set_guidance_scale(guidance_scale)
|
120 |
self.init_mode()
|
121 |
self.multi_transition_img_first = None
|
122 |
self.multi_transition_img_last = None
|
123 |
self.dt_per_diff = 0
|
124 |
self.spatial_mask = None
|
125 |
-
|
126 |
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
127 |
-
|
128 |
|
129 |
def init_mode(self):
|
130 |
r"""
|
@@ -138,7 +120,7 @@ class LatentBlending():
|
|
138 |
self.mode = 'inpaint'
|
139 |
else:
|
140 |
self.mode = 'standard'
|
141 |
-
|
142 |
def set_guidance_scale(self, guidance_scale):
|
143 |
r"""
|
144 |
sets the guidance scale.
|
@@ -146,25 +128,24 @@ class LatentBlending():
|
|
146 |
self.guidance_scale_base = guidance_scale
|
147 |
self.guidance_scale = guidance_scale
|
148 |
self.sdh.guidance_scale = guidance_scale
|
149 |
-
|
150 |
def set_negative_prompt(self, negative_prompt):
|
151 |
r"""Set the negative prompt. Currenty only one negative prompt is supported
|
152 |
"""
|
153 |
self.negative_prompt = negative_prompt
|
154 |
self.sdh.set_negative_prompt(negative_prompt)
|
155 |
-
|
156 |
def set_guidance_mid_dampening(self, fract_mixing):
|
157 |
r"""
|
158 |
-
Tunes the guidance scale down as a linear function of fract_mixing,
|
159 |
towards 0.5 the minimum will be reached.
|
160 |
"""
|
161 |
-
mid_factor = 1 - np.abs(fract_mixing - 0.5)/ 0.5
|
162 |
-
max_guidance_reduction = self.guidance_scale_base * (1-self.guidance_scale_mid_damper) - 1
|
163 |
-
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction*mid_factor
|
164 |
self.guidance_scale = guidance_scale_effective
|
165 |
self.sdh.guidance_scale = guidance_scale_effective
|
166 |
|
167 |
-
|
168 |
def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
169 |
r"""
|
170 |
Sets the crossfeed parameters for the first branch to the last branch.
|
@@ -179,14 +160,13 @@ class LatentBlending():
|
|
179 |
self.branch1_crossfeed_power = np.clip(crossfeed_power, 0, 1)
|
180 |
self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
181 |
self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
|
182 |
-
|
183 |
-
|
184 |
def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
185 |
r"""
|
186 |
Sets the crossfeed parameters for all transition images (within the first and last branch).
|
187 |
Args:
|
188 |
crossfeed_power: float [0,1]
|
189 |
-
Controls the level of cross-feeding from the parental branches
|
190 |
crossfeed_range: float [0,1]
|
191 |
Sets the duration of active crossfeed during development.
|
192 |
crossfeed_decay: float [0,1]
|
@@ -196,7 +176,6 @@ class LatentBlending():
|
|
196 |
self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
197 |
self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1)
|
198 |
|
199 |
-
|
200 |
def set_prompt1(self, prompt: str):
|
201 |
r"""
|
202 |
Sets the first prompt (for the first keyframe) including text embeddings.
|
@@ -207,8 +186,7 @@ class LatentBlending():
|
|
207 |
prompt = prompt.replace("_", " ")
|
208 |
self.prompt1 = prompt
|
209 |
self.text_embedding1 = self.get_text_embeddings(self.prompt1)
|
210 |
-
|
211 |
-
|
212 |
def set_prompt2(self, prompt: str):
|
213 |
r"""
|
214 |
Sets the second prompt (for the second keyframe) including text embeddings.
|
@@ -219,7 +197,7 @@ class LatentBlending():
|
|
219 |
prompt = prompt.replace("_", " ")
|
220 |
self.prompt2 = prompt
|
221 |
self.text_embedding2 = self.get_text_embeddings(self.prompt2)
|
222 |
-
|
223 |
def set_image1(self, image: Image):
|
224 |
r"""
|
225 |
Sets the first image (keyframe), relevant for the upscaling model transitions.
|
@@ -227,7 +205,7 @@ class LatentBlending():
|
|
227 |
image: Image
|
228 |
"""
|
229 |
self.image1_lowres = image
|
230 |
-
|
231 |
def set_image2(self, image: Image):
|
232 |
r"""
|
233 |
Sets the second image (keyframe), relevant for the upscaling model transitions.
|
@@ -235,17 +213,16 @@ class LatentBlending():
|
|
235 |
image: Image
|
236 |
"""
|
237 |
self.image2_lowres = image
|
238 |
-
|
239 |
def run_transition(
|
240 |
self,
|
241 |
-
recycle_img1: Optional[bool] = False,
|
242 |
-
recycle_img2: Optional[bool] = False,
|
243 |
num_inference_steps: Optional[int] = 30,
|
244 |
depth_strength: Optional[float] = 0.3,
|
245 |
t_compute_max_allowed: Optional[float] = None,
|
246 |
nmb_max_branches: Optional[int] = None,
|
247 |
-
fixed_seeds: Optional[List[int]] = None
|
248 |
-
):
|
249 |
r"""
|
250 |
Function for computing transitions.
|
251 |
Returns a list of transition images using spherical latent blending.
|
@@ -257,79 +234,77 @@ class LatentBlending():
|
|
257 |
num_inference_steps:
|
258 |
Number of diffusion steps. Higher values will take more compute time.
|
259 |
depth_strength:
|
260 |
-
Determines how deep the first injection will happen.
|
261 |
Deeper injections will cause (unwanted) formation of new structures,
|
262 |
more shallow values will go into alpha-blendy land.
|
263 |
t_compute_max_allowed:
|
264 |
-
Either provide t_compute_max_allowed or nmb_max_branches.
|
265 |
-
The maximum time allowed for computation. Higher values give better results but take longer.
|
266 |
nmb_max_branches: int
|
267 |
Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
|
268 |
-
results. Use this if you want to have controllable results independent
|
269 |
of your computer.
|
270 |
fixed_seeds: Optional[List[int)]:
|
271 |
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
272 |
Otherwise random seeds will be taken.
|
273 |
-
|
274 |
"""
|
275 |
-
|
276 |
# Sanity checks first
|
277 |
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
|
278 |
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
|
279 |
-
|
280 |
# Random seeds
|
281 |
if fixed_seeds is not None:
|
282 |
if fixed_seeds == 'randomize':
|
283 |
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
|
284 |
else:
|
285 |
-
assert len(fixed_seeds)==2, "Supply a list with len = 2"
|
286 |
-
|
287 |
self.seed1 = fixed_seeds[0]
|
288 |
self.seed2 = fixed_seeds[1]
|
289 |
-
|
290 |
# Ensure correct num_inference_steps in holder
|
291 |
self.num_inference_steps = num_inference_steps
|
292 |
self.sdh.num_inference_steps = num_inference_steps
|
293 |
-
|
294 |
# Compute / Recycle first image
|
295 |
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
|
296 |
list_latents1 = self.compute_latents1()
|
297 |
else:
|
298 |
list_latents1 = self.tree_latents[0]
|
299 |
-
|
300 |
# Compute / Recycle first image
|
301 |
if not recycle_img2 or len(self.tree_latents[-1]) != self.num_inference_steps:
|
302 |
list_latents2 = self.compute_latents2()
|
303 |
else:
|
304 |
list_latents2 = self.tree_latents[-1]
|
305 |
-
|
306 |
# Reset the tree, injecting the edge latents1/2 we just generated/recycled
|
307 |
-
self.tree_latents = [list_latents1, list_latents2]
|
308 |
self.tree_fracts = [0.0, 1.0]
|
309 |
self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
|
310 |
self.tree_idx_injection = [0, 0]
|
311 |
-
|
312 |
# Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP...
|
313 |
self.spatial_mask = None
|
314 |
-
|
315 |
# Set up branching scheme (dependent on provided compute time)
|
316 |
list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
|
317 |
|
318 |
-
# Run iteratively, starting with the longest trajectory.
|
319 |
# Always inserting new branches where they are needed most according to image similarity
|
320 |
for s_idx in tqdm(range(len(list_idx_injection))):
|
321 |
nmb_stems = list_nmb_stems[s_idx]
|
322 |
idx_injection = list_idx_injection[s_idx]
|
323 |
-
|
324 |
for i in range(nmb_stems):
|
325 |
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
|
326 |
self.set_guidance_mid_dampening(fract_mixing)
|
327 |
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
|
328 |
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
|
329 |
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
|
330 |
-
|
331 |
return self.tree_final_imgs
|
332 |
-
|
333 |
|
334 |
def compute_latents1(self, return_image=False):
|
335 |
r"""
|
@@ -343,18 +318,17 @@ class LatentBlending():
|
|
343 |
t0 = time.time()
|
344 |
latents_start = self.get_noise(self.seed1)
|
345 |
list_latents1 = self.run_diffusion(
|
346 |
-
list_conditionings,
|
347 |
-
latents_start
|
348 |
-
idx_start
|
349 |
-
)
|
350 |
t1 = time.time()
|
351 |
-
self.dt_per_diff = (t1-t0) / self.num_inference_steps
|
352 |
self.tree_latents[0] = list_latents1
|
353 |
if return_image:
|
354 |
return self.sdh.latent2image(list_latents1[-1])
|
355 |
else:
|
356 |
return list_latents1
|
357 |
-
|
358 |
def compute_latents2(self, return_image=False):
|
359 |
r"""
|
360 |
Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
|
@@ -368,28 +342,26 @@ class LatentBlending():
|
|
368 |
# Influence from branch1
|
369 |
if self.branch1_crossfeed_power > 0.0:
|
370 |
# Set up the mixing_coeffs
|
371 |
-
idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_crossfeed_range))
|
372 |
-
mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power*self.branch1_crossfeed_decay, idx_mixing_stop))
|
373 |
-
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
|
374 |
list_latents_mixing = self.tree_latents[0]
|
375 |
list_latents2 = self.run_diffusion(
|
376 |
-
list_conditionings,
|
377 |
-
latents_start
|
378 |
-
idx_start
|
379 |
-
list_latents_mixing
|
380 |
-
mixing_coeffs
|
381 |
-
)
|
382 |
else:
|
383 |
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
384 |
self.tree_latents[-1] = list_latents2
|
385 |
-
|
386 |
if return_image:
|
387 |
return self.sdh.latent2image(list_latents2[-1])
|
388 |
else:
|
389 |
-
return list_latents2
|
390 |
|
391 |
-
|
392 |
-
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
393 |
r"""
|
394 |
Runs a diffusion trajectory, using the latents from the respective parents
|
395 |
Args:
|
@@ -403,9 +375,9 @@ class LatentBlending():
|
|
403 |
the index in terms of diffusion steps, where the next insertion will start.
|
404 |
"""
|
405 |
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
406 |
-
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
|
407 |
# idx_reversed = self.num_inference_steps - idx_injection
|
408 |
-
|
409 |
list_latents_parental_mix = []
|
410 |
for i in range(self.num_inference_steps):
|
411 |
latents_p1 = self.tree_latents[b_parent1][i]
|
@@ -416,22 +388,19 @@ class LatentBlending():
|
|
416 |
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
417 |
list_latents_parental_mix.append(latents_parental)
|
418 |
|
419 |
-
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_crossfeed_range))
|
420 |
-
mixing_coeffs = idx_injection*[self.parental_crossfeed_power]
|
421 |
nmb_mixing = idx_mixing_stop - idx_injection
|
422 |
if nmb_mixing > 0:
|
423 |
-
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power*self.parental_crossfeed_power_decay, nmb_mixing)))
|
424 |
-
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
|
425 |
-
|
426 |
-
latents_start = list_latents_parental_mix[idx_injection-1]
|
427 |
list_latents = self.run_diffusion(
|
428 |
-
list_conditionings,
|
429 |
-
latents_start
|
430 |
-
idx_start
|
431 |
-
list_latents_mixing
|
432 |
-
mixing_coeffs
|
433 |
-
)
|
434 |
-
|
435 |
return list_latents
|
436 |
|
437 |
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
|
@@ -441,48 +410,46 @@ class LatentBlending():
|
|
441 |
Either provide t_compute_max_allowed or nmb_max_branches
|
442 |
Args:
|
443 |
depth_strength:
|
444 |
-
Determines how deep the first injection will happen.
|
445 |
Deeper injections will cause (unwanted) formation of new structures,
|
446 |
more shallow values will go into alpha-blendy land.
|
447 |
t_compute_max_allowed: float
|
448 |
The maximum time allowed for computation. Higher values give better results
|
449 |
-
but take longer. Use this if you want to fix your waiting time for the results.
|
450 |
nmb_max_branches: int
|
451 |
The maximum number of branches to be computed. Higher values give better
|
452 |
-
results. Use this if you want to have controllable results independent
|
453 |
of your computer.
|
454 |
"""
|
455 |
-
idx_injection_base = int(round(self.num_inference_steps*depth_strength))
|
456 |
-
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3)
|
457 |
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
458 |
t_compute = 0
|
459 |
-
|
460 |
if nmb_max_branches is None:
|
461 |
assert t_compute_max_allowed is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
462 |
stop_criterion = "t_compute_max_allowed"
|
463 |
elif t_compute_max_allowed is None:
|
464 |
assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
465 |
stop_criterion = "nmb_max_branches"
|
466 |
-
nmb_max_branches -= 2
|
467 |
else:
|
468 |
raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
|
469 |
-
|
470 |
stop_criterion_reached = False
|
471 |
is_first_iteration = True
|
472 |
-
|
473 |
while not stop_criterion_reached:
|
474 |
list_compute_steps = self.num_inference_steps - list_idx_injection
|
475 |
list_compute_steps *= list_nmb_stems
|
476 |
-
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
|
477 |
increase_done = False
|
478 |
-
for s_idx in range(len(list_nmb_stems)-1):
|
479 |
-
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
|
480 |
list_nmb_stems[s_idx] += 1
|
481 |
increase_done = True
|
482 |
break
|
483 |
if not increase_done:
|
484 |
list_nmb_stems[-1] += 1
|
485 |
-
|
486 |
if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed:
|
487 |
stop_criterion_reached = True
|
488 |
elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches:
|
@@ -493,7 +460,7 @@ class LatentBlending():
|
|
493 |
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
494 |
else:
|
495 |
is_first_iteration = False
|
496 |
-
|
497 |
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
|
498 |
return list_idx_injection, list_nmb_stems
|
499 |
|
@@ -508,13 +475,13 @@ class LatentBlending():
|
|
508 |
"""
|
509 |
# get_lpips_similarity
|
510 |
similarities = []
|
511 |
-
for i in range(len(self.tree_final_imgs)-1):
|
512 |
-
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i+1]))
|
513 |
b_closest1 = np.argmax(similarities)
|
514 |
-
b_closest2 = b_closest1+1
|
515 |
fract_closest1 = self.tree_fracts[b_closest1]
|
516 |
fract_closest2 = self.tree_fracts[b_closest2]
|
517 |
-
|
518 |
# Ensure that the parents are indeed older!
|
519 |
b_parent1 = b_closest1
|
520 |
while True:
|
@@ -522,23 +489,15 @@ class LatentBlending():
|
|
522 |
break
|
523 |
else:
|
524 |
b_parent1 -= 1
|
525 |
-
|
526 |
b_parent2 = b_closest2
|
527 |
while True:
|
528 |
if self.tree_idx_injection[b_parent2] < idx_injection:
|
529 |
break
|
530 |
else:
|
531 |
b_parent2 += 1
|
532 |
-
|
533 |
-
# print(f"\n\nb_closest: {b_closest1} {b_closest2} fract_closest1 {fract_closest1} fract_closest2 {fract_closest2}")
|
534 |
-
# print(f"b_parent: {b_parent1} {b_parent2}")
|
535 |
-
# print(f"similarities {similarities}")
|
536 |
-
# print(f"idx_injection {idx_injection} tree_idx_injection {self.tree_idx_injection}")
|
537 |
-
|
538 |
-
fract_mixing = (fract_closest1 + fract_closest2) /2
|
539 |
return fract_mixing, b_parent1, b_parent2
|
540 |
-
|
541 |
-
|
542 |
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
|
543 |
r"""
|
544 |
Inserts all necessary parameters into the trajectory tree.
|
@@ -550,31 +509,28 @@ class LatentBlending():
|
|
550 |
list_latents: list
|
551 |
list of the latents to be inserted
|
552 |
"""
|
553 |
-
b_parent1, b_parent2 = get_closest_idx(fract_mixing
|
554 |
-
self.tree_latents.insert(b_parent1+1, list_latents)
|
555 |
-
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
|
556 |
-
self.tree_fracts.insert(b_parent1+1, fract_mixing)
|
557 |
-
self.tree_idx_injection.insert(b_parent1+1, idx_injection)
|
558 |
-
|
559 |
-
|
560 |
-
def get_spatial_mask_template(self):
|
561 |
r"""
|
562 |
-
Experimental helper function to get a spatial mask template.
|
563 |
"""
|
564 |
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
565 |
C, H, W = shape_latents
|
566 |
return np.ones((H, W))
|
567 |
-
|
568 |
def set_spatial_mask(self, img_mask):
|
569 |
r"""
|
570 |
-
Experimental helper function to set a spatial mask.
|
571 |
The mask forces latents to be overwritten.
|
572 |
Args:
|
573 |
-
img_mask:
|
574 |
mask image [0,1]. You can get a template using get_spatial_mask_template
|
575 |
-
|
576 |
"""
|
577 |
-
|
578 |
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
579 |
C, H, W = shape_latents
|
580 |
img_mask = np.asarray(img_mask)
|
@@ -584,18 +540,15 @@ class LatentBlending():
|
|
584 |
assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
|
585 |
spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
|
586 |
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
587 |
-
spatial_mask = spatial_mask.repeat((C,1,1))
|
588 |
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
589 |
-
|
590 |
self.spatial_mask = spatial_mask
|
591 |
-
|
592 |
-
|
593 |
def get_noise(self, seed):
|
594 |
r"""
|
595 |
Helper function to get noise given seed.
|
596 |
Args:
|
597 |
seed: int
|
598 |
-
|
599 |
"""
|
600 |
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
|
601 |
if self.mode == 'standard':
|
@@ -606,87 +559,81 @@ class LatentBlending():
|
|
606 |
h = self.image1_lowres.size[1]
|
607 |
shape_latents = [self.sdh.model.channels, h, w]
|
608 |
C, H, W = shape_latents
|
609 |
-
|
610 |
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
|
611 |
|
612 |
-
|
613 |
@torch.no_grad()
|
614 |
def run_diffusion(
|
615 |
-
self,
|
616 |
-
list_conditionings,
|
617 |
-
latents_start: torch.FloatTensor = None,
|
618 |
-
idx_start: int = 0,
|
619 |
-
list_latents_mixing
|
620 |
-
mixing_coeffs
|
621 |
-
return_image: Optional[bool] = False
|
622 |
-
):
|
623 |
-
|
624 |
r"""
|
625 |
Wrapper function for diffusion runners.
|
626 |
Depending on the mode, the correct one will be executed.
|
627 |
-
|
628 |
Args:
|
629 |
list_conditionings: list
|
630 |
List of all conditionings for the diffusion model.
|
631 |
-
latents_start: torch.FloatTensor
|
632 |
Latents that are used for injection
|
633 |
idx_start: int
|
634 |
Index of the diffusion process start and where the latents_for_injection are injected
|
635 |
-
list_latents_mixing: torch.FloatTensor
|
636 |
List of latents (latent trajectories) that are used for mixing
|
637 |
mixing_coeffs: float or list
|
638 |
Coefficients, how strong each element of list_latents_mixing will be mixed in.
|
639 |
return_image: Optional[bool]
|
640 |
Optionally return image directly
|
641 |
"""
|
642 |
-
|
643 |
# Ensure correct num_inference_steps in Holder
|
644 |
self.sdh.num_inference_steps = self.num_inference_steps
|
645 |
assert type(list_conditionings) is list, "list_conditionings need to be a list"
|
646 |
-
|
647 |
if self.mode == 'standard':
|
648 |
text_embeddings = list_conditionings[0]
|
649 |
return self.sdh.run_diffusion_standard(
|
650 |
-
text_embeddings
|
651 |
-
latents_start
|
652 |
-
idx_start
|
653 |
-
list_latents_mixing
|
654 |
-
mixing_coeffs
|
655 |
-
spatial_mask
|
656 |
-
return_image
|
657 |
-
|
658 |
-
|
659 |
elif self.mode == 'upscale':
|
660 |
cond = list_conditionings[0]
|
661 |
uc_full = list_conditionings[1]
|
662 |
return self.sdh.run_diffusion_upscaling(
|
663 |
-
cond,
|
664 |
-
uc_full,
|
665 |
-
latents_start=latents_start,
|
666 |
-
idx_start=idx_start,
|
667 |
-
list_latents_mixing
|
668 |
-
mixing_coeffs
|
669 |
return_image=return_image)
|
670 |
|
671 |
-
|
672 |
def run_upscaling(
|
673 |
-
self,
|
674 |
dp_img: str,
|
675 |
depth_strength: float = 0.65,
|
676 |
num_inference_steps: int = 100,
|
677 |
nmb_max_branches_highres: int = 5,
|
678 |
nmb_max_branches_lowres: int = 6,
|
679 |
-
duration_single_segment
|
680 |
-
|
681 |
-
):
|
682 |
r"""
|
683 |
Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
|
684 |
-
|
685 |
Args:
|
686 |
dp_img: str
|
687 |
Path to the low-res transition path (as saved in write_imgs_transition)
|
688 |
depth_strength:
|
689 |
-
Determines how deep the first injection will happen.
|
690 |
Deeper injections will cause (unwanted) formation of new structures,
|
691 |
more shallow values will go into alpha-blendy land.
|
692 |
num_inference_steps:
|
@@ -699,68 +646,59 @@ class LatentBlending():
|
|
699 |
Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
|
700 |
duration_single_segment: float
|
701 |
The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
|
|
|
|
|
702 |
fixed_seeds: Optional[List[int)]:
|
703 |
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
704 |
Otherwise random seeds will be taken.
|
705 |
"""
|
706 |
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
707 |
fp_movie = os.path.join(dp_img, "movie_highres.mp4")
|
708 |
-
fps = 24
|
709 |
ms = MovieSaver(fp_movie, fps=fps)
|
710 |
assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
|
711 |
dict_stuff = yml_load(fp_yml)
|
712 |
-
|
713 |
# load lowres images
|
714 |
nmb_images_lowres = dict_stuff['nmb_images']
|
715 |
prompt1 = dict_stuff['prompt1']
|
716 |
prompt2 = dict_stuff['prompt2']
|
717 |
-
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres-1, nmb_max_branches_lowres)).astype(np.int32)
|
718 |
imgs_lowres = []
|
719 |
for i in idx_img_lowres:
|
720 |
fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
|
721 |
assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
|
722 |
imgs_lowres.append(Image.open(fp_img_lowres))
|
723 |
-
|
724 |
|
725 |
# set up upscaling
|
726 |
text_embeddingA = self.sdh.get_text_embedding(prompt1)
|
727 |
text_embeddingB = self.sdh.get_text_embedding(prompt2)
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
for i in range(nmb_max_branches_lowres-1):
|
732 |
print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
|
733 |
-
|
734 |
self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
|
735 |
-
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1-list_fract_mixing[i])
|
736 |
-
|
737 |
-
|
738 |
-
recycle_img1 = False
|
739 |
else:
|
740 |
self.swap_forward()
|
741 |
-
recycle_img1 = True
|
742 |
-
|
743 |
self.set_image1(imgs_lowres[i])
|
744 |
-
self.set_image2(imgs_lowres[i+1])
|
745 |
-
|
746 |
list_imgs = self.run_transition(
|
747 |
-
recycle_img1
|
748 |
-
recycle_img2
|
749 |
-
num_inference_steps
|
750 |
-
depth_strength
|
751 |
-
nmb_max_branches
|
752 |
-
)
|
753 |
-
|
754 |
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
|
755 |
-
|
756 |
# Save movie frame
|
757 |
for img in list_imgs_interp:
|
758 |
ms.write_frame(img)
|
759 |
-
|
760 |
ms.finalize()
|
761 |
-
|
762 |
|
763 |
-
|
764 |
@torch.no_grad()
|
765 |
def get_mixed_conditioning(self, fract_mixing):
|
766 |
if self.mode == 'standard':
|
@@ -782,9 +720,8 @@ class LatentBlending():
|
|
782 |
|
783 |
@torch.no_grad()
|
784 |
def get_text_embeddings(
|
785 |
-
self,
|
786 |
-
prompt: str
|
787 |
-
):
|
788 |
r"""
|
789 |
Computes the text embeddings provided a string with a prompts.
|
790 |
Adapted from stable diffusion repo
|
@@ -792,9 +729,7 @@ class LatentBlending():
|
|
792 |
prompt: str
|
793 |
ABC trending on artstation painted by Old Greg.
|
794 |
"""
|
795 |
-
|
796 |
return self.sdh.get_text_embedding(prompt)
|
797 |
-
|
798 |
|
799 |
def write_imgs_transition(self, dp_img):
|
800 |
r"""
|
@@ -809,10 +744,9 @@ class LatentBlending():
|
|
809 |
for i, img in enumerate(imgs_transition):
|
810 |
img_leaf = Image.fromarray(img)
|
811 |
img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
|
812 |
-
|
813 |
-
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
814 |
self.save_statedict(fp_yml)
|
815 |
-
|
816 |
def write_movie_transition(self, fp_movie, duration_transition, fps=30):
|
817 |
r"""
|
818 |
Writes the transition movie to fp_movie, using the given duration and fps..
|
@@ -824,9 +758,8 @@ class LatentBlending():
|
|
824 |
duration of the movie in seonds
|
825 |
fps: int
|
826 |
fps of the movie
|
827 |
-
|
828 |
"""
|
829 |
-
|
830 |
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
|
831 |
imgs_transition_ext = add_frames_linear_interp(self.tree_final_imgs, duration_transition, fps)
|
832 |
|
@@ -838,15 +771,13 @@ class LatentBlending():
|
|
838 |
ms.write_frame(img)
|
839 |
ms.finalize()
|
840 |
|
841 |
-
|
842 |
-
|
843 |
def save_statedict(self, fp_yml):
|
844 |
# Dump everything relevant into yaml
|
845 |
imgs_transition = self.tree_final_imgs
|
846 |
state_dict = self.get_state_dict()
|
847 |
state_dict['nmb_images'] = len(imgs_transition)
|
848 |
yml_save(fp_yml, state_dict)
|
849 |
-
|
850 |
def get_state_dict(self):
|
851 |
state_dict = {}
|
852 |
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
|
@@ -860,391 +791,94 @@ class LatentBlending():
|
|
860 |
state_dict[v] = int(getattr(self, v))
|
861 |
elif v == 'guidance_scale':
|
862 |
state_dict[v] = float(getattr(self, v))
|
863 |
-
|
864 |
else:
|
865 |
try:
|
866 |
state_dict[v] = getattr(self, v)
|
867 |
-
except Exception
|
868 |
pass
|
869 |
-
|
870 |
return state_dict
|
871 |
-
|
872 |
def randomize_seed(self):
|
873 |
r"""
|
874 |
Set a random seed for a fresh start.
|
875 |
-
"""
|
876 |
seed = np.random.randint(999999999)
|
877 |
self.set_seed(seed)
|
878 |
-
|
879 |
def set_seed(self, seed: int):
|
880 |
r"""
|
881 |
Set a the seed for a fresh start.
|
882 |
-
"""
|
883 |
self.seed = seed
|
884 |
self.sdh.seed = seed
|
885 |
-
|
886 |
def set_width(self, width):
|
887 |
r"""
|
888 |
Set the width of the resulting image.
|
889 |
-
"""
|
890 |
assert np.mod(width, 64) == 0, "set_width: value needs to be divisible by 64"
|
891 |
self.width = width
|
892 |
self.sdh.width = width
|
893 |
-
|
894 |
def set_height(self, height):
|
895 |
r"""
|
896 |
Set the height of the resulting image.
|
897 |
-
"""
|
898 |
assert np.mod(height, 64) == 0, "set_height: value needs to be divisible by 64"
|
899 |
self.height = height
|
900 |
self.sdh.height = height
|
901 |
-
|
902 |
|
903 |
def swap_forward(self):
|
904 |
r"""
|
905 |
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
|
906 |
as in run_multi_transition()
|
907 |
-
"""
|
908 |
# Move over all latents
|
909 |
self.tree_latents[0] = self.tree_latents[-1]
|
910 |
-
|
911 |
# Move over prompts and text embeddings
|
912 |
self.prompt1 = self.prompt2
|
913 |
self.text_embedding1 = self.text_embedding2
|
914 |
-
|
915 |
# Final cleanup for extra sanity
|
916 |
-
self.tree_final_imgs = []
|
917 |
-
|
918 |
-
|
919 |
def get_lpips_similarity(self, imgA, imgB):
|
920 |
r"""
|
921 |
-
Computes the image similarity between two images imgA and imgB.
|
922 |
Used to determine the optimal point of insertion to create smooth transitions.
|
923 |
High values indicate low similarity.
|
924 |
-
"""
|
925 |
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
|
926 |
-
tensorA = 2*tensorA/255.0 - 1
|
927 |
-
tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
|
928 |
-
|
929 |
tensorB = torch.from_numpy(imgB).float().cuda(self.device)
|
930 |
-
tensorB = 2*tensorB/255.0 - 1
|
931 |
-
tensorB = tensorB.permute([2,0,1]).unsqueeze(0)
|
932 |
lploss = self.lpips(tensorA, tensorB)
|
933 |
lploss = float(lploss[0][0][0][0])
|
934 |
-
|
935 |
return lploss
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
947 |
-
|
948 |
-
|
949 |
-
|
950 |
-
|
951 |
-
|
952 |
-
|
953 |
-
|
954 |
-
|
955 |
-
|
956 |
-
|
957 |
-
|
958 |
-
|
959 |
-
|
960 |
-
b_parent1
|
961 |
-
|
962 |
-
return b_parent1, b_parent2
|
963 |
-
|
964 |
-
@torch.no_grad()
|
965 |
-
def interpolate_spherical(p0, p1, fract_mixing: float):
|
966 |
-
r"""
|
967 |
-
Helper function to correctly mix two random variables using spherical interpolation.
|
968 |
-
See https://en.wikipedia.org/wiki/Slerp
|
969 |
-
The function will always cast up to float64 for sake of extra 4.
|
970 |
-
Args:
|
971 |
-
p0:
|
972 |
-
First tensor for interpolation
|
973 |
-
p1:
|
974 |
-
Second tensor for interpolation
|
975 |
-
fract_mixing: float
|
976 |
-
Mixing coefficient of interval [0, 1].
|
977 |
-
0 will return in p0
|
978 |
-
1 will return in p1
|
979 |
-
0.x will return a mix between both preserving angular velocity.
|
980 |
-
"""
|
981 |
-
|
982 |
-
if p0.dtype == torch.float16:
|
983 |
-
recast_to = 'fp16'
|
984 |
-
else:
|
985 |
-
recast_to = 'fp32'
|
986 |
-
|
987 |
-
p0 = p0.double()
|
988 |
-
p1 = p1.double()
|
989 |
-
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
990 |
-
epsilon = 1e-7
|
991 |
-
dot = torch.sum(p0 * p1) / norm
|
992 |
-
dot = dot.clamp(-1+epsilon, 1-epsilon)
|
993 |
-
|
994 |
-
theta_0 = torch.arccos(dot)
|
995 |
-
sin_theta_0 = torch.sin(theta_0)
|
996 |
-
theta_t = theta_0 * fract_mixing
|
997 |
-
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
998 |
-
s1 = torch.sin(theta_t) / sin_theta_0
|
999 |
-
interp = p0*s0 + p1*s1
|
1000 |
-
|
1001 |
-
if recast_to == 'fp16':
|
1002 |
-
interp = interp.half()
|
1003 |
-
elif recast_to == 'fp32':
|
1004 |
-
interp = interp.float()
|
1005 |
-
|
1006 |
-
return interp
|
1007 |
-
|
1008 |
-
|
1009 |
-
def interpolate_linear(p0, p1, fract_mixing):
|
1010 |
-
r"""
|
1011 |
-
Helper function to mix two variables using standard linear interpolation.
|
1012 |
-
Args:
|
1013 |
-
p0:
|
1014 |
-
First tensor / np.ndarray for interpolation
|
1015 |
-
p1:
|
1016 |
-
Second tensor / np.ndarray for interpolation
|
1017 |
-
fract_mixing: float
|
1018 |
-
Mixing coefficient of interval [0, 1].
|
1019 |
-
0 will return in p0
|
1020 |
-
1 will return in p1
|
1021 |
-
0.x will return a linear mix between both.
|
1022 |
-
"""
|
1023 |
-
reconvert_uint8 = False
|
1024 |
-
if type(p0) is np.ndarray and p0.dtype == 'uint8':
|
1025 |
-
reconvert_uint8 = True
|
1026 |
-
p0 = p0.astype(np.float64)
|
1027 |
-
|
1028 |
-
if type(p1) is np.ndarray and p1.dtype == 'uint8':
|
1029 |
-
reconvert_uint8 = True
|
1030 |
-
p1 = p1.astype(np.float64)
|
1031 |
-
|
1032 |
-
interp = (1-fract_mixing) * p0 + fract_mixing * p1
|
1033 |
-
|
1034 |
-
if reconvert_uint8:
|
1035 |
-
interp = np.clip(interp, 0, 255).astype(np.uint8)
|
1036 |
-
|
1037 |
-
return interp
|
1038 |
-
|
1039 |
-
|
1040 |
-
def add_frames_linear_interp(
|
1041 |
-
list_imgs: List[np.ndarray],
|
1042 |
-
fps_target: Union[float, int] = None,
|
1043 |
-
duration_target: Union[float, int] = None,
|
1044 |
-
nmb_frames_target: int=None,
|
1045 |
-
):
|
1046 |
-
r"""
|
1047 |
-
Helper function to cheaply increase the number of frames given a list of images,
|
1048 |
-
by virtue of standard linear interpolation.
|
1049 |
-
The number of inserted frames will be automatically adjusted so that the total of number
|
1050 |
-
of frames can be fixed precisely, using a random shuffling technique.
|
1051 |
-
The function allows 1:1 comparisons between transitions as videos.
|
1052 |
-
|
1053 |
-
Args:
|
1054 |
-
list_imgs: List[np.ndarray)
|
1055 |
-
List of images, between each image new frames will be inserted via linear interpolation.
|
1056 |
-
fps_target:
|
1057 |
-
OptionA: specify here the desired frames per second.
|
1058 |
-
duration_target:
|
1059 |
-
OptionA: specify here the desired duration of the transition in seconds.
|
1060 |
-
nmb_frames_target:
|
1061 |
-
OptionB: directly fix the total number of frames of the output.
|
1062 |
-
"""
|
1063 |
-
|
1064 |
-
# Sanity
|
1065 |
-
if nmb_frames_target is not None and fps_target is not None:
|
1066 |
-
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
|
1067 |
-
if fps_target is None:
|
1068 |
-
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
|
1069 |
-
if nmb_frames_target is None:
|
1070 |
-
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
1071 |
-
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
1072 |
-
nmb_frames_target = fps_target*duration_target
|
1073 |
-
|
1074 |
-
# Get number of frames that are missing
|
1075 |
-
nmb_frames_diff = len(list_imgs)-1
|
1076 |
-
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
|
1077 |
-
|
1078 |
-
if nmb_frames_missing < 1:
|
1079 |
-
return list_imgs
|
1080 |
-
|
1081 |
-
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
|
1082 |
-
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
|
1083 |
-
mean_nmb_frames_insert = nmb_frames_missing/nmb_frames_diff
|
1084 |
-
constfact = np.floor(mean_nmb_frames_insert)
|
1085 |
-
remainder_x = 1-(mean_nmb_frames_insert - constfact)
|
1086 |
-
|
1087 |
-
nmb_iter = 0
|
1088 |
-
while True:
|
1089 |
-
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
|
1090 |
-
nmb_frames_to_insert[nmb_frames_to_insert<=remainder_x] = 0
|
1091 |
-
nmb_frames_to_insert[nmb_frames_to_insert>remainder_x] = 1
|
1092 |
-
nmb_frames_to_insert += constfact
|
1093 |
-
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
|
1094 |
-
break
|
1095 |
-
nmb_iter += 1
|
1096 |
-
if nmb_iter > 100000:
|
1097 |
-
print("add_frames_linear_interp: issue with inserting the right number of frames")
|
1098 |
-
break
|
1099 |
-
|
1100 |
-
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
|
1101 |
-
list_imgs_interp = []
|
1102 |
-
for i in range(len(list_imgs_float)-1):#, desc="STAGE linear interp"):
|
1103 |
-
img0 = list_imgs_float[i]
|
1104 |
-
img1 = list_imgs_float[i+1]
|
1105 |
-
list_imgs_interp.append(img0.astype(np.uint8))
|
1106 |
-
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i]+2)[1:-1]
|
1107 |
-
for fract_linblend in list_fracts_linblend:
|
1108 |
-
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
|
1109 |
-
list_imgs_interp.append(img_blend.astype(np.uint8))
|
1110 |
-
|
1111 |
-
if i==len(list_imgs_float)-2:
|
1112 |
-
list_imgs_interp.append(img1.astype(np.uint8))
|
1113 |
-
|
1114 |
-
return list_imgs_interp
|
1115 |
-
|
1116 |
-
|
1117 |
-
def get_spacing(nmb_points: int, scaling: float):
|
1118 |
-
"""
|
1119 |
-
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
|
1120 |
-
Args:
|
1121 |
-
nmb_points: int
|
1122 |
-
Number of points between [0, 1]
|
1123 |
-
scaling: float
|
1124 |
-
Higher values will return higher sampling density around 0.5
|
1125 |
-
|
1126 |
-
"""
|
1127 |
-
if scaling < 1.7:
|
1128 |
-
return np.linspace(0, 1, nmb_points)
|
1129 |
-
nmb_points_per_side = nmb_points//2 + 1
|
1130 |
-
if np.mod(nmb_points, 2) != 0: # uneven case
|
1131 |
-
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
|
1132 |
-
right_side = 1-left_side[::-1][1:]
|
1133 |
-
else:
|
1134 |
-
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
|
1135 |
-
right_side = 1-left_side[::-1]
|
1136 |
-
all_fracts = np.hstack([left_side, right_side])
|
1137 |
-
return all_fracts
|
1138 |
-
|
1139 |
-
|
1140 |
-
def get_time(resolution=None):
|
1141 |
-
"""
|
1142 |
-
Helper function returning an nicely formatted time string, e.g. 221117_1620
|
1143 |
-
"""
|
1144 |
-
if resolution==None:
|
1145 |
-
resolution="second"
|
1146 |
-
if resolution == "day":
|
1147 |
-
t = time.strftime('%y%m%d', time.localtime())
|
1148 |
-
elif resolution == "minute":
|
1149 |
-
t = time.strftime('%y%m%d_%H%M', time.localtime())
|
1150 |
-
elif resolution == "second":
|
1151 |
-
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
1152 |
-
elif resolution == "millisecond":
|
1153 |
-
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
1154 |
-
t += "_"
|
1155 |
-
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f'))/1000)))
|
1156 |
-
else:
|
1157 |
-
raise ValueError("bad resolution provided: %s" %resolution)
|
1158 |
-
return t
|
1159 |
-
|
1160 |
-
def compare_dicts(a, b):
|
1161 |
-
"""
|
1162 |
-
Compares two dictionaries a and b and returns a dictionary c, with all
|
1163 |
-
keys,values that have shared keys in a and b but same values in a and b.
|
1164 |
-
The values of a and b are stacked together in the output.
|
1165 |
-
Example:
|
1166 |
-
a = {}; a['bobo'] = 4
|
1167 |
-
b = {}; b['bobo'] = 5
|
1168 |
-
c = dict_compare(a,b)
|
1169 |
-
c = {"bobo",[4,5]}
|
1170 |
-
"""
|
1171 |
-
c = {}
|
1172 |
-
for key in a.keys():
|
1173 |
-
if key in b.keys():
|
1174 |
-
val_a = a[key]
|
1175 |
-
val_b = b[key]
|
1176 |
-
if val_a != val_b:
|
1177 |
-
c[key] = [val_a, val_b]
|
1178 |
-
return c
|
1179 |
-
|
1180 |
-
def yml_load(fp_yml, print_fields=False):
|
1181 |
-
"""
|
1182 |
-
Helper function for loading yaml files
|
1183 |
-
"""
|
1184 |
-
with open(fp_yml) as f:
|
1185 |
-
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
|
1186 |
-
dict_data = dict(data)
|
1187 |
-
print("load: loaded {}".format(fp_yml))
|
1188 |
-
return dict_data
|
1189 |
-
|
1190 |
-
def yml_save(fp_yml, dict_stuff):
|
1191 |
-
"""
|
1192 |
-
Helper function for saving yaml files
|
1193 |
-
"""
|
1194 |
-
with open(fp_yml, 'w') as f:
|
1195 |
-
data = yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
|
1196 |
-
print("yml_save: saved {}".format(fp_yml))
|
1197 |
-
|
1198 |
-
|
1199 |
-
#%% le main
|
1200 |
-
if __name__ == "__main__":
|
1201 |
-
# xxxx
|
1202 |
-
|
1203 |
-
#%% First let us spawn a stable diffusion holder
|
1204 |
-
device = "cuda"
|
1205 |
-
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
1206 |
-
|
1207 |
-
sdh = StableDiffusionHolder(fp_ckpt)
|
1208 |
-
|
1209 |
-
xxx
|
1210 |
-
|
1211 |
-
|
1212 |
-
#%% Next let's set up all parameters
|
1213 |
-
depth_strength = 0.3 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
1214 |
-
fixed_seeds = [697164, 430214]
|
1215 |
-
|
1216 |
-
prompt1 = "photo of a desert and a sky"
|
1217 |
-
prompt2 = "photo of a tree with a lake"
|
1218 |
-
|
1219 |
-
duration_transition = 12 # In seconds
|
1220 |
-
fps = 30
|
1221 |
-
|
1222 |
-
# Spawn latent blending
|
1223 |
-
self = LatentBlending(sdh)
|
1224 |
-
|
1225 |
-
self.set_prompt1(prompt1)
|
1226 |
-
self.set_prompt2(prompt2)
|
1227 |
-
|
1228 |
-
# Run latent blending
|
1229 |
-
self.branch1_crossfeed_power = 0.3
|
1230 |
-
self.branch1_crossfeed_range = 0.4
|
1231 |
-
# self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
|
1232 |
-
self.seed1=21312
|
1233 |
-
img1 =self.compute_latents1(True)
|
1234 |
-
#%
|
1235 |
-
self.seed2=1234121
|
1236 |
-
self.branch1_crossfeed_power = 0.7
|
1237 |
-
self.branch1_crossfeed_range = 0.3
|
1238 |
-
self.branch1_crossfeed_decay = 0.3
|
1239 |
-
img2 =self.compute_latents2(True)
|
1240 |
-
# Image.fromarray(np.concatenate((img1, img2), axis=1))
|
1241 |
-
|
1242 |
-
#%%
|
1243 |
-
t0 = time.time()
|
1244 |
-
self.t_compute_max_allowed = 30
|
1245 |
-
self.parental_crossfeed_range = 1.0
|
1246 |
-
self.parental_crossfeed_power = 0.0
|
1247 |
-
self.parental_crossfeed_power_decay = 1.0
|
1248 |
-
imgs_transition = self.run_transition(recycle_img1=True, recycle_img2=True)
|
1249 |
-
t1 = time.time()
|
1250 |
-
print(f"took: {t1-t0}s")
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
|
16 |
+
import os
|
|
|
|
|
|
|
17 |
import torch
|
18 |
torch.backends.cudnn.benchmark = False
|
19 |
+
torch.set_grad_enabled(False)
|
20 |
import numpy as np
|
21 |
import warnings
|
22 |
warnings.filterwarnings('ignore')
|
23 |
import time
|
|
|
24 |
import warnings
|
|
|
25 |
from tqdm.auto import tqdm
|
26 |
from PIL import Image
|
|
|
|
|
27 |
from movie_util import MovieSaver
|
28 |
+
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
|
|
|
|
|
30 |
import lpips
|
31 |
+
from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
|
32 |
+
|
33 |
+
|
34 |
class LatentBlending():
|
35 |
def __init__(
|
36 |
+
self,
|
37 |
sdh: None,
|
38 |
guidance_scale: float = 4,
|
39 |
guidance_scale_mid_damper: float = 0.5,
|
40 |
+
mid_compression_scaler: float = 1.2):
|
|
|
41 |
r"""
|
42 |
Initializes the latent blending class.
|
43 |
Args:
|
|
|
54 |
Increases the sampling density in the middle (where most changes happen). Higher value
|
55 |
imply more values in the middle. However the inflection point can occur outside the middle,
|
56 |
thus high values can give rough transitions. Values around 2 should be fine.
|
|
|
57 |
"""
|
58 |
+
assert guidance_scale_mid_damper > 0 \
|
59 |
+
and guidance_scale_mid_damper <= 1.0, \
|
60 |
+
f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
|
61 |
|
62 |
self.sdh = sdh
|
63 |
self.device = self.sdh.device
|
|
|
65 |
self.height = self.sdh.height
|
66 |
self.guidance_scale_mid_damper = guidance_scale_mid_damper
|
67 |
self.mid_compression_scaler = mid_compression_scaler
|
68 |
+
self.seed1 = 0
|
69 |
self.seed2 = 0
|
70 |
+
|
71 |
# Initialize vars
|
72 |
self.prompt1 = ""
|
73 |
self.prompt2 = ""
|
74 |
self.negative_prompt = ""
|
75 |
+
|
76 |
self.tree_latents = [None, None]
|
77 |
self.tree_fracts = None
|
78 |
self.idx_injection = []
|
79 |
self.tree_status = None
|
80 |
self.tree_final_imgs = []
|
81 |
+
|
82 |
self.list_nmb_branches_prev = []
|
83 |
self.list_injection_idx_prev = []
|
84 |
self.text_embedding1 = None
|
|
|
90 |
self.noise_level_upscaling = 20
|
91 |
self.list_injection_idx = None
|
92 |
self.list_nmb_branches = None
|
93 |
+
|
94 |
# Mixing parameters
|
95 |
self.branch1_crossfeed_power = 0.1
|
96 |
self.branch1_crossfeed_range = 0.6
|
97 |
self.branch1_crossfeed_decay = 0.8
|
98 |
+
|
99 |
self.parental_crossfeed_power = 0.1
|
100 |
self.parental_crossfeed_range = 0.8
|
101 |
+
self.parental_crossfeed_power_decay = 0.8
|
102 |
+
|
103 |
self.set_guidance_scale(guidance_scale)
|
104 |
self.init_mode()
|
105 |
self.multi_transition_img_first = None
|
106 |
self.multi_transition_img_last = None
|
107 |
self.dt_per_diff = 0
|
108 |
self.spatial_mask = None
|
|
|
109 |
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
|
|
110 |
|
111 |
def init_mode(self):
|
112 |
r"""
|
|
|
120 |
self.mode = 'inpaint'
|
121 |
else:
|
122 |
self.mode = 'standard'
|
123 |
+
|
124 |
def set_guidance_scale(self, guidance_scale):
|
125 |
r"""
|
126 |
sets the guidance scale.
|
|
|
128 |
self.guidance_scale_base = guidance_scale
|
129 |
self.guidance_scale = guidance_scale
|
130 |
self.sdh.guidance_scale = guidance_scale
|
131 |
+
|
132 |
def set_negative_prompt(self, negative_prompt):
|
133 |
r"""Set the negative prompt. Currenty only one negative prompt is supported
|
134 |
"""
|
135 |
self.negative_prompt = negative_prompt
|
136 |
self.sdh.set_negative_prompt(negative_prompt)
|
137 |
+
|
138 |
def set_guidance_mid_dampening(self, fract_mixing):
|
139 |
r"""
|
140 |
+
Tunes the guidance scale down as a linear function of fract_mixing,
|
141 |
towards 0.5 the minimum will be reached.
|
142 |
"""
|
143 |
+
mid_factor = 1 - np.abs(fract_mixing - 0.5) / 0.5
|
144 |
+
max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1
|
145 |
+
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor
|
146 |
self.guidance_scale = guidance_scale_effective
|
147 |
self.sdh.guidance_scale = guidance_scale_effective
|
148 |
|
|
|
149 |
def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
150 |
r"""
|
151 |
Sets the crossfeed parameters for the first branch to the last branch.
|
|
|
160 |
self.branch1_crossfeed_power = np.clip(crossfeed_power, 0, 1)
|
161 |
self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
162 |
self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
|
163 |
+
|
|
|
164 |
def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
165 |
r"""
|
166 |
Sets the crossfeed parameters for all transition images (within the first and last branch).
|
167 |
Args:
|
168 |
crossfeed_power: float [0,1]
|
169 |
+
Controls the level of cross-feeding from the parental branches
|
170 |
crossfeed_range: float [0,1]
|
171 |
Sets the duration of active crossfeed during development.
|
172 |
crossfeed_decay: float [0,1]
|
|
|
176 |
self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
177 |
self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1)
|
178 |
|
|
|
179 |
def set_prompt1(self, prompt: str):
|
180 |
r"""
|
181 |
Sets the first prompt (for the first keyframe) including text embeddings.
|
|
|
186 |
prompt = prompt.replace("_", " ")
|
187 |
self.prompt1 = prompt
|
188 |
self.text_embedding1 = self.get_text_embeddings(self.prompt1)
|
189 |
+
|
|
|
190 |
def set_prompt2(self, prompt: str):
|
191 |
r"""
|
192 |
Sets the second prompt (for the second keyframe) including text embeddings.
|
|
|
197 |
prompt = prompt.replace("_", " ")
|
198 |
self.prompt2 = prompt
|
199 |
self.text_embedding2 = self.get_text_embeddings(self.prompt2)
|
200 |
+
|
201 |
def set_image1(self, image: Image):
|
202 |
r"""
|
203 |
Sets the first image (keyframe), relevant for the upscaling model transitions.
|
|
|
205 |
image: Image
|
206 |
"""
|
207 |
self.image1_lowres = image
|
208 |
+
|
209 |
def set_image2(self, image: Image):
|
210 |
r"""
|
211 |
Sets the second image (keyframe), relevant for the upscaling model transitions.
|
|
|
213 |
image: Image
|
214 |
"""
|
215 |
self.image2_lowres = image
|
216 |
+
|
217 |
def run_transition(
|
218 |
self,
|
219 |
+
recycle_img1: Optional[bool] = False,
|
220 |
+
recycle_img2: Optional[bool] = False,
|
221 |
num_inference_steps: Optional[int] = 30,
|
222 |
depth_strength: Optional[float] = 0.3,
|
223 |
t_compute_max_allowed: Optional[float] = None,
|
224 |
nmb_max_branches: Optional[int] = None,
|
225 |
+
fixed_seeds: Optional[List[int]] = None):
|
|
|
226 |
r"""
|
227 |
Function for computing transitions.
|
228 |
Returns a list of transition images using spherical latent blending.
|
|
|
234 |
num_inference_steps:
|
235 |
Number of diffusion steps. Higher values will take more compute time.
|
236 |
depth_strength:
|
237 |
+
Determines how deep the first injection will happen.
|
238 |
Deeper injections will cause (unwanted) formation of new structures,
|
239 |
more shallow values will go into alpha-blendy land.
|
240 |
t_compute_max_allowed:
|
241 |
+
Either provide t_compute_max_allowed or nmb_max_branches.
|
242 |
+
The maximum time allowed for computation. Higher values give better results but take longer.
|
243 |
nmb_max_branches: int
|
244 |
Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
|
245 |
+
results. Use this if you want to have controllable results independent
|
246 |
of your computer.
|
247 |
fixed_seeds: Optional[List[int)]:
|
248 |
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
249 |
Otherwise random seeds will be taken.
|
|
|
250 |
"""
|
251 |
+
|
252 |
# Sanity checks first
|
253 |
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
|
254 |
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
|
255 |
+
|
256 |
# Random seeds
|
257 |
if fixed_seeds is not None:
|
258 |
if fixed_seeds == 'randomize':
|
259 |
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
|
260 |
else:
|
261 |
+
assert len(fixed_seeds) == 2, "Supply a list with len = 2"
|
262 |
+
|
263 |
self.seed1 = fixed_seeds[0]
|
264 |
self.seed2 = fixed_seeds[1]
|
265 |
+
|
266 |
# Ensure correct num_inference_steps in holder
|
267 |
self.num_inference_steps = num_inference_steps
|
268 |
self.sdh.num_inference_steps = num_inference_steps
|
269 |
+
|
270 |
# Compute / Recycle first image
|
271 |
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
|
272 |
list_latents1 = self.compute_latents1()
|
273 |
else:
|
274 |
list_latents1 = self.tree_latents[0]
|
275 |
+
|
276 |
# Compute / Recycle first image
|
277 |
if not recycle_img2 or len(self.tree_latents[-1]) != self.num_inference_steps:
|
278 |
list_latents2 = self.compute_latents2()
|
279 |
else:
|
280 |
list_latents2 = self.tree_latents[-1]
|
281 |
+
|
282 |
# Reset the tree, injecting the edge latents1/2 we just generated/recycled
|
283 |
+
self.tree_latents = [list_latents1, list_latents2]
|
284 |
self.tree_fracts = [0.0, 1.0]
|
285 |
self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
|
286 |
self.tree_idx_injection = [0, 0]
|
287 |
+
|
288 |
# Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP...
|
289 |
self.spatial_mask = None
|
290 |
+
|
291 |
# Set up branching scheme (dependent on provided compute time)
|
292 |
list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
|
293 |
|
294 |
+
# Run iteratively, starting with the longest trajectory.
|
295 |
# Always inserting new branches where they are needed most according to image similarity
|
296 |
for s_idx in tqdm(range(len(list_idx_injection))):
|
297 |
nmb_stems = list_nmb_stems[s_idx]
|
298 |
idx_injection = list_idx_injection[s_idx]
|
299 |
+
|
300 |
for i in range(nmb_stems):
|
301 |
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
|
302 |
self.set_guidance_mid_dampening(fract_mixing)
|
303 |
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
|
304 |
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
|
305 |
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
|
306 |
+
|
307 |
return self.tree_final_imgs
|
|
|
308 |
|
309 |
def compute_latents1(self, return_image=False):
|
310 |
r"""
|
|
|
318 |
t0 = time.time()
|
319 |
latents_start = self.get_noise(self.seed1)
|
320 |
list_latents1 = self.run_diffusion(
|
321 |
+
list_conditionings,
|
322 |
+
latents_start=latents_start,
|
323 |
+
idx_start=0)
|
|
|
324 |
t1 = time.time()
|
325 |
+
self.dt_per_diff = (t1 - t0) / self.num_inference_steps
|
326 |
self.tree_latents[0] = list_latents1
|
327 |
if return_image:
|
328 |
return self.sdh.latent2image(list_latents1[-1])
|
329 |
else:
|
330 |
return list_latents1
|
331 |
+
|
332 |
def compute_latents2(self, return_image=False):
|
333 |
r"""
|
334 |
Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
|
|
|
342 |
# Influence from branch1
|
343 |
if self.branch1_crossfeed_power > 0.0:
|
344 |
# Set up the mixing_coeffs
|
345 |
+
idx_mixing_stop = int(round(self.num_inference_steps * self.branch1_crossfeed_range))
|
346 |
+
mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power * self.branch1_crossfeed_decay, idx_mixing_stop))
|
347 |
+
mixing_coeffs.extend((self.num_inference_steps - idx_mixing_stop) * [0])
|
348 |
list_latents_mixing = self.tree_latents[0]
|
349 |
list_latents2 = self.run_diffusion(
|
350 |
+
list_conditionings,
|
351 |
+
latents_start=latents_start,
|
352 |
+
idx_start=0,
|
353 |
+
list_latents_mixing=list_latents_mixing,
|
354 |
+
mixing_coeffs=mixing_coeffs)
|
|
|
355 |
else:
|
356 |
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
357 |
self.tree_latents[-1] = list_latents2
|
358 |
+
|
359 |
if return_image:
|
360 |
return self.sdh.latent2image(list_latents2[-1])
|
361 |
else:
|
362 |
+
return list_latents2
|
363 |
|
364 |
+
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
|
|
365 |
r"""
|
366 |
Runs a diffusion trajectory, using the latents from the respective parents
|
367 |
Args:
|
|
|
375 |
the index in terms of diffusion steps, where the next insertion will start.
|
376 |
"""
|
377 |
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
378 |
+
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
|
379 |
# idx_reversed = self.num_inference_steps - idx_injection
|
380 |
+
|
381 |
list_latents_parental_mix = []
|
382 |
for i in range(self.num_inference_steps):
|
383 |
latents_p1 = self.tree_latents[b_parent1][i]
|
|
|
388 |
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
389 |
list_latents_parental_mix.append(latents_parental)
|
390 |
|
391 |
+
idx_mixing_stop = int(round(self.num_inference_steps * self.parental_crossfeed_range))
|
392 |
+
mixing_coeffs = idx_injection * [self.parental_crossfeed_power]
|
393 |
nmb_mixing = idx_mixing_stop - idx_injection
|
394 |
if nmb_mixing > 0:
|
395 |
+
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_power_decay, nmb_mixing)))
|
396 |
+
mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0])
|
397 |
+
latents_start = list_latents_parental_mix[idx_injection - 1]
|
|
|
398 |
list_latents = self.run_diffusion(
|
399 |
+
list_conditionings,
|
400 |
+
latents_start=latents_start,
|
401 |
+
idx_start=idx_injection,
|
402 |
+
list_latents_mixing=list_latents_parental_mix,
|
403 |
+
mixing_coeffs=mixing_coeffs)
|
|
|
|
|
404 |
return list_latents
|
405 |
|
406 |
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
|
|
|
410 |
Either provide t_compute_max_allowed or nmb_max_branches
|
411 |
Args:
|
412 |
depth_strength:
|
413 |
+
Determines how deep the first injection will happen.
|
414 |
Deeper injections will cause (unwanted) formation of new structures,
|
415 |
more shallow values will go into alpha-blendy land.
|
416 |
t_compute_max_allowed: float
|
417 |
The maximum time allowed for computation. Higher values give better results
|
418 |
+
but take longer. Use this if you want to fix your waiting time for the results.
|
419 |
nmb_max_branches: int
|
420 |
The maximum number of branches to be computed. Higher values give better
|
421 |
+
results. Use this if you want to have controllable results independent
|
422 |
of your computer.
|
423 |
"""
|
424 |
+
idx_injection_base = int(round(self.num_inference_steps * depth_strength))
|
425 |
+
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps - 1, 3)
|
426 |
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
427 |
t_compute = 0
|
428 |
+
|
429 |
if nmb_max_branches is None:
|
430 |
assert t_compute_max_allowed is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
431 |
stop_criterion = "t_compute_max_allowed"
|
432 |
elif t_compute_max_allowed is None:
|
433 |
assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
434 |
stop_criterion = "nmb_max_branches"
|
435 |
+
nmb_max_branches -= 2 # Discounting the outer frames
|
436 |
else:
|
437 |
raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
|
|
|
438 |
stop_criterion_reached = False
|
439 |
is_first_iteration = True
|
|
|
440 |
while not stop_criterion_reached:
|
441 |
list_compute_steps = self.num_inference_steps - list_idx_injection
|
442 |
list_compute_steps *= list_nmb_stems
|
443 |
+
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
|
444 |
increase_done = False
|
445 |
+
for s_idx in range(len(list_nmb_stems) - 1):
|
446 |
+
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2:
|
447 |
list_nmb_stems[s_idx] += 1
|
448 |
increase_done = True
|
449 |
break
|
450 |
if not increase_done:
|
451 |
list_nmb_stems[-1] += 1
|
452 |
+
|
453 |
if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed:
|
454 |
stop_criterion_reached = True
|
455 |
elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches:
|
|
|
460 |
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
461 |
else:
|
462 |
is_first_iteration = False
|
463 |
+
|
464 |
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
|
465 |
return list_idx_injection, list_nmb_stems
|
466 |
|
|
|
475 |
"""
|
476 |
# get_lpips_similarity
|
477 |
similarities = []
|
478 |
+
for i in range(len(self.tree_final_imgs) - 1):
|
479 |
+
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
|
480 |
b_closest1 = np.argmax(similarities)
|
481 |
+
b_closest2 = b_closest1 + 1
|
482 |
fract_closest1 = self.tree_fracts[b_closest1]
|
483 |
fract_closest2 = self.tree_fracts[b_closest2]
|
484 |
+
|
485 |
# Ensure that the parents are indeed older!
|
486 |
b_parent1 = b_closest1
|
487 |
while True:
|
|
|
489 |
break
|
490 |
else:
|
491 |
b_parent1 -= 1
|
|
|
492 |
b_parent2 = b_closest2
|
493 |
while True:
|
494 |
if self.tree_idx_injection[b_parent2] < idx_injection:
|
495 |
break
|
496 |
else:
|
497 |
b_parent2 += 1
|
498 |
+
fract_mixing = (fract_closest1 + fract_closest2) / 2
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
return fract_mixing, b_parent1, b_parent2
|
500 |
+
|
|
|
501 |
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
|
502 |
r"""
|
503 |
Inserts all necessary parameters into the trajectory tree.
|
|
|
509 |
list_latents: list
|
510 |
list of the latents to be inserted
|
511 |
"""
|
512 |
+
b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
|
513 |
+
self.tree_latents.insert(b_parent1 + 1, list_latents)
|
514 |
+
self.tree_final_imgs.insert(b_parent1 + 1, self.sdh.latent2image(list_latents[-1]))
|
515 |
+
self.tree_fracts.insert(b_parent1 + 1, fract_mixing)
|
516 |
+
self.tree_idx_injection.insert(b_parent1 + 1, idx_injection)
|
517 |
+
|
518 |
+
def get_spatial_mask_template(self):
|
|
|
519 |
r"""
|
520 |
+
Experimental helper function to get a spatial mask template.
|
521 |
"""
|
522 |
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
523 |
C, H, W = shape_latents
|
524 |
return np.ones((H, W))
|
525 |
+
|
526 |
def set_spatial_mask(self, img_mask):
|
527 |
r"""
|
528 |
+
Experimental helper function to set a spatial mask.
|
529 |
The mask forces latents to be overwritten.
|
530 |
Args:
|
531 |
+
img_mask:
|
532 |
mask image [0,1]. You can get a template using get_spatial_mask_template
|
|
|
533 |
"""
|
|
|
534 |
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
535 |
C, H, W = shape_latents
|
536 |
img_mask = np.asarray(img_mask)
|
|
|
540 |
assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
|
541 |
spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
|
542 |
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
543 |
+
spatial_mask = spatial_mask.repeat((C, 1, 1))
|
544 |
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
|
|
545 |
self.spatial_mask = spatial_mask
|
546 |
+
|
|
|
547 |
def get_noise(self, seed):
|
548 |
r"""
|
549 |
Helper function to get noise given seed.
|
550 |
Args:
|
551 |
seed: int
|
|
|
552 |
"""
|
553 |
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
|
554 |
if self.mode == 'standard':
|
|
|
559 |
h = self.image1_lowres.size[1]
|
560 |
shape_latents = [self.sdh.model.channels, h, w]
|
561 |
C, H, W = shape_latents
|
|
|
562 |
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
|
563 |
|
|
|
564 |
@torch.no_grad()
|
565 |
def run_diffusion(
|
566 |
+
self,
|
567 |
+
list_conditionings,
|
568 |
+
latents_start: torch.FloatTensor = None,
|
569 |
+
idx_start: int = 0,
|
570 |
+
list_latents_mixing=None,
|
571 |
+
mixing_coeffs=0.0,
|
572 |
+
return_image: Optional[bool] = False):
|
|
|
|
|
573 |
r"""
|
574 |
Wrapper function for diffusion runners.
|
575 |
Depending on the mode, the correct one will be executed.
|
576 |
+
|
577 |
Args:
|
578 |
list_conditionings: list
|
579 |
List of all conditionings for the diffusion model.
|
580 |
+
latents_start: torch.FloatTensor
|
581 |
Latents that are used for injection
|
582 |
idx_start: int
|
583 |
Index of the diffusion process start and where the latents_for_injection are injected
|
584 |
+
list_latents_mixing: torch.FloatTensor
|
585 |
List of latents (latent trajectories) that are used for mixing
|
586 |
mixing_coeffs: float or list
|
587 |
Coefficients, how strong each element of list_latents_mixing will be mixed in.
|
588 |
return_image: Optional[bool]
|
589 |
Optionally return image directly
|
590 |
"""
|
591 |
+
|
592 |
# Ensure correct num_inference_steps in Holder
|
593 |
self.sdh.num_inference_steps = self.num_inference_steps
|
594 |
assert type(list_conditionings) is list, "list_conditionings need to be a list"
|
595 |
+
|
596 |
if self.mode == 'standard':
|
597 |
text_embeddings = list_conditionings[0]
|
598 |
return self.sdh.run_diffusion_standard(
|
599 |
+
text_embeddings=text_embeddings,
|
600 |
+
latents_start=latents_start,
|
601 |
+
idx_start=idx_start,
|
602 |
+
list_latents_mixing=list_latents_mixing,
|
603 |
+
mixing_coeffs=mixing_coeffs,
|
604 |
+
spatial_mask=self.spatial_mask,
|
605 |
+
return_image=return_image)
|
606 |
+
|
|
|
607 |
elif self.mode == 'upscale':
|
608 |
cond = list_conditionings[0]
|
609 |
uc_full = list_conditionings[1]
|
610 |
return self.sdh.run_diffusion_upscaling(
|
611 |
+
cond,
|
612 |
+
uc_full,
|
613 |
+
latents_start=latents_start,
|
614 |
+
idx_start=idx_start,
|
615 |
+
list_latents_mixing=list_latents_mixing,
|
616 |
+
mixing_coeffs=mixing_coeffs,
|
617 |
return_image=return_image)
|
618 |
|
|
|
619 |
def run_upscaling(
|
620 |
+
self,
|
621 |
dp_img: str,
|
622 |
depth_strength: float = 0.65,
|
623 |
num_inference_steps: int = 100,
|
624 |
nmb_max_branches_highres: int = 5,
|
625 |
nmb_max_branches_lowres: int = 6,
|
626 |
+
duration_single_segment=3,
|
627 |
+
fps=24,
|
628 |
+
fixed_seeds: Optional[List[int]] = None):
|
629 |
r"""
|
630 |
Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
|
631 |
+
|
632 |
Args:
|
633 |
dp_img: str
|
634 |
Path to the low-res transition path (as saved in write_imgs_transition)
|
635 |
depth_strength:
|
636 |
+
Determines how deep the first injection will happen.
|
637 |
Deeper injections will cause (unwanted) formation of new structures,
|
638 |
more shallow values will go into alpha-blendy land.
|
639 |
num_inference_steps:
|
|
|
646 |
Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
|
647 |
duration_single_segment: float
|
648 |
The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
|
649 |
+
fps: float
|
650 |
+
frames per second of movie
|
651 |
fixed_seeds: Optional[List[int)]:
|
652 |
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
653 |
Otherwise random seeds will be taken.
|
654 |
"""
|
655 |
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
656 |
fp_movie = os.path.join(dp_img, "movie_highres.mp4")
|
|
|
657 |
ms = MovieSaver(fp_movie, fps=fps)
|
658 |
assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
|
659 |
dict_stuff = yml_load(fp_yml)
|
660 |
+
|
661 |
# load lowres images
|
662 |
nmb_images_lowres = dict_stuff['nmb_images']
|
663 |
prompt1 = dict_stuff['prompt1']
|
664 |
prompt2 = dict_stuff['prompt2']
|
665 |
+
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres - 1, nmb_max_branches_lowres)).astype(np.int32)
|
666 |
imgs_lowres = []
|
667 |
for i in idx_img_lowres:
|
668 |
fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
|
669 |
assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
|
670 |
imgs_lowres.append(Image.open(fp_img_lowres))
|
|
|
671 |
|
672 |
# set up upscaling
|
673 |
text_embeddingA = self.sdh.get_text_embedding(prompt1)
|
674 |
text_embeddingB = self.sdh.get_text_embedding(prompt2)
|
675 |
+
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1)
|
676 |
+
for i in range(nmb_max_branches_lowres - 1):
|
|
|
|
|
677 |
print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
|
|
|
678 |
self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
|
679 |
+
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1 - list_fract_mixing[i])
|
680 |
+
if i == 0:
|
681 |
+
recycle_img1 = False
|
|
|
682 |
else:
|
683 |
self.swap_forward()
|
684 |
+
recycle_img1 = True
|
685 |
+
|
686 |
self.set_image1(imgs_lowres[i])
|
687 |
+
self.set_image2(imgs_lowres[i + 1])
|
688 |
+
|
689 |
list_imgs = self.run_transition(
|
690 |
+
recycle_img1=recycle_img1,
|
691 |
+
recycle_img2=False,
|
692 |
+
num_inference_steps=num_inference_steps,
|
693 |
+
depth_strength=depth_strength,
|
694 |
+
nmb_max_branches=nmb_max_branches_highres)
|
|
|
|
|
695 |
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
|
696 |
+
|
697 |
# Save movie frame
|
698 |
for img in list_imgs_interp:
|
699 |
ms.write_frame(img)
|
|
|
700 |
ms.finalize()
|
|
|
701 |
|
|
|
702 |
@torch.no_grad()
|
703 |
def get_mixed_conditioning(self, fract_mixing):
|
704 |
if self.mode == 'standard':
|
|
|
720 |
|
721 |
@torch.no_grad()
|
722 |
def get_text_embeddings(
|
723 |
+
self,
|
724 |
+
prompt: str):
|
|
|
725 |
r"""
|
726 |
Computes the text embeddings provided a string with a prompts.
|
727 |
Adapted from stable diffusion repo
|
|
|
729 |
prompt: str
|
730 |
ABC trending on artstation painted by Old Greg.
|
731 |
"""
|
|
|
732 |
return self.sdh.get_text_embedding(prompt)
|
|
|
733 |
|
734 |
def write_imgs_transition(self, dp_img):
|
735 |
r"""
|
|
|
744 |
for i, img in enumerate(imgs_transition):
|
745 |
img_leaf = Image.fromarray(img)
|
746 |
img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
|
747 |
+
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
|
|
748 |
self.save_statedict(fp_yml)
|
749 |
+
|
750 |
def write_movie_transition(self, fp_movie, duration_transition, fps=30):
|
751 |
r"""
|
752 |
Writes the transition movie to fp_movie, using the given duration and fps..
|
|
|
758 |
duration of the movie in seonds
|
759 |
fps: int
|
760 |
fps of the movie
|
|
|
761 |
"""
|
762 |
+
|
763 |
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
|
764 |
imgs_transition_ext = add_frames_linear_interp(self.tree_final_imgs, duration_transition, fps)
|
765 |
|
|
|
771 |
ms.write_frame(img)
|
772 |
ms.finalize()
|
773 |
|
|
|
|
|
774 |
def save_statedict(self, fp_yml):
|
775 |
# Dump everything relevant into yaml
|
776 |
imgs_transition = self.tree_final_imgs
|
777 |
state_dict = self.get_state_dict()
|
778 |
state_dict['nmb_images'] = len(imgs_transition)
|
779 |
yml_save(fp_yml, state_dict)
|
780 |
+
|
781 |
def get_state_dict(self):
|
782 |
state_dict = {}
|
783 |
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
|
|
|
791 |
state_dict[v] = int(getattr(self, v))
|
792 |
elif v == 'guidance_scale':
|
793 |
state_dict[v] = float(getattr(self, v))
|
794 |
+
|
795 |
else:
|
796 |
try:
|
797 |
state_dict[v] = getattr(self, v)
|
798 |
+
except Exception:
|
799 |
pass
|
|
|
800 |
return state_dict
|
801 |
+
|
802 |
def randomize_seed(self):
|
803 |
r"""
|
804 |
Set a random seed for a fresh start.
|
805 |
+
"""
|
806 |
seed = np.random.randint(999999999)
|
807 |
self.set_seed(seed)
|
808 |
+
|
809 |
def set_seed(self, seed: int):
|
810 |
r"""
|
811 |
Set a the seed for a fresh start.
|
812 |
+
"""
|
813 |
self.seed = seed
|
814 |
self.sdh.seed = seed
|
815 |
+
|
816 |
def set_width(self, width):
|
817 |
r"""
|
818 |
Set the width of the resulting image.
|
819 |
+
"""
|
820 |
assert np.mod(width, 64) == 0, "set_width: value needs to be divisible by 64"
|
821 |
self.width = width
|
822 |
self.sdh.width = width
|
823 |
+
|
824 |
def set_height(self, height):
|
825 |
r"""
|
826 |
Set the height of the resulting image.
|
827 |
+
"""
|
828 |
assert np.mod(height, 64) == 0, "set_height: value needs to be divisible by 64"
|
829 |
self.height = height
|
830 |
self.sdh.height = height
|
|
|
831 |
|
832 |
def swap_forward(self):
|
833 |
r"""
|
834 |
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
|
835 |
as in run_multi_transition()
|
836 |
+
"""
|
837 |
# Move over all latents
|
838 |
self.tree_latents[0] = self.tree_latents[-1]
|
|
|
839 |
# Move over prompts and text embeddings
|
840 |
self.prompt1 = self.prompt2
|
841 |
self.text_embedding1 = self.text_embedding2
|
|
|
842 |
# Final cleanup for extra sanity
|
843 |
+
self.tree_final_imgs = []
|
844 |
+
|
|
|
845 |
def get_lpips_similarity(self, imgA, imgB):
|
846 |
r"""
|
847 |
+
Computes the image similarity between two images imgA and imgB.
|
848 |
Used to determine the optimal point of insertion to create smooth transitions.
|
849 |
High values indicate low similarity.
|
850 |
+
"""
|
851 |
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
|
852 |
+
tensorA = 2 * tensorA / 255.0 - 1
|
853 |
+
tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0)
|
|
|
854 |
tensorB = torch.from_numpy(imgB).float().cuda(self.device)
|
855 |
+
tensorB = 2 * tensorB / 255.0 - 1
|
856 |
+
tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0)
|
857 |
lploss = self.lpips(tensorA, tensorB)
|
858 |
lploss = float(lploss[0][0][0][0])
|
|
|
859 |
return lploss
|
860 |
+
|
861 |
+
# Auxiliary functions
|
862 |
+
def get_closest_idx(
|
863 |
+
self,
|
864 |
+
fract_mixing: float):
|
865 |
+
r"""
|
866 |
+
Helper function to retrieve the parents for any given mixing.
|
867 |
+
Example: fract_mixing = 0.4 and self.tree_fracts = [0, 0.3, 0.6, 1.0]
|
868 |
+
Will return the two closest values here, i.e. [1, 2]
|
869 |
+
"""
|
870 |
+
|
871 |
+
pdist = fract_mixing - np.asarray(self.tree_fracts)
|
872 |
+
pdist_pos = pdist.copy()
|
873 |
+
pdist_pos[pdist_pos < 0] = np.inf
|
874 |
+
b_parent1 = np.argmin(pdist_pos)
|
875 |
+
pdist_neg = -pdist.copy()
|
876 |
+
pdist_neg[pdist_neg <= 0] = np.inf
|
877 |
+
b_parent2 = np.argmin(pdist_neg)
|
878 |
+
|
879 |
+
if b_parent1 > b_parent2:
|
880 |
+
tmp = b_parent2
|
881 |
+
b_parent2 = b_parent1
|
882 |
+
b_parent1 = tmp
|
883 |
+
|
884 |
+
return b_parent1, b_parent2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
movie_util.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# Copyright 2022 Lunar Ring. All rights reserved.
|
2 |
-
#
|
|
|
3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
# you may not use this file except in compliance with the License.
|
5 |
# You may obtain a copy of the License at
|
@@ -17,26 +18,24 @@ import os
|
|
17 |
import numpy as np
|
18 |
from tqdm import tqdm
|
19 |
import cv2
|
20 |
-
from typing import
|
21 |
-
import ffmpeg
|
|
|
22 |
|
23 |
-
#%%
|
24 |
-
|
25 |
class MovieSaver():
|
26 |
def __init__(
|
27 |
-
self,
|
28 |
-
fp_out: str,
|
29 |
-
fps: int = 24,
|
30 |
shape_hw: List[int] = None,
|
31 |
crf: int = 24,
|
32 |
codec: str = 'libx264',
|
33 |
-
preset: str ='fast',
|
34 |
-
pix_fmt: str = 'yuv420p',
|
35 |
-
silent_ffmpeg: bool = True
|
36 |
-
):
|
37 |
r"""
|
38 |
Initializes movie saver class - a human friendly ffmpeg wrapper.
|
39 |
-
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
|
40 |
Don't forget toi finalize movie file with moviesaver.finalize().
|
41 |
Args:
|
42 |
fp_out: str
|
@@ -47,22 +46,22 @@ class MovieSaver():
|
|
47 |
Output shape, optional argument. Can be initialized automatically when first frame is written.
|
48 |
crf: int
|
49 |
ffmpeg doc: the range of the CRF scale is 0–51, where 0 is lossless
|
50 |
-
(for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
|
51 |
-
A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
|
52 |
-
Consider 17 or 18 to be visually lossless or nearly so;
|
53 |
-
it should look the same or nearly the same as the input but it isn't technically lossless.
|
54 |
-
The range is exponential, so increasing the CRF value +6 results in
|
55 |
-
roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
|
56 |
codec: int
|
57 |
Number of diffusion steps. Larger values will take more compute time.
|
58 |
preset: str
|
59 |
Choose between ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow.
|
60 |
-
ffmpeg doc: A preset is a collection of options that will provide a certain encoding speed
|
61 |
-
to compression ratio. A slower preset will provide better compression
|
62 |
-
(compression is quality per filesize).
|
63 |
-
This means that, for example, if you target a certain file size or constant bit rate,
|
64 |
you will achieve better quality with a slower preset. Similarly, for constant quality encoding,
|
65 |
-
you will simply save bitrate by choosing a slower preset.
|
66 |
pix_fmt: str
|
67 |
Pixel format. Run 'ffmpeg -pix_fmts' in your shell to see all options.
|
68 |
silent_ffmpeg: bool
|
@@ -70,7 +69,7 @@ class MovieSaver():
|
|
70 |
"""
|
71 |
if len(os.path.split(fp_out)[0]) > 0:
|
72 |
assert os.path.isdir(os.path.split(fp_out)[0]), "Directory does not exist!"
|
73 |
-
|
74 |
self.fp_out = fp_out
|
75 |
self.fps = fps
|
76 |
self.crf = crf
|
@@ -78,10 +77,10 @@ class MovieSaver():
|
|
78 |
self.codec = codec
|
79 |
self.preset = preset
|
80 |
self.silent_ffmpeg = silent_ffmpeg
|
81 |
-
|
82 |
if os.path.isfile(fp_out):
|
83 |
os.remove(fp_out)
|
84 |
-
|
85 |
self.init_done = False
|
86 |
self.nmb_frames = 0
|
87 |
if shape_hw is None:
|
@@ -91,11 +90,9 @@ class MovieSaver():
|
|
91 |
shape_hw.append(3)
|
92 |
self.shape_hw = shape_hw
|
93 |
self.initialize()
|
94 |
-
|
95 |
-
|
96 |
print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
|
97 |
-
|
98 |
-
|
99 |
def initialize(self):
|
100 |
args = (
|
101 |
ffmpeg
|
@@ -111,8 +108,7 @@ class MovieSaver():
|
|
111 |
self.init_done = True
|
112 |
self.shape_hw = tuple(self.shape_hw)
|
113 |
print(f"Initialization done. Movie shape: {self.shape_hw}")
|
114 |
-
|
115 |
-
|
116 |
def write_frame(self, out_frame: np.ndarray):
|
117 |
r"""
|
118 |
Function to dump a numpy array as frame of a movie.
|
@@ -123,18 +119,17 @@ class MovieSaver():
|
|
123 |
Dim 1: x
|
124 |
Dim 2: RGB
|
125 |
"""
|
126 |
-
|
127 |
assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
|
128 |
assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
|
129 |
assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
|
130 |
-
|
131 |
if not self.init_done:
|
132 |
self.shape_hw = out_frame.shape
|
133 |
self.initialize()
|
134 |
-
|
135 |
assert self.shape_hw == out_frame.shape, f"You cannot change the image size after init. Initialized with {self.shape_hw}, out_frame {out_frame.shape}"
|
136 |
|
137 |
-
# write frame
|
138 |
self.ffmpg_process.stdin.write(
|
139 |
out_frame
|
140 |
.astype(np.uint8)
|
@@ -142,8 +137,7 @@ class MovieSaver():
|
|
142 |
)
|
143 |
|
144 |
self.nmb_frames += 1
|
145 |
-
|
146 |
-
|
147 |
def finalize(self):
|
148 |
r"""
|
149 |
Call this function to finalize the movie. If you forget to call it your movie will be garbage.
|
@@ -157,7 +151,6 @@ class MovieSaver():
|
|
157 |
print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
|
158 |
|
159 |
|
160 |
-
|
161 |
def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
162 |
r"""
|
163 |
Concatenate multiple movie segments into one long movie, using ffmpeg.
|
@@ -167,13 +160,13 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
|
167 |
fp_final : str
|
168 |
Full path of the final movie file. Should end with .mp4
|
169 |
list_fp_movies : list[str]
|
170 |
-
List of full paths of movie segments.
|
171 |
"""
|
172 |
assert fp_final[-4] == ".", "fp_final seems to miss file extension: {fp_final}"
|
173 |
for fp in list_fp_movies:
|
174 |
assert os.path.isfile(fp), f"Input movie does not exist: {fp}"
|
175 |
assert os.path.getsize(fp) > 100, f"Input movie seems empty: {fp}"
|
176 |
-
|
177 |
if os.path.isfile(fp_final):
|
178 |
os.remove(fp_final)
|
179 |
|
@@ -181,32 +174,32 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
|
181 |
list_concat = []
|
182 |
for fp_part in list_fp_movies:
|
183 |
list_concat.append(f"""file '{fp_part}'""")
|
184 |
-
|
185 |
# save this list
|
186 |
fp_list = "tmp_move.txt"
|
187 |
with open(fp_list, "w") as fa:
|
188 |
for item in list_concat:
|
189 |
fa.write("%s\n" % item)
|
190 |
-
|
191 |
cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
|
192 |
-
dp_movie = os.path.split(fp_final)[0]
|
193 |
subprocess.call(cmd, shell=True)
|
194 |
os.remove(fp_list)
|
195 |
if os.path.isfile(fp_final):
|
196 |
print(f"concatenate_movies: success! Watch here: {fp_final}")
|
197 |
|
198 |
-
|
199 |
class MovieReader():
|
200 |
r"""
|
201 |
Class to read in a movie.
|
202 |
"""
|
|
|
203 |
def __init__(self, fp_movie):
|
204 |
self.video_player_object = cv2.VideoCapture(fp_movie)
|
205 |
self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
|
206 |
self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
|
207 |
-
self.shape = [100,100,3]
|
208 |
self.shape_is_set = False
|
209 |
-
|
210 |
def get_next_frame(self):
|
211 |
success, image = self.video_player_object.read()
|
212 |
if success:
|
@@ -217,19 +210,18 @@ class MovieReader():
|
|
217 |
else:
|
218 |
return np.zeros(self.shape)
|
219 |
|
220 |
-
|
221 |
-
if __name__ == "__main__":
|
222 |
-
fps=2
|
223 |
list_fp_movies = []
|
224 |
for k in range(4):
|
225 |
fp_movie = f"/tmp/my_random_movie_{k}.mp4"
|
226 |
list_fp_movies.append(fp_movie)
|
227 |
ms = MovieSaver(fp_movie, fps=fps)
|
228 |
for fn in tqdm(range(30)):
|
229 |
-
img = (np.random.rand(512, 1024, 3)*255).astype(np.uint8)
|
230 |
ms.write_frame(img)
|
231 |
ms.finalize()
|
232 |
-
|
233 |
fp_final = "/tmp/my_concatenated_movie.mp4"
|
234 |
concatenate_movies(fp_final, list_fp_movies)
|
235 |
-
|
|
|
1 |
# Copyright 2022 Lunar Ring. All rights reserved.
|
2 |
+
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
3 |
+
|
4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
# you may not use this file except in compliance with the License.
|
6 |
# You may obtain a copy of the License at
|
|
|
18 |
import numpy as np
|
19 |
from tqdm import tqdm
|
20 |
import cv2
|
21 |
+
from typing import List
|
22 |
+
import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
|
23 |
+
|
24 |
|
|
|
|
|
25 |
class MovieSaver():
|
26 |
def __init__(
|
27 |
+
self,
|
28 |
+
fp_out: str,
|
29 |
+
fps: int = 24,
|
30 |
shape_hw: List[int] = None,
|
31 |
crf: int = 24,
|
32 |
codec: str = 'libx264',
|
33 |
+
preset: str = 'fast',
|
34 |
+
pix_fmt: str = 'yuv420p',
|
35 |
+
silent_ffmpeg: bool = True):
|
|
|
36 |
r"""
|
37 |
Initializes movie saver class - a human friendly ffmpeg wrapper.
|
38 |
+
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
|
39 |
Don't forget toi finalize movie file with moviesaver.finalize().
|
40 |
Args:
|
41 |
fp_out: str
|
|
|
46 |
Output shape, optional argument. Can be initialized automatically when first frame is written.
|
47 |
crf: int
|
48 |
ffmpeg doc: the range of the CRF scale is 0–51, where 0 is lossless
|
49 |
+
(for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
|
50 |
+
A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
|
51 |
+
Consider 17 or 18 to be visually lossless or nearly so;
|
52 |
+
it should look the same or nearly the same as the input but it isn't technically lossless.
|
53 |
+
The range is exponential, so increasing the CRF value +6 results in
|
54 |
+
roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
|
55 |
codec: int
|
56 |
Number of diffusion steps. Larger values will take more compute time.
|
57 |
preset: str
|
58 |
Choose between ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow.
|
59 |
+
ffmpeg doc: A preset is a collection of options that will provide a certain encoding speed
|
60 |
+
to compression ratio. A slower preset will provide better compression
|
61 |
+
(compression is quality per filesize).
|
62 |
+
This means that, for example, if you target a certain file size or constant bit rate,
|
63 |
you will achieve better quality with a slower preset. Similarly, for constant quality encoding,
|
64 |
+
you will simply save bitrate by choosing a slower preset.
|
65 |
pix_fmt: str
|
66 |
Pixel format. Run 'ffmpeg -pix_fmts' in your shell to see all options.
|
67 |
silent_ffmpeg: bool
|
|
|
69 |
"""
|
70 |
if len(os.path.split(fp_out)[0]) > 0:
|
71 |
assert os.path.isdir(os.path.split(fp_out)[0]), "Directory does not exist!"
|
72 |
+
|
73 |
self.fp_out = fp_out
|
74 |
self.fps = fps
|
75 |
self.crf = crf
|
|
|
77 |
self.codec = codec
|
78 |
self.preset = preset
|
79 |
self.silent_ffmpeg = silent_ffmpeg
|
80 |
+
|
81 |
if os.path.isfile(fp_out):
|
82 |
os.remove(fp_out)
|
83 |
+
|
84 |
self.init_done = False
|
85 |
self.nmb_frames = 0
|
86 |
if shape_hw is None:
|
|
|
90 |
shape_hw.append(3)
|
91 |
self.shape_hw = shape_hw
|
92 |
self.initialize()
|
93 |
+
|
|
|
94 |
print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
|
95 |
+
|
|
|
96 |
def initialize(self):
|
97 |
args = (
|
98 |
ffmpeg
|
|
|
108 |
self.init_done = True
|
109 |
self.shape_hw = tuple(self.shape_hw)
|
110 |
print(f"Initialization done. Movie shape: {self.shape_hw}")
|
111 |
+
|
|
|
112 |
def write_frame(self, out_frame: np.ndarray):
|
113 |
r"""
|
114 |
Function to dump a numpy array as frame of a movie.
|
|
|
119 |
Dim 1: x
|
120 |
Dim 2: RGB
|
121 |
"""
|
|
|
122 |
assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
|
123 |
assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
|
124 |
assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
|
125 |
+
|
126 |
if not self.init_done:
|
127 |
self.shape_hw = out_frame.shape
|
128 |
self.initialize()
|
129 |
+
|
130 |
assert self.shape_hw == out_frame.shape, f"You cannot change the image size after init. Initialized with {self.shape_hw}, out_frame {out_frame.shape}"
|
131 |
|
132 |
+
# write frame
|
133 |
self.ffmpg_process.stdin.write(
|
134 |
out_frame
|
135 |
.astype(np.uint8)
|
|
|
137 |
)
|
138 |
|
139 |
self.nmb_frames += 1
|
140 |
+
|
|
|
141 |
def finalize(self):
|
142 |
r"""
|
143 |
Call this function to finalize the movie. If you forget to call it your movie will be garbage.
|
|
|
151 |
print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
|
152 |
|
153 |
|
|
|
154 |
def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
155 |
r"""
|
156 |
Concatenate multiple movie segments into one long movie, using ffmpeg.
|
|
|
160 |
fp_final : str
|
161 |
Full path of the final movie file. Should end with .mp4
|
162 |
list_fp_movies : list[str]
|
163 |
+
List of full paths of movie segments.
|
164 |
"""
|
165 |
assert fp_final[-4] == ".", "fp_final seems to miss file extension: {fp_final}"
|
166 |
for fp in list_fp_movies:
|
167 |
assert os.path.isfile(fp), f"Input movie does not exist: {fp}"
|
168 |
assert os.path.getsize(fp) > 100, f"Input movie seems empty: {fp}"
|
169 |
+
|
170 |
if os.path.isfile(fp_final):
|
171 |
os.remove(fp_final)
|
172 |
|
|
|
174 |
list_concat = []
|
175 |
for fp_part in list_fp_movies:
|
176 |
list_concat.append(f"""file '{fp_part}'""")
|
177 |
+
|
178 |
# save this list
|
179 |
fp_list = "tmp_move.txt"
|
180 |
with open(fp_list, "w") as fa:
|
181 |
for item in list_concat:
|
182 |
fa.write("%s\n" % item)
|
183 |
+
|
184 |
cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
|
|
|
185 |
subprocess.call(cmd, shell=True)
|
186 |
os.remove(fp_list)
|
187 |
if os.path.isfile(fp_final):
|
188 |
print(f"concatenate_movies: success! Watch here: {fp_final}")
|
189 |
|
190 |
+
|
191 |
class MovieReader():
|
192 |
r"""
|
193 |
Class to read in a movie.
|
194 |
"""
|
195 |
+
|
196 |
def __init__(self, fp_movie):
|
197 |
self.video_player_object = cv2.VideoCapture(fp_movie)
|
198 |
self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
|
199 |
self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
|
200 |
+
self.shape = [100, 100, 3]
|
201 |
self.shape_is_set = False
|
202 |
+
|
203 |
def get_next_frame(self):
|
204 |
success, image = self.video_player_object.read()
|
205 |
if success:
|
|
|
210 |
else:
|
211 |
return np.zeros(self.shape)
|
212 |
|
213 |
+
|
214 |
+
if __name__ == "__main__":
|
215 |
+
fps = 2
|
216 |
list_fp_movies = []
|
217 |
for k in range(4):
|
218 |
fp_movie = f"/tmp/my_random_movie_{k}.mp4"
|
219 |
list_fp_movies.append(fp_movie)
|
220 |
ms = MovieSaver(fp_movie, fps=fps)
|
221 |
for fn in tqdm(range(30)):
|
222 |
+
img = (np.random.rand(512, 1024, 3) * 255).astype(np.uint8)
|
223 |
ms.write_frame(img)
|
224 |
ms.finalize()
|
225 |
+
|
226 |
fp_final = "/tmp/my_concatenated_movie.mp4"
|
227 |
concatenate_movies(fp_final, list_fp_movies)
|
|
stable_diffusion_holder.py
CHANGED
@@ -13,36 +13,25 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
|
16 |
-
import os
|
17 |
-
dp_git = "/home/lugo/git/"
|
18 |
-
sys.path.append(os.path.join(dp_git,'garden4'))
|
19 |
-
sys.path.append('util')
|
20 |
import torch
|
21 |
torch.backends.cudnn.benchmark = False
|
|
|
22 |
import numpy as np
|
23 |
import warnings
|
24 |
warnings.filterwarnings('ignore')
|
25 |
-
import time
|
26 |
-
import subprocess
|
27 |
import warnings
|
28 |
import torch
|
29 |
-
from tqdm.auto import tqdm
|
30 |
from PIL import Image
|
31 |
-
# import matplotlib.pyplot as plt
|
32 |
import torch
|
33 |
-
from
|
34 |
-
import datetime
|
35 |
-
from typing import Callable, List, Optional, Union
|
36 |
-
import inspect
|
37 |
-
from threading import Thread
|
38 |
-
torch.set_grad_enabled(False)
|
39 |
from omegaconf import OmegaConf
|
40 |
from torch import autocast
|
41 |
from contextlib import nullcontext
|
42 |
from ldm.util import instantiate_from_config
|
43 |
from ldm.models.diffusion.ddim import DDIMSampler
|
44 |
from einops import repeat, rearrange
|
45 |
-
|
46 |
|
47 |
|
48 |
def pad_image(input_image):
|
@@ -53,41 +42,11 @@ def pad_image(input_image):
|
|
53 |
return im_padded
|
54 |
|
55 |
|
56 |
-
|
57 |
-
def make_batch_inpaint(
|
58 |
-
image,
|
59 |
-
mask,
|
60 |
-
txt,
|
61 |
-
device,
|
62 |
-
num_samples=1):
|
63 |
-
image = np.array(image.convert("RGB"))
|
64 |
-
image = image[None].transpose(0, 3, 1, 2)
|
65 |
-
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
66 |
-
|
67 |
-
mask = np.array(mask.convert("L"))
|
68 |
-
mask = mask.astype(np.float32) / 255.0
|
69 |
-
mask = mask[None, None]
|
70 |
-
mask[mask < 0.5] = 0
|
71 |
-
mask[mask >= 0.5] = 1
|
72 |
-
mask = torch.from_numpy(mask)
|
73 |
-
|
74 |
-
masked_image = image * (mask < 0.5)
|
75 |
-
|
76 |
-
batch = {
|
77 |
-
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
78 |
-
"txt": num_samples * [txt],
|
79 |
-
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
80 |
-
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
81 |
-
}
|
82 |
-
return batch
|
83 |
-
|
84 |
-
|
85 |
def make_batch_superres(
|
86 |
image,
|
87 |
txt,
|
88 |
device,
|
89 |
-
num_samples=1
|
90 |
-
):
|
91 |
image = np.array(image.convert("RGB"))
|
92 |
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
93 |
batch = {
|
@@ -107,14 +66,14 @@ def make_noise_augmentation(model, batch, noise_level=None):
|
|
107 |
|
108 |
|
109 |
class StableDiffusionHolder:
|
110 |
-
def __init__(self,
|
111 |
-
fp_ckpt: str = None,
|
112 |
fp_config: str = None,
|
113 |
-
num_inference_steps: int = 30,
|
114 |
height: Optional[int] = None,
|
115 |
width: Optional[int] = None,
|
116 |
device: str = None,
|
117 |
-
precision: str='autocast',
|
118 |
):
|
119 |
r"""
|
120 |
Initializes the stable diffusion holder, which contains the models and sampler.
|
@@ -122,26 +81,26 @@ class StableDiffusionHolder:
|
|
122 |
fp_ckpt: File pointer to the .ckpt model file
|
123 |
fp_config: File pointer to the .yaml config file
|
124 |
num_inference_steps: Number of diffusion iterations. Will be overwritten by latent blending.
|
125 |
-
height: Height of the resulting image.
|
126 |
-
width: Width of the resulting image.
|
127 |
device: Device to run the model on.
|
128 |
precision: Precision to run the model on.
|
129 |
"""
|
130 |
self.seed = 42
|
131 |
self.guidance_scale = 5.0
|
132 |
-
|
133 |
if device is None:
|
134 |
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
135 |
else:
|
136 |
self.device = device
|
137 |
self.precision = precision
|
138 |
self.init_model(fp_ckpt, fp_config)
|
139 |
-
|
140 |
-
self.f = 8
|
141 |
self.C = 4
|
142 |
self.ddim_eta = 0
|
143 |
self.num_inference_steps = num_inference_steps
|
144 |
-
|
145 |
if height is None and width is None:
|
146 |
self.init_auto_res()
|
147 |
else:
|
@@ -149,53 +108,44 @@ class StableDiffusionHolder:
|
|
149 |
assert width is not None, "specify both width and height"
|
150 |
self.height = height
|
151 |
self.width = width
|
152 |
-
|
153 |
-
# Inpainting inits
|
154 |
-
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
|
155 |
-
self.image_empty = Image.fromarray(np.zeros([self.width, self.height, 3], dtype=np.uint8))
|
156 |
-
|
157 |
self.negative_prompt = [""]
|
158 |
-
|
159 |
-
|
160 |
def init_model(self, fp_ckpt, fp_config):
|
161 |
r"""Loads the models and sampler.
|
162 |
"""
|
163 |
|
164 |
assert os.path.isfile(fp_ckpt), f"Your model checkpoint file does not exist: {fp_ckpt}"
|
165 |
self.fp_ckpt = fp_ckpt
|
166 |
-
|
167 |
# Auto init the config?
|
168 |
if fp_config is None:
|
169 |
fn_ckpt = os.path.basename(fp_ckpt)
|
170 |
if 'depth' in fn_ckpt:
|
171 |
fp_config = 'configs/v2-midas-inference.yaml'
|
172 |
-
elif 'inpain' in fn_ckpt:
|
173 |
-
fp_config = 'configs/v2-inpainting-inference.yaml'
|
174 |
elif 'upscaler' in fn_ckpt:
|
175 |
-
fp_config = 'configs/x4-upscaling.yaml'
|
176 |
elif '512' in fn_ckpt:
|
177 |
-
fp_config = 'configs/v2-inference.yaml'
|
178 |
-
elif '768'in fn_ckpt:
|
179 |
-
fp_config = 'configs/v2-inference-v.yaml'
|
180 |
elif 'v1-5' in fn_ckpt:
|
181 |
-
fp_config = 'configs/v1-inference.yaml'
|
182 |
else:
|
183 |
raise ValueError("auto detect of config failed. please specify fp_config manually!")
|
184 |
-
|
185 |
assert os.path.isfile(fp_config), "Auto-init of the config file failed. Please specify manually."
|
186 |
-
|
187 |
assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
|
188 |
-
|
189 |
|
190 |
config = OmegaConf.load(fp_config)
|
191 |
-
|
192 |
self.model = instantiate_from_config(config.model)
|
193 |
self.model.load_state_dict(torch.load(fp_ckpt)["state_dict"], strict=False)
|
194 |
|
195 |
self.model = self.model.to(self.device)
|
196 |
self.sampler = DDIMSampler(self.model)
|
197 |
-
|
198 |
-
|
199 |
def init_auto_res(self):
|
200 |
r"""Automatically set the resolution to the one used in training.
|
201 |
"""
|
@@ -205,7 +155,7 @@ class StableDiffusionHolder:
|
|
205 |
else:
|
206 |
self.height = 512
|
207 |
self.width = 512
|
208 |
-
|
209 |
def set_negative_prompt(self, negative_prompt):
|
210 |
r"""Set the negative prompt. Currenty only one negative prompt is supported
|
211 |
"""
|
@@ -214,51 +164,46 @@ class StableDiffusionHolder:
|
|
214 |
self.negative_prompt = [negative_prompt]
|
215 |
else:
|
216 |
self.negative_prompt = negative_prompt
|
217 |
-
|
218 |
if len(self.negative_prompt) > 1:
|
219 |
self.negative_prompt = [self.negative_prompt[0]]
|
220 |
|
221 |
-
|
222 |
def get_text_embedding(self, prompt):
|
223 |
c = self.model.get_learned_conditioning(prompt)
|
224 |
return c
|
225 |
-
|
226 |
@torch.no_grad()
|
227 |
def get_cond_upscaling(self, image, text_embedding, noise_level):
|
228 |
r"""
|
229 |
Initializes the conditioning for the x4 upscaling model.
|
230 |
"""
|
231 |
-
|
232 |
image = pad_image(image) # resize to integer multiple of 32
|
233 |
w, h = image.size
|
234 |
noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
|
235 |
batch = make_batch_superres(image, txt="placeholder", device=self.device, num_samples=1)
|
236 |
|
237 |
x_augment, noise_level = make_noise_augmentation(self.model, batch, noise_level)
|
238 |
-
|
239 |
cond = {"c_concat": [x_augment], "c_crossattn": [text_embedding], "c_adm": noise_level}
|
240 |
# uncond cond
|
241 |
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
242 |
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
|
243 |
-
|
244 |
return cond, uc_full
|
245 |
|
246 |
@torch.no_grad()
|
247 |
def run_diffusion_standard(
|
248 |
-
self,
|
249 |
-
text_embeddings: torch.FloatTensor,
|
250 |
latents_start: torch.FloatTensor,
|
251 |
-
idx_start: int = 0,
|
252 |
-
list_latents_mixing
|
253 |
-
mixing_coeffs
|
254 |
-
spatial_mask
|
255 |
-
return_image: Optional[bool] = False
|
256 |
-
):
|
257 |
r"""
|
258 |
-
Diffusion standard version.
|
259 |
-
|
260 |
Args:
|
261 |
-
text_embeddings: torch.FloatTensor
|
262 |
Text embeddings used for diffusion
|
263 |
latents_for_injection: torch.FloatTensor or list
|
264 |
Latents that are used for injection
|
@@ -270,41 +215,32 @@ class StableDiffusionHolder:
|
|
270 |
experimental feature for enforcing pixels from list_latents_mixing
|
271 |
return_image: Optional[bool]
|
272 |
Optionally return image directly
|
273 |
-
|
274 |
"""
|
275 |
-
|
276 |
# Asserts
|
277 |
if type(mixing_coeffs) == float:
|
278 |
-
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
|
279 |
elif type(mixing_coeffs) == list:
|
280 |
assert len(mixing_coeffs) == self.num_inference_steps
|
281 |
list_mixing_coeffs = mixing_coeffs
|
282 |
else:
|
283 |
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
|
284 |
-
|
285 |
if np.sum(list_mixing_coeffs) > 0:
|
286 |
assert len(list_latents_mixing) == self.num_inference_steps
|
287 |
-
|
288 |
-
|
289 |
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
290 |
-
|
291 |
with precision_scope("cuda"):
|
292 |
with self.model.ema_scope():
|
293 |
if self.guidance_scale != 1.0:
|
294 |
uc = self.model.get_learned_conditioning(self.negative_prompt)
|
295 |
else:
|
296 |
uc = None
|
297 |
-
|
298 |
-
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
299 |
-
|
300 |
latents = latents_start.clone()
|
301 |
-
|
302 |
timesteps = self.sampler.ddim_timesteps
|
303 |
-
|
304 |
time_range = np.flip(timesteps)
|
305 |
total_steps = timesteps.shape[0]
|
306 |
-
|
307 |
-
# collect latents
|
308 |
list_latents_out = []
|
309 |
for i, step in enumerate(time_range):
|
310 |
# Set the right starting latents
|
@@ -313,83 +249,71 @@ class StableDiffusionHolder:
|
|
313 |
continue
|
314 |
elif i == idx_start:
|
315 |
latents = latents_start.clone()
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
latents_mixtarget = list_latents_mixing[i-1].clone()
|
320 |
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
321 |
-
|
322 |
if spatial_mask is not None and list_latents_mixing is not None:
|
323 |
-
latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask)
|
324 |
-
|
325 |
-
|
326 |
index = total_steps - i - 1
|
327 |
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
|
328 |
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
latents, pred_x0 = outs
|
336 |
list_latents_out.append(latents.clone())
|
337 |
-
|
338 |
-
if return_image:
|
339 |
return self.latent2image(latents)
|
340 |
else:
|
341 |
return list_latents_out
|
342 |
-
|
343 |
-
|
344 |
@torch.no_grad()
|
345 |
def run_diffusion_upscaling(
|
346 |
-
self,
|
347 |
cond,
|
348 |
uc_full,
|
349 |
-
latents_start: torch.FloatTensor,
|
350 |
-
idx_start: int = -1,
|
351 |
-
list_latents_mixing = None,
|
352 |
-
mixing_coeffs = 0.0,
|
353 |
-
return_image: Optional[bool] = False
|
354 |
-
):
|
355 |
r"""
|
356 |
-
Diffusion upscaling version.
|
357 |
"""
|
358 |
-
|
359 |
# Asserts
|
360 |
if type(mixing_coeffs) == float:
|
361 |
-
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
|
362 |
elif type(mixing_coeffs) == list:
|
363 |
assert len(mixing_coeffs) == self.num_inference_steps
|
364 |
list_mixing_coeffs = mixing_coeffs
|
365 |
else:
|
366 |
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
|
367 |
-
|
368 |
if np.sum(list_mixing_coeffs) > 0:
|
369 |
assert len(list_latents_mixing) == self.num_inference_steps
|
370 |
-
|
371 |
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
372 |
-
|
373 |
-
|
374 |
-
w = uc_full['c_concat'][0].shape[3]
|
375 |
-
|
376 |
with precision_scope("cuda"):
|
377 |
with self.model.ema_scope():
|
378 |
|
379 |
shape_latents = [self.model.channels, h, w]
|
380 |
-
|
381 |
-
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
382 |
C, H, W = shape_latents
|
383 |
size = (1, C, H, W)
|
384 |
b = size[0]
|
385 |
-
|
386 |
latents = latents_start.clone()
|
387 |
-
|
388 |
timesteps = self.sampler.ddim_timesteps
|
389 |
-
|
390 |
time_range = np.flip(timesteps)
|
391 |
total_steps = timesteps.shape[0]
|
392 |
-
|
393 |
# collect latents
|
394 |
list_latents_out = []
|
395 |
for i, step in enumerate(time_range):
|
@@ -399,232 +323,40 @@ class StableDiffusionHolder:
|
|
399 |
continue
|
400 |
elif i == idx_start:
|
401 |
latents = latents_start.clone()
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
latents_mixtarget = list_latents_mixing[i-1].clone()
|
406 |
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
407 |
-
|
408 |
# print(f"diffusion iter {i}")
|
409 |
index = total_steps - i - 1
|
410 |
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
411 |
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
latents, pred_x0 = outs
|
419 |
list_latents_out.append(latents.clone())
|
420 |
-
|
421 |
-
if return_image:
|
422 |
-
return self.latent2image(latents)
|
423 |
-
else:
|
424 |
-
return list_latents_out
|
425 |
-
|
426 |
-
@torch.no_grad()
|
427 |
-
def run_diffusion_inpaint(
|
428 |
-
self,
|
429 |
-
text_embeddings: torch.FloatTensor,
|
430 |
-
latents_for_injection: torch.FloatTensor = None,
|
431 |
-
idx_start: int = -1,
|
432 |
-
idx_stop: int = -1,
|
433 |
-
return_image: Optional[bool] = False
|
434 |
-
):
|
435 |
-
r"""
|
436 |
-
Runs inpaint-based diffusion. Returns a list of latents that were computed.
|
437 |
-
Adaptations allow to supply
|
438 |
-
a) starting index for diffusion
|
439 |
-
b) stopping index for diffusion
|
440 |
-
c) latent representations that are injected at the starting index
|
441 |
-
Furthermore the intermittent latents are collected and returned.
|
442 |
-
|
443 |
-
Adapted from diffusers (https://github.com/huggingface/diffusers)
|
444 |
-
Args:
|
445 |
-
text_embeddings: torch.FloatTensor
|
446 |
-
Text embeddings used for diffusion
|
447 |
-
latents_for_injection: torch.FloatTensor
|
448 |
-
Latents that are used for injection
|
449 |
-
idx_start: int
|
450 |
-
Index of the diffusion process start and where the latents_for_injection are injected
|
451 |
-
idx_stop: int
|
452 |
-
Index of the diffusion process end.
|
453 |
-
return_image: Optional[bool]
|
454 |
-
Optionally return image directly
|
455 |
-
|
456 |
-
"""
|
457 |
-
|
458 |
-
if latents_for_injection is None:
|
459 |
-
do_inject_latents = False
|
460 |
-
else:
|
461 |
-
do_inject_latents = True
|
462 |
-
|
463 |
-
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
464 |
-
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
465 |
|
466 |
-
|
467 |
-
with self.model.ema_scope():
|
468 |
-
|
469 |
-
batch = make_batch_inpaint(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
|
470 |
-
c = text_embeddings
|
471 |
-
c_cat = list()
|
472 |
-
for ck in self.model.concat_keys:
|
473 |
-
cc = batch[ck].float()
|
474 |
-
if ck != self.model.masked_image_key:
|
475 |
-
bchw = [1, 4, self.height // 8, self.width // 8]
|
476 |
-
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
477 |
-
else:
|
478 |
-
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
|
479 |
-
c_cat.append(cc)
|
480 |
-
c_cat = torch.cat(c_cat, dim=1)
|
481 |
-
|
482 |
-
# cond
|
483 |
-
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
484 |
-
|
485 |
-
# uncond cond
|
486 |
-
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
487 |
-
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
488 |
-
|
489 |
-
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
|
490 |
-
|
491 |
-
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
|
492 |
-
# sampling
|
493 |
-
C, H, W = shape_latents
|
494 |
-
size = (1, C, H, W)
|
495 |
-
|
496 |
-
device = self.model.betas.device
|
497 |
-
b = size[0]
|
498 |
-
latents = torch.randn(size, generator=generator, device=device)
|
499 |
-
|
500 |
-
timesteps = self.sampler.ddim_timesteps
|
501 |
-
|
502 |
-
time_range = np.flip(timesteps)
|
503 |
-
total_steps = timesteps.shape[0]
|
504 |
-
|
505 |
-
# collect latents
|
506 |
-
list_latents_out = []
|
507 |
-
for i, step in enumerate(time_range):
|
508 |
-
if do_inject_latents:
|
509 |
-
# Inject latent at right place
|
510 |
-
if i < idx_start:
|
511 |
-
continue
|
512 |
-
elif i == idx_start:
|
513 |
-
latents = latents_for_injection.clone()
|
514 |
-
|
515 |
-
if i == idx_stop:
|
516 |
-
return list_latents_out
|
517 |
-
|
518 |
-
index = total_steps - i - 1
|
519 |
-
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
520 |
-
|
521 |
-
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
522 |
-
quantize_denoised=False, temperature=1.0,
|
523 |
-
noise_dropout=0.0, score_corrector=None,
|
524 |
-
corrector_kwargs=None,
|
525 |
-
unconditional_guidance_scale=self.guidance_scale,
|
526 |
-
unconditional_conditioning=uc_full,
|
527 |
-
dynamic_threshold=None)
|
528 |
-
latents, pred_x0 = outs
|
529 |
-
list_latents_out.append(latents.clone())
|
530 |
-
|
531 |
-
if return_image:
|
532 |
return self.latent2image(latents)
|
533 |
else:
|
534 |
return list_latents_out
|
535 |
|
536 |
@torch.no_grad()
|
537 |
def latent2image(
|
538 |
-
self,
|
539 |
-
latents: torch.FloatTensor
|
540 |
-
):
|
541 |
r"""
|
542 |
Returns an image provided a latent representation from diffusion.
|
543 |
Args:
|
544 |
latents: torch.FloatTensor
|
545 |
-
Result of the diffusion process.
|
546 |
"""
|
547 |
x_sample = self.model.decode_first_stage(latents)
|
548 |
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
549 |
-
x_sample = 255 * x_sample[0
|
550 |
image = x_sample.astype(np.uint8)
|
551 |
return image
|
552 |
-
|
553 |
-
@torch.no_grad()
|
554 |
-
def interpolate_spherical(p0, p1, fract_mixing: float):
|
555 |
-
r"""
|
556 |
-
Helper function to correctly mix two random variables using spherical interpolation.
|
557 |
-
See https://en.wikipedia.org/wiki/Slerp
|
558 |
-
The function will always cast up to float64 for sake of extra 4.
|
559 |
-
Args:
|
560 |
-
p0:
|
561 |
-
First tensor for interpolation
|
562 |
-
p1:
|
563 |
-
Second tensor for interpolation
|
564 |
-
fract_mixing: float
|
565 |
-
Mixing coefficient of interval [0, 1].
|
566 |
-
0 will return in p0
|
567 |
-
1 will return in p1
|
568 |
-
0.x will return a mix between both preserving angular velocity.
|
569 |
-
"""
|
570 |
-
|
571 |
-
if p0.dtype == torch.float16:
|
572 |
-
recast_to = 'fp16'
|
573 |
-
else:
|
574 |
-
recast_to = 'fp32'
|
575 |
-
|
576 |
-
p0 = p0.double()
|
577 |
-
p1 = p1.double()
|
578 |
-
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
579 |
-
epsilon = 1e-7
|
580 |
-
dot = torch.sum(p0 * p1) / norm
|
581 |
-
dot = dot.clamp(-1+epsilon, 1-epsilon)
|
582 |
-
|
583 |
-
theta_0 = torch.arccos(dot)
|
584 |
-
sin_theta_0 = torch.sin(theta_0)
|
585 |
-
theta_t = theta_0 * fract_mixing
|
586 |
-
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
587 |
-
s1 = torch.sin(theta_t) / sin_theta_0
|
588 |
-
interp = p0*s0 + p1*s1
|
589 |
-
|
590 |
-
if recast_to == 'fp16':
|
591 |
-
interp = interp.half()
|
592 |
-
elif recast_to == 'fp32':
|
593 |
-
interp = interp.float()
|
594 |
-
|
595 |
-
return interp
|
596 |
-
|
597 |
-
|
598 |
-
if __name__ == "__main__":
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
num_inference_steps = 20 # Number of diffusion interations
|
606 |
-
|
607 |
-
# fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
|
608 |
-
# fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
|
609 |
-
|
610 |
-
# fp_ckpt= "../stable_diffusion_models/ckpt/512-inpainting-ema.ckpt"
|
611 |
-
# fp_config = '../stablediffusion/configs//stable-diffusion/v2-inpainting-inference.yaml'
|
612 |
-
|
613 |
-
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
614 |
-
# fp_config = 'configs/v2-inference-v.yaml'
|
615 |
-
|
616 |
-
|
617 |
-
self = StableDiffusionHolder(fp_ckpt, num_inference_steps=num_inference_steps)
|
618 |
-
|
619 |
-
xxx
|
620 |
-
|
621 |
-
#%%
|
622 |
-
self.width = 1536
|
623 |
-
self.height = 768
|
624 |
-
prompt = "360 degree equirectangular, a huge rocky hill full of pianos and keyboards, musical instruments, cinematic, masterpiece 8 k, artstation"
|
625 |
-
self.set_negative_prompt("out of frame, faces, rendering, blurry")
|
626 |
-
te = self.get_text_embedding(prompt)
|
627 |
-
|
628 |
-
img = self.run_diffusion_standard(te, return_image=True)
|
629 |
-
Image.fromarray(img).show()
|
630 |
-
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
|
16 |
+
import os
|
|
|
|
|
|
|
17 |
import torch
|
18 |
torch.backends.cudnn.benchmark = False
|
19 |
+
torch.set_grad_enabled(False)
|
20 |
import numpy as np
|
21 |
import warnings
|
22 |
warnings.filterwarnings('ignore')
|
|
|
|
|
23 |
import warnings
|
24 |
import torch
|
|
|
25 |
from PIL import Image
|
|
|
26 |
import torch
|
27 |
+
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
28 |
from omegaconf import OmegaConf
|
29 |
from torch import autocast
|
30 |
from contextlib import nullcontext
|
31 |
from ldm.util import instantiate_from_config
|
32 |
from ldm.models.diffusion.ddim import DDIMSampler
|
33 |
from einops import repeat, rearrange
|
34 |
+
from utils import interpolate_spherical
|
35 |
|
36 |
|
37 |
def pad_image(input_image):
|
|
|
42 |
return im_padded
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def make_batch_superres(
|
46 |
image,
|
47 |
txt,
|
48 |
device,
|
49 |
+
num_samples=1):
|
|
|
50 |
image = np.array(image.convert("RGB"))
|
51 |
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
52 |
batch = {
|
|
|
66 |
|
67 |
|
68 |
class StableDiffusionHolder:
|
69 |
+
def __init__(self,
|
70 |
+
fp_ckpt: str = None,
|
71 |
fp_config: str = None,
|
72 |
+
num_inference_steps: int = 30,
|
73 |
height: Optional[int] = None,
|
74 |
width: Optional[int] = None,
|
75 |
device: str = None,
|
76 |
+
precision: str = 'autocast',
|
77 |
):
|
78 |
r"""
|
79 |
Initializes the stable diffusion holder, which contains the models and sampler.
|
|
|
81 |
fp_ckpt: File pointer to the .ckpt model file
|
82 |
fp_config: File pointer to the .yaml config file
|
83 |
num_inference_steps: Number of diffusion iterations. Will be overwritten by latent blending.
|
84 |
+
height: Height of the resulting image.
|
85 |
+
width: Width of the resulting image.
|
86 |
device: Device to run the model on.
|
87 |
precision: Precision to run the model on.
|
88 |
"""
|
89 |
self.seed = 42
|
90 |
self.guidance_scale = 5.0
|
91 |
+
|
92 |
if device is None:
|
93 |
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
94 |
else:
|
95 |
self.device = device
|
96 |
self.precision = precision
|
97 |
self.init_model(fp_ckpt, fp_config)
|
98 |
+
|
99 |
+
self.f = 8 # downsampling factor, most often 8 or 16"
|
100 |
self.C = 4
|
101 |
self.ddim_eta = 0
|
102 |
self.num_inference_steps = num_inference_steps
|
103 |
+
|
104 |
if height is None and width is None:
|
105 |
self.init_auto_res()
|
106 |
else:
|
|
|
108 |
assert width is not None, "specify both width and height"
|
109 |
self.height = height
|
110 |
self.width = width
|
111 |
+
|
|
|
|
|
|
|
|
|
112 |
self.negative_prompt = [""]
|
113 |
+
|
|
|
114 |
def init_model(self, fp_ckpt, fp_config):
|
115 |
r"""Loads the models and sampler.
|
116 |
"""
|
117 |
|
118 |
assert os.path.isfile(fp_ckpt), f"Your model checkpoint file does not exist: {fp_ckpt}"
|
119 |
self.fp_ckpt = fp_ckpt
|
120 |
+
|
121 |
# Auto init the config?
|
122 |
if fp_config is None:
|
123 |
fn_ckpt = os.path.basename(fp_ckpt)
|
124 |
if 'depth' in fn_ckpt:
|
125 |
fp_config = 'configs/v2-midas-inference.yaml'
|
|
|
|
|
126 |
elif 'upscaler' in fn_ckpt:
|
127 |
+
fp_config = 'configs/x4-upscaling.yaml'
|
128 |
elif '512' in fn_ckpt:
|
129 |
+
fp_config = 'configs/v2-inference.yaml'
|
130 |
+
elif '768' in fn_ckpt:
|
131 |
+
fp_config = 'configs/v2-inference-v.yaml'
|
132 |
elif 'v1-5' in fn_ckpt:
|
133 |
+
fp_config = 'configs/v1-inference.yaml'
|
134 |
else:
|
135 |
raise ValueError("auto detect of config failed. please specify fp_config manually!")
|
136 |
+
|
137 |
assert os.path.isfile(fp_config), "Auto-init of the config file failed. Please specify manually."
|
138 |
+
|
139 |
assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
|
|
|
140 |
|
141 |
config = OmegaConf.load(fp_config)
|
142 |
+
|
143 |
self.model = instantiate_from_config(config.model)
|
144 |
self.model.load_state_dict(torch.load(fp_ckpt)["state_dict"], strict=False)
|
145 |
|
146 |
self.model = self.model.to(self.device)
|
147 |
self.sampler = DDIMSampler(self.model)
|
148 |
+
|
|
|
149 |
def init_auto_res(self):
|
150 |
r"""Automatically set the resolution to the one used in training.
|
151 |
"""
|
|
|
155 |
else:
|
156 |
self.height = 512
|
157 |
self.width = 512
|
158 |
+
|
159 |
def set_negative_prompt(self, negative_prompt):
|
160 |
r"""Set the negative prompt. Currenty only one negative prompt is supported
|
161 |
"""
|
|
|
164 |
self.negative_prompt = [negative_prompt]
|
165 |
else:
|
166 |
self.negative_prompt = negative_prompt
|
167 |
+
|
168 |
if len(self.negative_prompt) > 1:
|
169 |
self.negative_prompt = [self.negative_prompt[0]]
|
170 |
|
|
|
171 |
def get_text_embedding(self, prompt):
|
172 |
c = self.model.get_learned_conditioning(prompt)
|
173 |
return c
|
174 |
+
|
175 |
@torch.no_grad()
|
176 |
def get_cond_upscaling(self, image, text_embedding, noise_level):
|
177 |
r"""
|
178 |
Initializes the conditioning for the x4 upscaling model.
|
179 |
"""
|
|
|
180 |
image = pad_image(image) # resize to integer multiple of 32
|
181 |
w, h = image.size
|
182 |
noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
|
183 |
batch = make_batch_superres(image, txt="placeholder", device=self.device, num_samples=1)
|
184 |
|
185 |
x_augment, noise_level = make_noise_augmentation(self.model, batch, noise_level)
|
186 |
+
|
187 |
cond = {"c_concat": [x_augment], "c_crossattn": [text_embedding], "c_adm": noise_level}
|
188 |
# uncond cond
|
189 |
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
190 |
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
|
|
|
191 |
return cond, uc_full
|
192 |
|
193 |
@torch.no_grad()
|
194 |
def run_diffusion_standard(
|
195 |
+
self,
|
196 |
+
text_embeddings: torch.FloatTensor,
|
197 |
latents_start: torch.FloatTensor,
|
198 |
+
idx_start: int = 0,
|
199 |
+
list_latents_mixing=None,
|
200 |
+
mixing_coeffs=0.0,
|
201 |
+
spatial_mask=None,
|
202 |
+
return_image: Optional[bool] = False):
|
|
|
203 |
r"""
|
204 |
+
Diffusion standard version.
|
|
|
205 |
Args:
|
206 |
+
text_embeddings: torch.FloatTensor
|
207 |
Text embeddings used for diffusion
|
208 |
latents_for_injection: torch.FloatTensor or list
|
209 |
Latents that are used for injection
|
|
|
215 |
experimental feature for enforcing pixels from list_latents_mixing
|
216 |
return_image: Optional[bool]
|
217 |
Optionally return image directly
|
|
|
218 |
"""
|
|
|
219 |
# Asserts
|
220 |
if type(mixing_coeffs) == float:
|
221 |
+
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
|
222 |
elif type(mixing_coeffs) == list:
|
223 |
assert len(mixing_coeffs) == self.num_inference_steps
|
224 |
list_mixing_coeffs = mixing_coeffs
|
225 |
else:
|
226 |
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
|
227 |
+
|
228 |
if np.sum(list_mixing_coeffs) > 0:
|
229 |
assert len(list_latents_mixing) == self.num_inference_steps
|
230 |
+
|
|
|
231 |
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
|
|
232 |
with precision_scope("cuda"):
|
233 |
with self.model.ema_scope():
|
234 |
if self.guidance_scale != 1.0:
|
235 |
uc = self.model.get_learned_conditioning(self.negative_prompt)
|
236 |
else:
|
237 |
uc = None
|
238 |
+
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
|
|
|
|
|
239 |
latents = latents_start.clone()
|
|
|
240 |
timesteps = self.sampler.ddim_timesteps
|
|
|
241 |
time_range = np.flip(timesteps)
|
242 |
total_steps = timesteps.shape[0]
|
243 |
+
# Collect latents
|
|
|
244 |
list_latents_out = []
|
245 |
for i, step in enumerate(time_range):
|
246 |
# Set the right starting latents
|
|
|
249 |
continue
|
250 |
elif i == idx_start:
|
251 |
latents = latents_start.clone()
|
252 |
+
# Mix latents
|
253 |
+
if i > 0 and list_mixing_coeffs[i] > 0:
|
254 |
+
latents_mixtarget = list_latents_mixing[i - 1].clone()
|
|
|
255 |
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
256 |
+
|
257 |
if spatial_mask is not None and list_latents_mixing is not None:
|
258 |
+
latents = interpolate_spherical(latents, list_latents_mixing[i - 1], 1 - spatial_mask)
|
259 |
+
|
|
|
260 |
index = total_steps - i - 1
|
261 |
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
|
262 |
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
263 |
+
quantize_denoised=False, temperature=1.0,
|
264 |
+
noise_dropout=0.0, score_corrector=None,
|
265 |
+
corrector_kwargs=None,
|
266 |
+
unconditional_guidance_scale=self.guidance_scale,
|
267 |
+
unconditional_conditioning=uc,
|
268 |
+
dynamic_threshold=None)
|
269 |
latents, pred_x0 = outs
|
270 |
list_latents_out.append(latents.clone())
|
271 |
+
if return_image:
|
|
|
272 |
return self.latent2image(latents)
|
273 |
else:
|
274 |
return list_latents_out
|
275 |
+
|
|
|
276 |
@torch.no_grad()
|
277 |
def run_diffusion_upscaling(
|
278 |
+
self,
|
279 |
cond,
|
280 |
uc_full,
|
281 |
+
latents_start: torch.FloatTensor,
|
282 |
+
idx_start: int = -1,
|
283 |
+
list_latents_mixing: list = None,
|
284 |
+
mixing_coeffs: float = 0.0,
|
285 |
+
return_image: Optional[bool] = False):
|
|
|
286 |
r"""
|
287 |
+
Diffusion upscaling version.
|
288 |
"""
|
289 |
+
|
290 |
# Asserts
|
291 |
if type(mixing_coeffs) == float:
|
292 |
+
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
|
293 |
elif type(mixing_coeffs) == list:
|
294 |
assert len(mixing_coeffs) == self.num_inference_steps
|
295 |
list_mixing_coeffs = mixing_coeffs
|
296 |
else:
|
297 |
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
|
298 |
+
|
299 |
if np.sum(list_mixing_coeffs) > 0:
|
300 |
assert len(list_latents_mixing) == self.num_inference_steps
|
301 |
+
|
302 |
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
303 |
+
h = uc_full['c_concat'][0].shape[2]
|
304 |
+
w = uc_full['c_concat'][0].shape[3]
|
|
|
|
|
305 |
with precision_scope("cuda"):
|
306 |
with self.model.ema_scope():
|
307 |
|
308 |
shape_latents = [self.model.channels, h, w]
|
309 |
+
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
|
|
|
310 |
C, H, W = shape_latents
|
311 |
size = (1, C, H, W)
|
312 |
b = size[0]
|
|
|
313 |
latents = latents_start.clone()
|
|
|
314 |
timesteps = self.sampler.ddim_timesteps
|
|
|
315 |
time_range = np.flip(timesteps)
|
316 |
total_steps = timesteps.shape[0]
|
|
|
317 |
# collect latents
|
318 |
list_latents_out = []
|
319 |
for i, step in enumerate(time_range):
|
|
|
323 |
continue
|
324 |
elif i == idx_start:
|
325 |
latents = latents_start.clone()
|
326 |
+
# Mix the latents.
|
327 |
+
if i > 0 and list_mixing_coeffs[i] > 0:
|
328 |
+
latents_mixtarget = list_latents_mixing[i - 1].clone()
|
|
|
329 |
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
|
|
330 |
# print(f"diffusion iter {i}")
|
331 |
index = total_steps - i - 1
|
332 |
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
333 |
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
334 |
+
quantize_denoised=False, temperature=1.0,
|
335 |
+
noise_dropout=0.0, score_corrector=None,
|
336 |
+
corrector_kwargs=None,
|
337 |
+
unconditional_guidance_scale=self.guidance_scale,
|
338 |
+
unconditional_conditioning=uc_full,
|
339 |
+
dynamic_threshold=None)
|
340 |
latents, pred_x0 = outs
|
341 |
list_latents_out.append(latents.clone())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
+
if return_image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
return self.latent2image(latents)
|
345 |
else:
|
346 |
return list_latents_out
|
347 |
|
348 |
@torch.no_grad()
|
349 |
def latent2image(
|
350 |
+
self,
|
351 |
+
latents: torch.FloatTensor):
|
|
|
352 |
r"""
|
353 |
Returns an image provided a latent representation from diffusion.
|
354 |
Args:
|
355 |
latents: torch.FloatTensor
|
356 |
+
Result of the diffusion process.
|
357 |
"""
|
358 |
x_sample = self.model.decode_first_stage(latents)
|
359 |
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
360 |
+
x_sample = 255 * x_sample[0, :, :].permute([1, 2, 0]).cpu().numpy()
|
361 |
image = x_sample.astype(np.uint8)
|
362 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Lunar Ring. All rights reserved.
|
2 |
+
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
torch.backends.cudnn.benchmark = False
|
18 |
+
import numpy as np
|
19 |
+
import warnings
|
20 |
+
warnings.filterwarnings('ignore')
|
21 |
+
import time
|
22 |
+
import warnings
|
23 |
+
import datetime
|
24 |
+
from typing import List, Union
|
25 |
+
torch.set_grad_enabled(False)
|
26 |
+
import yaml
|
27 |
+
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def interpolate_spherical(p0, p1, fract_mixing: float):
|
31 |
+
r"""
|
32 |
+
Helper function to correctly mix two random variables using spherical interpolation.
|
33 |
+
See https://en.wikipedia.org/wiki/Slerp
|
34 |
+
The function will always cast up to float64 for sake of extra 4.
|
35 |
+
Args:
|
36 |
+
p0:
|
37 |
+
First tensor for interpolation
|
38 |
+
p1:
|
39 |
+
Second tensor for interpolation
|
40 |
+
fract_mixing: float
|
41 |
+
Mixing coefficient of interval [0, 1].
|
42 |
+
0 will return in p0
|
43 |
+
1 will return in p1
|
44 |
+
0.x will return a mix between both preserving angular velocity.
|
45 |
+
"""
|
46 |
+
|
47 |
+
if p0.dtype == torch.float16:
|
48 |
+
recast_to = 'fp16'
|
49 |
+
else:
|
50 |
+
recast_to = 'fp32'
|
51 |
+
|
52 |
+
p0 = p0.double()
|
53 |
+
p1 = p1.double()
|
54 |
+
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
55 |
+
epsilon = 1e-7
|
56 |
+
dot = torch.sum(p0 * p1) / norm
|
57 |
+
dot = dot.clamp(-1 + epsilon, 1 - epsilon)
|
58 |
+
|
59 |
+
theta_0 = torch.arccos(dot)
|
60 |
+
sin_theta_0 = torch.sin(theta_0)
|
61 |
+
theta_t = theta_0 * fract_mixing
|
62 |
+
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
63 |
+
s1 = torch.sin(theta_t) / sin_theta_0
|
64 |
+
interp = p0 * s0 + p1 * s1
|
65 |
+
|
66 |
+
if recast_to == 'fp16':
|
67 |
+
interp = interp.half()
|
68 |
+
elif recast_to == 'fp32':
|
69 |
+
interp = interp.float()
|
70 |
+
|
71 |
+
return interp
|
72 |
+
|
73 |
+
|
74 |
+
def interpolate_linear(p0, p1, fract_mixing):
|
75 |
+
r"""
|
76 |
+
Helper function to mix two variables using standard linear interpolation.
|
77 |
+
Args:
|
78 |
+
p0:
|
79 |
+
First tensor / np.ndarray for interpolation
|
80 |
+
p1:
|
81 |
+
Second tensor / np.ndarray for interpolation
|
82 |
+
fract_mixing: float
|
83 |
+
Mixing coefficient of interval [0, 1].
|
84 |
+
0 will return in p0
|
85 |
+
1 will return in p1
|
86 |
+
0.x will return a linear mix between both.
|
87 |
+
"""
|
88 |
+
reconvert_uint8 = False
|
89 |
+
if type(p0) is np.ndarray and p0.dtype == 'uint8':
|
90 |
+
reconvert_uint8 = True
|
91 |
+
p0 = p0.astype(np.float64)
|
92 |
+
|
93 |
+
if type(p1) is np.ndarray and p1.dtype == 'uint8':
|
94 |
+
reconvert_uint8 = True
|
95 |
+
p1 = p1.astype(np.float64)
|
96 |
+
|
97 |
+
interp = (1 - fract_mixing) * p0 + fract_mixing * p1
|
98 |
+
|
99 |
+
if reconvert_uint8:
|
100 |
+
interp = np.clip(interp, 0, 255).astype(np.uint8)
|
101 |
+
|
102 |
+
return interp
|
103 |
+
|
104 |
+
|
105 |
+
def add_frames_linear_interp(
|
106 |
+
list_imgs: List[np.ndarray],
|
107 |
+
fps_target: Union[float, int] = None,
|
108 |
+
duration_target: Union[float, int] = None,
|
109 |
+
nmb_frames_target: int = None):
|
110 |
+
r"""
|
111 |
+
Helper function to cheaply increase the number of frames given a list of images,
|
112 |
+
by virtue of standard linear interpolation.
|
113 |
+
The number of inserted frames will be automatically adjusted so that the total of number
|
114 |
+
of frames can be fixed precisely, using a random shuffling technique.
|
115 |
+
The function allows 1:1 comparisons between transitions as videos.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
list_imgs: List[np.ndarray)
|
119 |
+
List of images, between each image new frames will be inserted via linear interpolation.
|
120 |
+
fps_target:
|
121 |
+
OptionA: specify here the desired frames per second.
|
122 |
+
duration_target:
|
123 |
+
OptionA: specify here the desired duration of the transition in seconds.
|
124 |
+
nmb_frames_target:
|
125 |
+
OptionB: directly fix the total number of frames of the output.
|
126 |
+
"""
|
127 |
+
|
128 |
+
# Sanity
|
129 |
+
if nmb_frames_target is not None and fps_target is not None:
|
130 |
+
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
|
131 |
+
if fps_target is None:
|
132 |
+
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
|
133 |
+
if nmb_frames_target is None:
|
134 |
+
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
135 |
+
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
136 |
+
nmb_frames_target = fps_target * duration_target
|
137 |
+
|
138 |
+
# Get number of frames that are missing
|
139 |
+
nmb_frames_diff = len(list_imgs) - 1
|
140 |
+
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
|
141 |
+
|
142 |
+
if nmb_frames_missing < 1:
|
143 |
+
return list_imgs
|
144 |
+
|
145 |
+
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
|
146 |
+
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
|
147 |
+
mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
|
148 |
+
constfact = np.floor(mean_nmb_frames_insert)
|
149 |
+
remainder_x = 1 - (mean_nmb_frames_insert - constfact)
|
150 |
+
nmb_iter = 0
|
151 |
+
while True:
|
152 |
+
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
|
153 |
+
nmb_frames_to_insert[nmb_frames_to_insert <= remainder_x] = 0
|
154 |
+
nmb_frames_to_insert[nmb_frames_to_insert > remainder_x] = 1
|
155 |
+
nmb_frames_to_insert += constfact
|
156 |
+
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
|
157 |
+
break
|
158 |
+
nmb_iter += 1
|
159 |
+
if nmb_iter > 100000:
|
160 |
+
print("add_frames_linear_interp: issue with inserting the right number of frames")
|
161 |
+
break
|
162 |
+
|
163 |
+
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
|
164 |
+
list_imgs_interp = []
|
165 |
+
for i in range(len(list_imgs_float) - 1):
|
166 |
+
img0 = list_imgs_float[i]
|
167 |
+
img1 = list_imgs_float[i + 1]
|
168 |
+
list_imgs_interp.append(img0.astype(np.uint8))
|
169 |
+
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i] + 2)[1:-1]
|
170 |
+
for fract_linblend in list_fracts_linblend:
|
171 |
+
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
|
172 |
+
list_imgs_interp.append(img_blend.astype(np.uint8))
|
173 |
+
if i == len(list_imgs_float) - 2:
|
174 |
+
list_imgs_interp.append(img1.astype(np.uint8))
|
175 |
+
|
176 |
+
return list_imgs_interp
|
177 |
+
|
178 |
+
|
179 |
+
def get_spacing(nmb_points: int, scaling: float):
|
180 |
+
"""
|
181 |
+
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
|
182 |
+
Args:
|
183 |
+
nmb_points: int
|
184 |
+
Number of points between [0, 1]
|
185 |
+
scaling: float
|
186 |
+
Higher values will return higher sampling density around 0.5
|
187 |
+
"""
|
188 |
+
if scaling < 1.7:
|
189 |
+
return np.linspace(0, 1, nmb_points)
|
190 |
+
nmb_points_per_side = nmb_points // 2 + 1
|
191 |
+
if np.mod(nmb_points, 2) != 0: # Uneven case
|
192 |
+
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
|
193 |
+
right_side = 1 - left_side[::-1][1:]
|
194 |
+
else:
|
195 |
+
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
|
196 |
+
right_side = 1 - left_side[::-1]
|
197 |
+
all_fracts = np.hstack([left_side, right_side])
|
198 |
+
return all_fracts
|
199 |
+
|
200 |
+
|
201 |
+
def get_time(resolution=None):
|
202 |
+
"""
|
203 |
+
Helper function returning an nicely formatted time string, e.g. 221117_1620
|
204 |
+
"""
|
205 |
+
if resolution is None:
|
206 |
+
resolution = "second"
|
207 |
+
if resolution == "day":
|
208 |
+
t = time.strftime('%y%m%d', time.localtime())
|
209 |
+
elif resolution == "minute":
|
210 |
+
t = time.strftime('%y%m%d_%H%M', time.localtime())
|
211 |
+
elif resolution == "second":
|
212 |
+
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
213 |
+
elif resolution == "millisecond":
|
214 |
+
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
215 |
+
t += "_"
|
216 |
+
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f')) / 1000)))
|
217 |
+
else:
|
218 |
+
raise ValueError("bad resolution provided: %s" % resolution)
|
219 |
+
return t
|
220 |
+
|
221 |
+
|
222 |
+
def compare_dicts(a, b):
|
223 |
+
"""
|
224 |
+
Compares two dictionaries a and b and returns a dictionary c, with all
|
225 |
+
keys,values that have shared keys in a and b but same values in a and b.
|
226 |
+
The values of a and b are stacked together in the output.
|
227 |
+
Example:
|
228 |
+
a = {}; a['bobo'] = 4
|
229 |
+
b = {}; b['bobo'] = 5
|
230 |
+
c = dict_compare(a,b)
|
231 |
+
c = {"bobo",[4,5]}
|
232 |
+
"""
|
233 |
+
c = {}
|
234 |
+
for key in a.keys():
|
235 |
+
if key in b.keys():
|
236 |
+
val_a = a[key]
|
237 |
+
val_b = b[key]
|
238 |
+
if val_a != val_b:
|
239 |
+
c[key] = [val_a, val_b]
|
240 |
+
return c
|
241 |
+
|
242 |
+
|
243 |
+
def yml_load(fp_yml, print_fields=False):
|
244 |
+
"""
|
245 |
+
Helper function for loading yaml files
|
246 |
+
"""
|
247 |
+
with open(fp_yml) as f:
|
248 |
+
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
|
249 |
+
dict_data = dict(data)
|
250 |
+
print("load: loaded {}".format(fp_yml))
|
251 |
+
return dict_data
|
252 |
+
|
253 |
+
|
254 |
+
def yml_save(fp_yml, dict_stuff):
|
255 |
+
"""
|
256 |
+
Helper function for saving yaml files
|
257 |
+
"""
|
258 |
+
with open(fp_yml, 'w') as f:
|
259 |
+
yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
|
260 |
+
print("yml_save: saved {}".format(fp_yml))
|