lunarring commited on
Commit
91c7095
·
1 Parent(s): 75a96dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -401
app.py CHANGED
@@ -1,404 +1,26 @@
1
- from huggingface_hub import hf_hub_download
2
- hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
3
-
4
- # Copyright 2022 Lunar Ring. All rights reserved.
5
- # Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
6
- #
7
- # Licensed under the Apache License, Version 2.0 (the "License");
8
- # you may not use this file except in compliance with the License.
9
- # You may obtain a copy of the License at
10
- #
11
- # http://www.apache.org/licenses/LICENSE-2.0
12
- #
13
- # Unless required by applicable law or agreed to in writing, software
14
- # distributed under the License is distributed on an "AS IS" BASIS,
15
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
- # See the License for the specific language governing permissions and
17
- # limitations under the License.
18
-
19
- import os, sys
20
- import torch
21
- torch.backends.cudnn.benchmark = False
22
  import numpy as np
23
- import warnings
24
- warnings.filterwarnings('ignore')
25
- import warnings
26
  import torch
27
- from tqdm.auto import tqdm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  from PIL import Image
29
- import torch
30
- from movie_util import MovieSaver, concatenate_movies
31
- from typing import Callable, List, Optional, Union
32
- from latent_blending import get_time, yml_save, LatentBlending, add_frames_linear_interp, compare_dicts
33
- from stable_diffusion_holder import StableDiffusionHolder
34
- torch.set_grad_enabled(False)
35
- import gradio as gr
36
- import copy
37
- from dotenv import find_dotenv, load_dotenv
38
- import shutil
39
-
40
-
41
- #%%
42
-
43
- class BlendingFrontend():
44
- def __init__(self, sdh=None):
45
- self.num_inference_steps = 30
46
- if sdh is None:
47
- self.use_debug = True
48
- self.height = 768
49
- self.width = 768
50
- else:
51
- self.use_debug = False
52
- self.lb = LatentBlending(sdh)
53
- self.lb.sdh.num_inference_steps = self.num_inference_steps
54
- self.height = self.lb.sdh.height
55
- self.width = self.lb.sdh.width
56
-
57
- self.init_save_dir()
58
- self.save_empty_image()
59
- self.share = True
60
- self.transition_can_be_computed = False
61
- self.depth_strength = 0.25
62
- self.seed1 = 420
63
- self.seed2 = 420
64
- self.guidance_scale = 4.0
65
- self.guidance_scale_mid_damper = 0.5
66
- self.mid_compression_scaler = 1.2
67
- self.prompt1 = ""
68
- self.prompt2 = ""
69
- self.negative_prompt = ""
70
- self.state_current = {}
71
- self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
72
- self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
73
- self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
74
- self.parental_crossfeed_power = self.lb.parental_crossfeed_power
75
- self.parental_crossfeed_range = self.lb.parental_crossfeed_range
76
- self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
77
- self.fps = 30
78
- self.duration_video = 10
79
- self.t_compute_max_allowed = 10
80
- self.list_fp_imgs_current = []
81
- self.current_timestamp = None
82
- self.recycle_img1 = False
83
- self.recycle_img2 = False
84
- self.fp_img1 = None
85
- self.fp_img2 = None
86
- self.multi_idx_current = -1
87
- self.list_imgs_shown_last = 5*[self.fp_img_empty]
88
- self.list_all_segments = []
89
- self.dp_session = ""
90
-
91
-
92
- def init_save_dir(self):
93
- load_dotenv(find_dotenv(), verbose=False)
94
- self.dp_out = os.getenv("DIR_OUT")
95
- if self.dp_out is None:
96
- self.dp_out = ""
97
- self.dp_imgs = os.path.join(self.dp_out, "imgs")
98
- os.makedirs(self.dp_imgs, exist_ok=True)
99
- self.dp_movies = os.path.join(self.dp_out, "movies")
100
- os.makedirs(self.dp_movies, exist_ok=True)
101
-
102
-
103
- # make dummy image
104
- def save_empty_image(self):
105
- self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
106
- Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
107
-
108
-
109
- def randomize_seed1(self):
110
- # Dont randomize seed if we are in a multi concat mode. we don't want to change this one otherwise the movie breaks
111
- if len(self.list_all_segments) > 0:
112
- seed = self.seed1
113
- else:
114
- seed = np.random.randint(0, 10000000)
115
- self.seed1 = int(seed)
116
- print(f"randomize_seed1: new seed = {self.seed1}")
117
- return seed
118
-
119
- def randomize_seed2(self):
120
- seed = np.random.randint(0, 10000000)
121
- self.seed2 = int(seed)
122
- print(f"randomize_seed2: new seed = {self.seed2}")
123
- return seed
124
-
125
-
126
- def setup_lb(self, list_ui_elem):
127
- # Collect latent blending variables
128
- self.state_current = self.get_state_dict()
129
- self.lb.set_width(list_ui_elem[list_ui_keys.index('width')])
130
- self.lb.set_height(list_ui_elem[list_ui_keys.index('height')])
131
- self.lb.set_prompt1(list_ui_elem[list_ui_keys.index('prompt1')])
132
- self.lb.set_prompt2(list_ui_elem[list_ui_keys.index('prompt2')])
133
- self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
134
- self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
135
- self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
136
- self.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
137
- self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
138
- self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
139
- self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')]
140
- self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')] #seed
141
- self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
142
-
143
- self.lb.branch1_crossfeed_power = list_ui_elem[list_ui_keys.index('branch1_crossfeed_power')]
144
- self.lb.branch1_crossfeed_range = list_ui_elem[list_ui_keys.index('branch1_crossfeed_range')]
145
- self.lb.branch1_crossfeed_decay = list_ui_elem[list_ui_keys.index('branch1_crossfeed_decay')]
146
- self.lb.parental_crossfeed_power = list_ui_elem[list_ui_keys.index('parental_crossfeed_power')]
147
- self.lb.parental_crossfeed_range = list_ui_elem[list_ui_keys.index('parental_crossfeed_range')]
148
- self.lb.parental_crossfeed_power_decay = list_ui_elem[list_ui_keys.index('parental_crossfeed_power_decay')]
149
- self.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
150
- self.depth_strength = list_ui_elem[list_ui_keys.index('depth_strength')]
151
-
152
-
153
- def compute_img1(self, *args):
154
- list_ui_elem = args
155
- self.setup_lb(list_ui_elem)
156
- self.fp_img1 = os.path.join(self.dp_imgs, f"img1_{get_time('second')}.jpg")
157
- img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
158
- img1.save(self.fp_img1)
159
- self.recycle_img1 = True
160
- self.recycle_img2 = False
161
- return [self.fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
162
-
163
- def compute_img2(self, *args):
164
- if self.fp_img1 is None: # don't do anything
165
- return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
166
- list_ui_elem = args
167
- self.setup_lb(list_ui_elem)
168
- self.fp_img2 = os.path.join(self.dp_imgs, f"img2_{get_time('second')}.jpg")
169
- img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
170
- img2.save(self.fp_img2)
171
- self.recycle_img2 = True
172
- self.transition_can_be_computed = True
173
- return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img2]
174
-
175
- def compute_transition(self, *args):
176
-
177
- if not self.transition_can_be_computed:
178
- list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
179
- return list_return
180
-
181
- list_ui_elem = args
182
- self.setup_lb(list_ui_elem)
183
- print("STARTING TRANSITION...")
184
- if self.use_debug:
185
- list_imgs = [(255*np.random.rand(self.height,self.width,3)).astype(np.uint8) for l in range(5)]
186
- list_imgs = [Image.fromarray(l) for l in list_imgs]
187
- print("DONE! SENDING BACK RESULTS")
188
- return list_imgs
189
-
190
- fixed_seeds = [self.seed1, self.seed2]
191
-
192
- # Run Latent Blending
193
- imgs_transition = self.lb.run_transition(
194
- recycle_img1=self.recycle_img1,
195
- recycle_img2=self.recycle_img2,
196
- num_inference_steps=self.num_inference_steps,
197
- depth_strength=self.depth_strength,
198
- t_compute_max_allowed=self.t_compute_max_allowed,
199
- fixed_seeds=fixed_seeds
200
- )
201
- print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images")
202
-
203
- # Subselect three preview images
204
- idx_img_prev = np.round(np.linspace(0, len(imgs_transition)-1, 5)[1:-1]).astype(np.int32)
205
- list_imgs_preview = []
206
- for j in idx_img_prev:
207
- list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
208
-
209
- # Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
210
- self.current_timestamp = get_time('second')
211
- self.list_fp_imgs_current = []
212
- for i in range(len(list_imgs_preview)):
213
- fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg")
214
- list_imgs_preview[i].save(fp_img)
215
- self.list_fp_imgs_current.append(fp_img)
216
-
217
- # Insert cheap frames for the movie
218
- imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
219
-
220
- # Save as movie
221
- self.fp_movie = os.path.join(self.dp_movies, f"movie_{self.current_timestamp}.mp4")
222
- if os.path.isfile(self.fp_movie):
223
- os.remove(self.fp_movie)
224
- ms = MovieSaver(self.fp_movie, fps=self.fps)
225
- for img in tqdm(imgs_transition_ext):
226
- ms.write_frame(img)
227
- ms.finalize()
228
- print("DONE SAVING MOVIE! SENDING BACK...")
229
-
230
- # Assemble Output, updating the preview images and le movie
231
- list_return = self.list_fp_imgs_current + [self.fp_movie]
232
- return list_return
233
-
234
-
235
- def stack_forward(self, prompt2, seed2):
236
- # Save preview images, prompts and seeds into dictionary for stacking
237
- if len(self.list_all_segments) == 0:
238
- timestamp_session = get_time('second')
239
- self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
240
- os.makedirs(self.dp_session)
241
-
242
- self.transition_can_be_computed = False
243
-
244
- idx_segment = len(self.list_all_segments)
245
- dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
246
-
247
- self.list_all_segments.append(dp_segment)
248
- self.lb.write_imgs_transition(dp_segment)
249
- shutil.copyfile(self.fp_movie, os.path.join(dp_segment, "movie.mp4"))
250
-
251
- self.lb.swap_forward()
252
- fp_multi = self.multi_concat()
253
- list_out = [fp_multi]
254
- list_out.extend([self.fp_img2])
255
- list_out.extend([self.fp_img_empty]*4)
256
- list_out.append(gr.update(interactive=False, value=prompt2))
257
- list_out.append(gr.update(interactive=False, value=seed2))
258
- list_out.append("")
259
- list_out.append(np.random.randint(0, 10000000))
260
- print(f"stack_forward: fp_multi {fp_multi}")
261
- return list_out
262
-
263
-
264
- def multi_concat(self):
265
- list_fp_movies = []
266
- for dp_segment in self.list_all_segments:
267
- list_fp_movies.append(os.path.join(dp_segment, "movie.mp4"))
268
-
269
- # Concatenate movies and save
270
- fp_final = os.path.join(self.dp_session, "movie.mp4")
271
- concatenate_movies(fp_final, list_fp_movies)
272
- return fp_final
273
-
274
- def get_state_dict(self):
275
- state_dict = {}
276
- grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
277
- 'num_inference_steps', 'depth_strength', 'guidance_scale',
278
- 'guidance_scale_mid_damper', 'mid_compression_scaler']
279
-
280
- for v in grab_vars:
281
- state_dict[v] = getattr(self, v)
282
- return state_dict
283
-
284
-
285
-
286
- if __name__ == "__main__":
287
-
288
- # fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
289
- fp_ckpt = "v2-1_512-ema-pruned.ckpt"
290
- bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
291
- # self = BlendingFrontend(None)
292
-
293
- with gr.Blocks() as demo:
294
- with gr.Tab("Single Transition"):
295
- with gr.Row():
296
- prompt1 = gr.Textbox(label="prompt 1")
297
- prompt2 = gr.Textbox(label="prompt 2")
298
-
299
- with gr.Row():
300
- duration_compute = gr.Slider(5, 200, bf.t_compute_max_allowed, step=1, label='compute budget for transition (seconds)', interactive=True)
301
- duration_video = gr.Slider(1, 100, bf.duration_video, step=0.1, label='result video duration (seconds)', interactive=True)
302
- height = gr.Slider(256, 2048, bf.height, step=128, label='height', interactive=True)
303
- width = gr.Slider(256, 2048, bf.width, step=128, label='width', interactive=True)
304
-
305
- with gr.Accordion("Advanced Settings (click to expand)", open=False):
306
-
307
- with gr.Accordion("Diffusion settings", open=True):
308
- with gr.Row():
309
- num_inference_steps = gr.Slider(5, 100, bf.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
310
- guidance_scale = gr.Slider(1, 25, bf.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
311
- negative_prompt = gr.Textbox(label="negative prompt")
312
-
313
- with gr.Accordion("Seed control: adjust seeds for first and last images", open=True):
314
- with gr.Row():
315
- b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
316
- seed1 = gr.Number(bf.seed1, label="seed 1", interactive=True)
317
- seed2 = gr.Number(bf.seed2, label="seed 2", interactive=True)
318
- b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
319
-
320
- with gr.Accordion("Last image crossfeeding.", open=True):
321
- with gr.Row():
322
- branch1_crossfeed_power = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_power, step=0.01, label='branch1 crossfeed power', interactive=True)
323
- branch1_crossfeed_range = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_range, step=0.01, label='branch1 crossfeed range', interactive=True)
324
- branch1_crossfeed_decay = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_decay, step=0.01, label='branch1 crossfeed decay', interactive=True)
325
-
326
- with gr.Accordion("Transition settings", open=True):
327
- with gr.Row():
328
- parental_crossfeed_power = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power, step=0.01, label='parental crossfeed power', interactive=True)
329
- parental_crossfeed_range = gr.Slider(0.0, 1.0, bf.parental_crossfeed_range, step=0.01, label='parental crossfeed range', interactive=True)
330
- parental_crossfeed_power_decay = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power_decay, step=0.01, label='parental crossfeed decay', interactive=True)
331
- with gr.Row():
332
- depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
333
- guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
334
-
335
-
336
- with gr.Row():
337
- b_compute1 = gr.Button('compute first image', variant='primary')
338
- b_compute_transition = gr.Button('compute transition', variant='primary')
339
- b_compute2 = gr.Button('compute last image', variant='primary')
340
-
341
- with gr.Row():
342
- img1 = gr.Image(label="1/5")
343
- img2 = gr.Image(label="2/5", show_progress=False)
344
- img3 = gr.Image(label="3/5", show_progress=False)
345
- img4 = gr.Image(label="4/5", show_progress=False)
346
- img5 = gr.Image(label="5/5")
347
-
348
- with gr.Row():
349
- vid_single = gr.Video(label="single trans")
350
- vid_multi = gr.Video(label="multi trans")
351
-
352
- with gr.Row():
353
- # b_restart = gr.Button("RESTART EVERYTHING")
354
- b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
355
-
356
-
357
- # Collect all UI elemts in list to easily pass as inputs in gradio
358
- dict_ui_elem = {}
359
- dict_ui_elem["prompt1"] = prompt1
360
- dict_ui_elem["negative_prompt"] = negative_prompt
361
- dict_ui_elem["prompt2"] = prompt2
362
-
363
- dict_ui_elem["duration_compute"] = duration_compute
364
- dict_ui_elem["duration_video"] = duration_video
365
- dict_ui_elem["height"] = height
366
- dict_ui_elem["width"] = width
367
-
368
- dict_ui_elem["depth_strength"] = depth_strength
369
- dict_ui_elem["branch1_crossfeed_power"] = branch1_crossfeed_power
370
- dict_ui_elem["branch1_crossfeed_range"] = branch1_crossfeed_range
371
- dict_ui_elem["branch1_crossfeed_decay"] = branch1_crossfeed_decay
372
-
373
- dict_ui_elem["num_inference_steps"] = num_inference_steps
374
- dict_ui_elem["guidance_scale"] = guidance_scale
375
- dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
376
- dict_ui_elem["seed1"] = seed1
377
- dict_ui_elem["seed2"] = seed2
378
-
379
- dict_ui_elem["parental_crossfeed_range"] = parental_crossfeed_range
380
- dict_ui_elem["parental_crossfeed_power"] = parental_crossfeed_power
381
- dict_ui_elem["parental_crossfeed_power_decay"] = parental_crossfeed_power_decay
382
-
383
- # Convert to list, as gradio doesn't seem to accept dicts
384
- list_ui_elem = []
385
- list_ui_keys = []
386
- for k in dict_ui_elem.keys():
387
- list_ui_elem.append(dict_ui_elem[k])
388
- list_ui_keys.append(k)
389
- bf.list_ui_keys = list_ui_keys
390
-
391
- b_newseed1.click(bf.randomize_seed1, outputs=seed1)
392
- b_newseed2.click(bf.randomize_seed2, outputs=seed2)
393
- b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5])
394
- b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5])
395
- b_compute_transition.click(bf.compute_transition,
396
- inputs=list_ui_elem,
397
- outputs=[img2, img3, img4, vid_single])
398
-
399
- b_stackforward.click(bf.stack_forward,
400
- inputs=[prompt2, seed2],
401
- outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
402
-
403
-
404
- demo.launch(share=bf.share, inbrowser=True, inline=False)
 
1
+ import random
2
+ import tempfile
3
+ import time
4
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import numpy as np
 
 
 
6
  import torch
7
+ import math
8
+ import re
9
+
10
+ from gradio import inputs
11
+ from diffusers import (
12
+ AutoencoderKL,
13
+ DDIMScheduler,
14
+ UNet2DConditionModel,
15
+ )
16
+ from modules.model import (
17
+ CrossAttnProcessor,
18
+ StableDiffusionPipeline,
19
+ )
20
+ from torchvision import transforms
21
+ from transformers import CLIPTokenizer, CLIPTextModel
22
  from PIL import Image
23
+ from pathlib import Path
24
+ from safetensors.torch import load_file
25
+ import modules.safe as _
26
+ from modules.lora import LoRANetwork