ironjr commited on
Commit
b5de3c9
1 Parent(s): 2042463

aded files without images

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ */.ipynb_checkpoints
README.md CHANGED
@@ -1,13 +1,15 @@
1
  ---
2
  title: StreamMultiDiffusion
3
- emoji: 🌍
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.21.0
8
  app_file: app.py
9
- pinned: false
10
  license: mit
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: StreamMultiDiffusion
3
+ emoji: 🦦🦦🦦🦦
4
+ colorFrom: #feecd6
5
+ colorTo: #732a14
6
  sdk: gradio
7
+ sdk_version: 4.26.0
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
+ models:
12
+ - KBlueLeaf/kohaku-v2.1
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ import sys
22
+
23
+ sys.path.append('../../src')
24
+
25
+ import argparse
26
+ import random
27
+ import time
28
+ import json
29
+ import os
30
+ import glob
31
+ import pathlib
32
+ from functools import partial
33
+ from pprint import pprint
34
+
35
+ import numpy as np
36
+ from PIL import Image
37
+ import torch
38
+
39
+ import gradio as gr
40
+ from huggingface_hub import snapshot_download
41
+
42
+ # from model import StreamMultiDiffusionSDXL
43
+ from model import StreamMultiDiffusion
44
+ from util import seed_everything
45
+ from prompt_util import preprocess_prompts, _quality_dict, _style_dict
46
+
47
+
48
+ ### Utils
49
+
50
+
51
+
52
+
53
+ def log_state(state):
54
+ pprint(vars(opt))
55
+ if isinstance(state, gr.State):
56
+ state = state.value
57
+ pprint(vars(state))
58
+
59
+
60
+ def is_empty_image(im: Image.Image) -> bool:
61
+ if im is None:
62
+ return True
63
+ im = np.array(im)
64
+ has_alpha = (im.shape[2] == 4)
65
+ if not has_alpha:
66
+ return False
67
+ elif im.sum() == 0:
68
+ return True
69
+ else:
70
+ return False
71
+
72
+
73
+ ### Argument passing
74
+
75
+ # parser = argparse.ArgumentParser(description='Semantic Palette demo powered by StreamMultiDiffusion with SDXL support.')
76
+ # parser.add_argument('-H', '--height', type=int, default=1024)
77
+ # parser.add_argument('-W', '--width', type=int, default=1024)
78
+ parser = argparse.ArgumentParser(description='Semantic Palette demo powered by StreamMultiDiffusion.')
79
+ parser.add_argument('-H', '--height', type=int, default=768)
80
+ parser.add_argument('-W', '--width', type=int, default=768)
81
+ parser.add_argument('--model', type=str, default=None, help='Hugging face model repository or local path for a SD1.5 model checkpoint to run.')
82
+ parser.add_argument('--bootstrap_steps', type=int, default=1)
83
+ parser.add_argument('--guidance_scale', type=float, default=0) # 1.2
84
+ parser.add_argument('--run_time', type=float, default=60)
85
+ parser.add_argument('--seed', type=int, default=-1)
86
+ parser.add_argument('--device', type=int, default=0)
87
+ parser.add_argument('--port', type=int, default=8000)
88
+ opt = parser.parse_args()
89
+
90
+
91
+ ### Global variables and data structures
92
+
93
+ device = f'cuda:{opt.device}' if opt.device >= 0 else 'cpu'
94
+
95
+
96
+ if opt.model is None:
97
+ # opt.model = 'cagliostrolab/animagine-xl-3.1'
98
+ # opt.model = 'ironjr/BlazingDriveV11m'
99
+ opt.model = 'KBlueLeaf/kohaku-v2.1'
100
+ else:
101
+ if opt.model.endswith('.safetensors'):
102
+ opt.model = os.path.abspath(os.path.join('checkpoints', opt.model))
103
+
104
+ # model = StreamMultiDiffusionSDXL(
105
+ model = StreamMultiDiffusion(
106
+ device,
107
+ hf_key=opt.model,
108
+ height=opt.height,
109
+ width=opt.width,
110
+ cfg_type="full",
111
+ autoflush=True,
112
+ use_tiny_vae=True,
113
+ mask_type='continuous',
114
+ bootstrap_steps=opt.bootstrap_steps,
115
+ bootstrap_mix_steps=opt.bootstrap_steps,
116
+ guidance_scale=opt.guidance_scale,
117
+ seed=opt.seed,
118
+ )
119
+
120
+
121
+ prompt_suggestions = [
122
+ '1girl, souryuu asuka langley, neon genesis evangelion, solo, upper body, v, smile, looking at viewer',
123
+ '1boy, solo, portrait, looking at viewer, white t-shirt, brown hair',
124
+ '1girl, arima kana, oshi no ko, solo, upper body, from behind',
125
+ ]
126
+
127
+ opt.max_palettes = 3
128
+ opt.default_prompt_strength = 1.0
129
+ opt.default_mask_strength = 1.0
130
+ opt.default_mask_std = 0.0
131
+ opt.default_negative_prompt = (
132
+ 'nsfw, worst quality, bad quality, normal quality, cropped, framed'
133
+ )
134
+ opt.verbose = True
135
+ opt.colors = [
136
+ '#000000',
137
+ '#2692F3',
138
+ '#F89E12',
139
+ '#16C232',
140
+ # '#F92F6C',
141
+ # '#AC6AEB',
142
+ # '#92C62C',
143
+ # '#92C6EC',
144
+ # '#FECAC0',
145
+ ]
146
+
147
+
148
+ ### Event handlers
149
+
150
+ def add_palette(state):
151
+ old_actives = state.active_palettes
152
+ state.active_palettes = min(state.active_palettes + 1, opt.max_palettes)
153
+
154
+ if opt.verbose:
155
+ log_state(state)
156
+
157
+ if state.active_palettes != old_actives:
158
+ return [state] + [
159
+ gr.update() if state.active_palettes != opt.max_palettes else gr.update(visible=False)
160
+ ] + [
161
+ gr.update() if i != state.active_palettes - 1 else gr.update(value=state.prompt_names[i + 1], visible=True)
162
+ for i in range(opt.max_palettes)
163
+ ]
164
+ else:
165
+ return [state] + [gr.update() for i in range(opt.max_palettes + 1)]
166
+
167
+
168
+ def select_palette(state, button, idx):
169
+ if idx < 0 or idx > opt.max_palettes:
170
+ idx = 0
171
+ old_idx = state.current_palette
172
+ if old_idx == idx:
173
+ return [state] + [gr.update() for _ in range(opt.max_palettes + 7)]
174
+
175
+ state.current_palette = idx
176
+
177
+ if opt.verbose:
178
+ log_state(state)
179
+
180
+ updates = [state] + [
181
+ gr.update() if i not in (idx, old_idx) else
182
+ gr.update(variant='secondary') if i == old_idx else gr.update(variant='primary')
183
+ for i in range(opt.max_palettes + 1)
184
+ ]
185
+ label = 'Background' if idx == 0 else f'Palette {idx}'
186
+ updates.extend([
187
+ gr.update(value=button, interactive=(idx > 0)),
188
+ gr.update(value=state.prompts[idx], label=f'Edit Prompt for {label}'),
189
+ gr.update(value=state.neg_prompts[idx], label=f'Edit Negative Prompt for {label}'),
190
+ (
191
+ gr.update(value=state.mask_strengths[idx - 1], interactive=True) if idx > 0 else
192
+ gr.update(value=opt.default_mask_strength, interactive=False)
193
+ ),
194
+ (
195
+ gr.update(value=state.prompt_strengths[idx - 1], interactive=True) if idx > 0 else
196
+ gr.update(value=opt.default_prompt_strength, interactive=False)
197
+ ),
198
+ (
199
+ gr.update(value=state.mask_stds[idx - 1], interactive=True) if idx > 0 else
200
+ gr.update(value=opt.default_mask_std, interactive=False)
201
+ ),
202
+ ])
203
+ return updates
204
+
205
+
206
+ def change_prompt_strength(state, strength):
207
+ if state.current_palette == 0:
208
+ return state
209
+
210
+ state.prompt_strengths[state.current_palette - 1] = strength
211
+ if opt.verbose:
212
+ log_state(state)
213
+
214
+ return state
215
+
216
+
217
+ def change_std(state, std):
218
+ if state.current_palette == 0:
219
+ return state
220
+
221
+ state.mask_stds[state.current_palette - 1] = std
222
+ if opt.verbose:
223
+ log_state(state)
224
+
225
+ return state
226
+
227
+
228
+ def change_mask_strength(state, strength):
229
+ if state.current_palette == 0:
230
+ return state
231
+
232
+ state.mask_strengths[state.current_palette - 1] = strength
233
+ if opt.verbose:
234
+ log_state(state)
235
+
236
+ return state
237
+
238
+
239
+ def reset_seed(state, seed):
240
+ state.seed = seed
241
+ if opt.verbose:
242
+ log_state(state)
243
+
244
+ return state
245
+
246
+
247
+ def rename_prompt(state, name):
248
+ state.prompt_names[state.current_palette] = name
249
+ if opt.verbose:
250
+ log_state(state)
251
+
252
+ return [state] + [
253
+ gr.update() if i != state.current_palette else gr.update(value=name)
254
+ for i in range(opt.max_palettes + 1)
255
+ ]
256
+
257
+
258
+ def change_prompt(state, prompt):
259
+ state.prompts[state.current_palette] = prompt
260
+ if opt.verbose:
261
+ log_state(state)
262
+
263
+ return state
264
+
265
+
266
+ def change_neg_prompt(state, neg_prompt):
267
+ state.neg_prompts[state.current_palette] = neg_prompt
268
+ if opt.verbose:
269
+ log_state(state)
270
+
271
+ return state
272
+
273
+
274
+ # def select_style(state, style_name):
275
+ # state.style_name = style_name
276
+ # if opt.verbose:
277
+ # log_state(state)
278
+
279
+ # return state
280
+
281
+
282
+ # def select_quality(state, quality_name):
283
+ # state.quality_name = quality_name
284
+ # if opt.verbose:
285
+ # log_state(state)
286
+
287
+ # return state
288
+
289
+
290
+ def import_state(state, json_text):
291
+ current_palette = state.current_palette
292
+ # active_palettes = state.active_palettes
293
+ state_dict = json.loads(json_text)
294
+ for k in ('inpainting_mode', 'is_runing', 'active_palettes', 'current_palette'):
295
+ if k in state_dict:
296
+ del state_dict[k]
297
+ state = argparse.Namespace(**state_dict)
298
+ state.active_palettes = opt.max_palettes
299
+ return [state] + [
300
+ gr.update(value=v, visible=True) for v in state.prompt_names
301
+ ] + [
302
+ # state.style_name,
303
+ # state.quality_name,
304
+ state.prompts[current_palette],
305
+ state.prompt_names[current_palette],
306
+ state.neg_prompts[current_palette],
307
+ state.prompt_strengths[current_palette - 1],
308
+ state.mask_strengths[current_palette - 1],
309
+ state.mask_stds[current_palette - 1],
310
+ state.seed,
311
+ ]
312
+
313
+
314
+ ### Main worker
315
+
316
+ def generate():
317
+ return model()
318
+
319
+
320
+ def register(state, drawpad):
321
+ seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
322
+ print('Generate!')
323
+
324
+ background = drawpad['background'].convert('RGBA')
325
+ inpainting_mode = np.asarray(background).sum() != 0
326
+ if not inpainting_mode:
327
+ background = Image.new(size=(opt.width, opt.height), mode='RGB', color=(255, 255, 255))
328
+ print('Inpainting mode: ', inpainting_mode)
329
+
330
+ user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
331
+ foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
332
+ user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
333
+
334
+ palette = torch.tensor([
335
+ tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
336
+ for s in opt.colors[1:]
337
+ ]) # (N, 3)
338
+ masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
339
+ # has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
340
+ has_masks = list(range(opt.max_palettes))
341
+ print('Has mask: ', has_masks)
342
+ masks = masks * foreground_mask
343
+ masks = masks[has_masks]
344
+
345
+ # if inpainting_mode:
346
+ prompts = [state.prompts[v + 1] for v in has_masks]
347
+ negative_prompts = [state.neg_prompts[v + 1] for v in has_masks]
348
+ mask_strengths = [state.mask_strengths[v] for v in has_masks]
349
+ mask_stds = [state.mask_stds[v] for v in has_masks]
350
+ prompt_strengths = [state.prompt_strengths[v] for v in has_masks]
351
+ # else:
352
+ # masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
353
+ # prompts = [state.prompts[0]] + [state.prompts[v + 1] for v in has_masks]
354
+ # negative_prompts = [state.neg_prompts[0]] + [state.neg_prompts[v + 1] for v in has_masks]
355
+ # mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
356
+ # mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
357
+ # prompt_strengths = [1] + [state.prompt_strengths[v] for v in has_masks]
358
+
359
+ # prompts, negative_prompts = preprocess_prompts(
360
+ # prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
361
+
362
+ model.update_background(
363
+ background.convert('RGB'),
364
+ prompt=None,
365
+ negative_prompt=None,
366
+ )
367
+ state.prompts[0] = model.background.prompt
368
+ state.neg_prompts[0] = model.background.negative_prompt
369
+
370
+ model.update_layers(
371
+ prompts=prompts,
372
+ negative_prompts=negative_prompts,
373
+ masks=masks.to(device),
374
+ mask_strengths=mask_strengths,
375
+ mask_stds=mask_stds,
376
+ prompt_strengths=prompt_strengths,
377
+ )
378
+
379
+ state.inpainting_mode = inpainting_mode
380
+ return state
381
+
382
+
383
+ def run(state, drawpad):
384
+ state = register(state, drawpad)
385
+ state.is_running = True
386
+
387
+ tic = time.time()
388
+ while True:
389
+ yield [state, generate()]
390
+ toc = time.time()
391
+ tdelta = toc - tic
392
+ if tdelta > opt.run_time:
393
+ state.is_running = False
394
+ return [state, generate()]
395
+
396
+
397
+ def hide_element():
398
+ return gr.update(visible=False)
399
+
400
+
401
+ def show_element():
402
+ return gr.update(visible=True)
403
+
404
+
405
+ def draw(state, drawpad):
406
+ if not state.is_running:
407
+ return
408
+
409
+ user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
410
+ foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
411
+ user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
412
+
413
+ palette = torch.tensor([
414
+ tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
415
+ for s in opt.colors[1:]
416
+ ]) # (N, 3)
417
+ masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
418
+ # has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
419
+ has_masks = list(range(opt.max_palettes))
420
+ print('Has mask: ', has_masks)
421
+ masks = masks * foreground_mask
422
+ masks = masks[has_masks]
423
+
424
+ # if state.inpainting_mode:
425
+ mask_strengths = [state.mask_strengths[v] for v in has_masks]
426
+ mask_stds = [state.mask_stds[v] for v in has_masks]
427
+ # else:
428
+ # masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
429
+ # mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
430
+ # mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
431
+
432
+ for i in range(len(has_masks)):
433
+ model.update_single_layer(
434
+ idx=i,
435
+ mask=masks[i],
436
+ mask_strength=mask_strengths[i],
437
+ mask_std=mask_stds[i],
438
+ )
439
+
440
+ ### Load examples
441
+
442
+
443
+ root = pathlib.Path(__file__).parent
444
+ print(root)
445
+ example_root = os.path.join(root, 'examples')
446
+ example_images = glob.glob(os.path.join(example_root, '*.png'))
447
+ example_images = [Image.open(i) for i in example_images]
448
+
449
+ # with open(os.path.join(example_root, 'prompt_background_advanced.txt')) as f:
450
+ # prompts_background = [l.strip() for l in f.readlines() if l.strip() != '']
451
+
452
+ # with open(os.path.join(example_root, 'prompt_girl.txt')) as f:
453
+ # prompts_girl = [l.strip() for l in f.readlines() if l.strip() != '']
454
+
455
+ # with open(os.path.join(example_root, 'prompt_boy.txt')) as f:
456
+ # prompts_boy = [l.strip() for l in f.readlines() if l.strip() != '']
457
+
458
+ # with open(os.path.join(example_root, 'prompt_props.txt')) as f:
459
+ # prompts_props = [l.strip() for l in f.readlines() if l.strip() != '']
460
+ # prompts_props = {l.split(',')[0].strip(): ','.join(l.split(',')[1:]).strip() for l in prompts_props}
461
+
462
+ # prompt_background = lambda: random.choice(prompts_background)
463
+ # prompt_girl = lambda: random.choice(prompts_girl)
464
+ # prompt_boy = lambda: random.choice(prompts_boy)
465
+ # prompt_props = lambda: np.random.choice(list(prompts_props.keys()), size=(opt.max_palettes - 2), replace=False).tolist()
466
+
467
+
468
+ ### Main application
469
+
470
+ css = f"""
471
+ #run-button {{
472
+ font-size: 18pt;
473
+ background-image: linear-gradient(to right, #4338ca 0%, #26a0da 51%, #4338ca 100%);
474
+ margin: 0;
475
+ padding: 15px 45px;
476
+ text-align: center;
477
+ // text-transform: uppercase;
478
+ transition: 0.5s;
479
+ background-size: 200% auto;
480
+ color: white;
481
+ box-shadow: 0 0 20px #eee;
482
+ border-radius: 10px;
483
+ // display: block;
484
+ background-position: right center;
485
+ }}
486
+
487
+ #run-button:hover {{
488
+ background-position: left center;
489
+ color: #fff;
490
+ text-decoration: none;
491
+ }}
492
+
493
+ #run-anim {{
494
+ padding: 40px 45px;
495
+ }}
496
+
497
+ #semantic-palette {{
498
+ border-style: solid;
499
+ border-width: 0.2em;
500
+ border-color: #eee;
501
+ }}
502
+
503
+ #semantic-palette:hover {{
504
+ box-shadow: 0 0 20px #eee;
505
+ }}
506
+
507
+ #output-screen {{
508
+ width: 100%;
509
+ aspect-ratio: {opt.width} / {opt.height};
510
+ }}
511
+
512
+ .layer-wrap {{
513
+ display: none;
514
+ }}
515
+ """
516
+
517
+ for i in range(opt.max_palettes + 1):
518
+ css = css + f"""
519
+ .secondary#semantic-palette-{i} {{
520
+ background-image: linear-gradient(to right, #374151 0%, #374151 71%, {opt.colors[i]} 100%);
521
+ color: white;
522
+ }}
523
+
524
+ .primary#semantic-palette-{i} {{
525
+ background-image: linear-gradient(to right, #4338ca 0%, #4338ca 71%, {opt.colors[i]} 100%);
526
+ color: white;
527
+ }}
528
+ """
529
+
530
+ css = css + f"""
531
+
532
+ .mask-red {{
533
+ left: 0;
534
+ width: 0;
535
+ color: #BE002A;
536
+ -webkit-animation: text-red {opt.run_time:.1f}s ease infinite;
537
+ animation: text-red {opt.run_time:.1f}s ease infinite;
538
+ z-index: 2;
539
+ background: transparent;
540
+ }}
541
+ .mask-white {{
542
+ right: 0;
543
+ }}
544
+
545
+ /* Flames */
546
+
547
+ #red-flame {{
548
+ opacity: 0;
549
+ -webkit-animation: show-flames {opt.run_time:.1f}s ease infinite, red-flame 120ms ease infinite;
550
+ animation: show-flames {opt.run_time:.1f}s ease infinite, red-flame 120ms ease infinite;
551
+ transform-origin: center bottom;
552
+ }}
553
+
554
+ #yellow-flame {{
555
+ opacity: 0;
556
+ -webkit-animation: show-flames {opt.run_time:.1f}s ease infinite, yellow-flame 120ms ease infinite;
557
+ animation: show-flames {opt.run_time:.1f}s ease infinite, yellow-flame 120ms ease infinite;
558
+ transform-origin: center bottom;
559
+ }}
560
+
561
+ #white-flame {{
562
+ opacity: 0;
563
+ -webkit-animation: show-flames {opt.run_time:.1f}s ease infinite, red-flame 100ms ease infinite;
564
+ animation: show-flames {opt.run_time:.1f}s ease infinite, red-flame 100ms ease infinite;
565
+ transform-origin: center bottom;
566
+ }}
567
+ """
568
+
569
+ with open(os.path.join(root, 'timer', 'style.css')) as f:
570
+ added_css = ''.join(f.readlines())
571
+ css = css + added_css
572
+
573
+ # js = ''
574
+
575
+ # with open(os.path.join(root, 'timer', 'script.js')) as f:
576
+ # added_js = ''.join(f.readlines())
577
+ # js = js + added_js
578
+
579
+ head = f"""
580
+ <link href='https://fonts.googleapis.com/css?family=Oswald' rel='stylesheet' type='text/css'>
581
+ <script src='https://code.jquery.com/jquery-2.2.4.min.js'></script>
582
+ """
583
+
584
+
585
+ with gr.Blocks(theme=gr.themes.Soft(), css=css, head=head) as demo:
586
+
587
+ iface = argparse.Namespace()
588
+
589
+ def _define_state():
590
+ state = argparse.Namespace()
591
+
592
+ # Cursor.
593
+ state.is_running = False
594
+ state.inpainting_mode = False
595
+ state.current_palette = 0 # 0: Background; 1,2,3,...: Layers
596
+ state.model_id = opt.model
597
+ state.style_name = '(None)'
598
+ state.quality_name = 'Standard v3.1'
599
+
600
+ # State variables (one-hot).
601
+ state.active_palettes = 5
602
+
603
+ # Front-end initialized to the default values.
604
+ # prompt_props_ = prompt_props()
605
+ state.prompt_names = [
606
+ '🌄 Background',
607
+ '👧 Girl',
608
+ '🐶 Dog',
609
+ '💐 Garden',
610
+ ] + ['🎨 New Palette' for _ in range(opt.max_palettes - 3)]
611
+ state.prompts = [
612
+ '',
613
+ 'A girl smiling at viewer',
614
+ 'Doggy body part',
615
+ 'Flower garden',
616
+ ] + ['' for _ in range(opt.max_palettes - 3)]
617
+ state.neg_prompts = [
618
+ opt.default_negative_prompt
619
+ + (', humans, humans, humans' if i == 0 else '')
620
+ for i in range(opt.max_palettes + 1)
621
+ ]
622
+ state.prompt_strengths = [opt.default_prompt_strength for _ in range(opt.max_palettes)]
623
+ state.mask_strengths = [opt.default_mask_strength for _ in range(opt.max_palettes)]
624
+ state.mask_stds = [opt.default_mask_std for _ in range(opt.max_palettes)]
625
+ state.seed = opt.seed
626
+ return state
627
+
628
+ state = gr.State(value=_define_state)
629
+
630
+
631
+ ### Demo user interface
632
+
633
+ gr.HTML(
634
+ """
635
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
636
+ <div>
637
+ <h1>🦦🦦 StreamMultiDiffusion: Real-Time Interactive Generation with Region-Based Semantic Control 🦦🦦</h1>
638
+ <h5 style="margin: 0;">If you ❤️ our project, please visit our Github and give us a 🌟!</h5>
639
+ </br>
640
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
641
+ <a href='https://arxiv.org/abs/2403.09055'>
642
+ <img src="https://img.shields.io/badge/arXiv-2403.09055-red">
643
+ </a>
644
+ &nbsp;
645
+ <a href='https://jaerinlee.com/research/StreamMultiDiffusion'>
646
+ <img src='https://img.shields.io/badge/Project-Page-green' alt='Project Page'>
647
+ </a>
648
+ &nbsp;
649
+ <a href='https://github.com/ironjr/StreamMultiDiffusion'>
650
+ <img src='https://img.shields.io/github/stars/ironjr/StreamMultiDiffusion?label=Github&color=blue'>
651
+ </a>
652
+ &nbsp;
653
+ <a href='https://twitter.com/_ironjr_'>
654
+ <img src='https://img.shields.io/twitter/url?label=_ironjr_&url=https%3A%2F%2Ftwitter.com%2F_ironjr_'>
655
+ </a>
656
+ &nbsp;
657
+ <a href='https://github.com/ironjr/StreamMultiDiffusion/blob/main/LICENSE'>
658
+ <img src='https://img.shields.io/badge/license-MIT-lightgrey'>
659
+ </a>
660
+ &nbsp;
661
+ <a href='https://huggingface.co/papers/2403.09055'>
662
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Paper-StreamMultiDiffusion-yellow'>
663
+ </a>
664
+ &nbsp;
665
+ <a href='https://huggingface.co/spaces/ironjr/StreamMultiDiffusion'>
666
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-StreamMultiDiffusion-yellow'>
667
+ </a>
668
+ &nbsp;
669
+ <a href='https://huggingface.co/spaces/ironjr/SemanticPalette'>
670
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SemanticPaletteSD1.5-yellow'>
671
+ </a>
672
+ &nbsp;
673
+ <a href='https://huggingface.co/spaces/ironjr/SemanticPaletteXL'>
674
+ <img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SemanticPaletteSDXL-yellow'>
675
+ </a>
676
+ &nbsp;
677
+ <a href='https://colab.research.google.com/github/camenduru/SemanticPalette-jupyter/blob/main/SemanticPalette_jupyter.ipynb'>
678
+ <img src='https://colab.research.google.com/assets/colab-badge.svg'>
679
+ </a>
680
+ </div>
681
+ </div>
682
+ </div>
683
+ <div>
684
+ </br>
685
+ </div>
686
+ """
687
+ )
688
+
689
+ with gr.Row():
690
+
691
+ with gr.Column(scale=1):
692
+
693
+ with gr.Group(elem_id='semantic-palette'):
694
+
695
+ gr.HTML(
696
+ """
697
+ <div style="justify-content: center; align-items: center;">
698
+ <br/>
699
+ <h3 style="margin: 0; text-align: center;"><b>🧠 Semantic Palette 🎨</b></h3>
700
+ <br/>
701
+ </div>
702
+ """
703
+ )
704
+
705
+ iface.btn_semantics = [gr.Button(
706
+ value=state.value.prompt_names[0],
707
+ variant='primary',
708
+ elem_id='semantic-palette-0',
709
+ )]
710
+ for i in range(opt.max_palettes):
711
+ iface.btn_semantics.append(gr.Button(
712
+ value=state.value.prompt_names[i + 1],
713
+ variant='secondary',
714
+ visible=(i < state.value.active_palettes),
715
+ elem_id=f'semantic-palette-{i + 1}'
716
+ ))
717
+
718
+ iface.btn_add_palette = gr.Button(
719
+ value='Create New Semantic Brush',
720
+ variant='primary',
721
+ visible=(state.value.active_palettes < opt.max_palettes),
722
+ )
723
+
724
+ with gr.Accordion(label='Import/Export Semantic Palette', open=True):
725
+ iface.tbox_state_import = gr.Textbox(label='Put Palette JSON Here To Import')
726
+ iface.json_state_export = gr.JSON(label='Exported Palette')
727
+ iface.btn_export_state = gr.Button("Export Palette ➡️ JSON", variant='primary')
728
+ iface.btn_import_state = gr.Button("Import JSON ➡️ Palette", variant='secondary')
729
+
730
+ gr.HTML(
731
+ """
732
+ <div>
733
+ </br>
734
+ </div>
735
+ <div style="justify-content: center; align-items: center;">
736
+ <h3 style="margin: 0; text-align: center;"><b>❓Usage❓</b></h3>
737
+ </br>
738
+ <div style="justify-content: center; align-items: left; text-align: left;">
739
+ <p>1-1. Type in the background prompt. Background is not required if you paint the whole drawpad.</p>
740
+ <p>1-2. (Optional: <em><b>Inpainting mode</b></em>) Uploading a background image will make the app into inpainting mode. Removing the image returns to the creation mode. In the inpainting mode, increasing the <em>Mask Blur STD</em> > 8 for every colored palette is recommended for smooth boundaries.</p>
741
+ <p>2. Select a semantic brush by clicking onto one in the <b>Semantic Palette</b> above. Edit prompt for the semantic brush.</p>
742
+ <p>2-1. If you are willing to draw more diverse images, try <b>Create New Semantic Brush</b>.</p>
743
+ <p>3. Start drawing in the <b>Semantic Drawpad</b> tab. The brush color is directly linked to the semantic brushes.</p>
744
+ <p>4. Click [<b>GENERATE!</b>] button to create your (large-scale) artwork!</p>
745
+ </div>
746
+ </div>
747
+ """
748
+ )
749
+
750
+ gr.HTML(
751
+ """
752
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
753
+ <h5 style="margin: 0;"><b>... or run in your own 🤗 space!</b></h5>
754
+ </div>
755
+ """
756
+ )
757
+
758
+ gr.DuplicateButton()
759
+
760
+ with gr.Column(scale=4):
761
+
762
+ with gr.Row():
763
+
764
+ with gr.Column(scale=2):
765
+
766
+ iface.ctrl_semantic = gr.ImageEditor(
767
+ image_mode='RGBA',
768
+ sources=['upload', 'clipboard', 'webcam'],
769
+ transforms=['crop'],
770
+ crop_size=(opt.width, opt.height),
771
+ brush=gr.Brush(
772
+ colors=opt.colors[1:],
773
+ color_mode="fixed",
774
+ ),
775
+ type='pil',
776
+ label='Semantic Drawpad',
777
+ elem_id='drawpad',
778
+ )
779
+
780
+ # with gr.Accordion(label='Prompt Engineering', open=False):
781
+ # iface.quality_select = gr.Dropdown(
782
+ # label='Quality Presets',
783
+ # interactive=True,
784
+ # choices=list(_quality_dict.keys()),
785
+ # value='Standard v3.1',
786
+ # )
787
+
788
+ # iface.style_select = gr.Radio(
789
+ # label='Style Preset',
790
+ # container=True,
791
+ # interactive=True,
792
+ # choices=list(_style_dict.keys()),
793
+ # value='(None)',
794
+ # )
795
+
796
+ with gr.Column(scale=2):
797
+
798
+ iface.image_slot = gr.Image(
799
+ interactive=False,
800
+ show_label=False,
801
+ show_download_button=True,
802
+ type='pil',
803
+ label='Generated Result',
804
+ elem_id='output-screen',
805
+ value=lambda: random.choice(example_images),
806
+ )
807
+
808
+ iface.btn_generate = gr.Button(
809
+ value=f'Lemme try! ({int(opt.run_time // 60)} min)',
810
+ variant='primary',
811
+ # scale=1,
812
+ elem_id='run-button'
813
+ )
814
+
815
+ iface.run_animation = gr.HTML(
816
+ f"""
817
+ <div id="deadline">
818
+ <svg preserveAspectRatio="none" id="line" viewBox="0 0 581 158" enable-background="new 0 0 581 158">
819
+ <g id="fire">
820
+ <rect id="mask-fire-black" x="511" y="41" width="38" height="34"/>
821
+ <g>
822
+ <defs>
823
+ <rect id="mask_fire" x="511" y="41" width="38" height="34"/>
824
+ </defs>
825
+ <clipPath id="mask-fire_1_">
826
+ <use xlink:href="#mask_fire" overflow="visible"/>
827
+ </clipPath>
828
+ <g id="group-fire" clip-path="url(#mask-fire_1_)">
829
+ <path id="red-flame" fill="#B71342" d="M528.377,100.291c6.207,0,10.947-3.272,10.834-8.576 c-0.112-5.305-2.934-8.803-8.237-10.383c-5.306-1.581-3.838-7.9-0.79-9.707c-7.337,2.032-7.581,5.891-7.11,8.238 c0.789,3.951,7.56,4.402,5.077,9.48c-2.482,5.079-8.012,1.129-6.319-2.257c-2.843,2.233-4.78,6.681-2.259,9.703 C521.256,98.809,524.175,100.291,528.377,100.291z"/>
830
+ <path id="yellow-flame" opacity="0.71" fill="#F7B523" d="M528.837,100.291c4.197,0,5.108-1.854,5.974-5.417 c0.902-3.724-1.129-6.207-5.305-9.931c-2.396-2.137-1.581-4.176-0.565-6.32c-4.401,1.918-3.384,5.304-2.482,6.658 c1.511,2.267,2.099,2.364,0.42,5.8c-1.679,3.435-5.42,0.764-4.275-1.527c-1.921,1.512-2.373,4.04-1.528,6.563 C522.057,99.051,525.994,100.291,528.837,100.291z"/>
831
+ <path id="white-flame" opacity="0.81" fill="#FFFFFF" d="M529.461,100.291c-2.364,0-4.174-1.322-4.129-3.469 c0.04-2.145,1.117-3.56,3.141-4.198c2.022-0.638,1.463-3.195,0.302-3.925c2.798,0.821,2.89,2.382,2.711,3.332 c-0.301,1.597-2.883,1.779-1.938,3.834c0.912,1.975,3.286,0.938,2.409-0.913c1.086,0.903,1.826,2.701,0.864,3.924 C532.18,99.691,531.064,100.291,529.461,100.291z"/>
832
+ </g>
833
+ </g>
834
+ </g>
835
+ <g id="progress-trail">
836
+ <path fill="#FFFFFF" d="M491.979,83.878c1.215-0.73-0.62-5.404-3.229-11.044c-2.583-5.584-5.034-10.066-7.229-8.878
837
+ c-2.854,1.544-0.192,6.286,2.979,11.628C487.667,80.917,490.667,84.667,491.979,83.878z"/>
838
+ <path fill="#FFFFFF" d="M571,76v-5h-23.608c0.476-9.951-4.642-13.25-4.642-13.25l-3.125,4c0,0,3.726,2.7,3.625,5.125
839
+ c-0.071,1.714-2.711,3.18-4.962,4.125H517v5h10v24h-25v-5.666c0,0,0.839,0,2.839-0.667s6.172-3.667,4.005-6.333
840
+ s-7.49,0.333-9.656,0.166s-6.479-1.5-8.146,1.917c-1.551,3.178,0.791,5.25,5.541,6.083l-0.065,4.5H16c-2.761,0-5,2.238-5,5v17
841
+ c0,2.762,2.239,5,5,5h549c2.762,0,5-2.238,5-5v-17c0-2.762-2.238-5-5-5h-3V76H571z"/>
842
+ <path fill="#FFFFFF" d="M535,65.625c1.125,0.625,2.25-1.125,2.25-1.125l11.625-22.375c0,0,0.75-0.875-1.75-2.125
843
+ s-3.375,0.25-3.375,0.25s-8.75,21.625-9.875,23.5S533.875,65,535,65.625z"/>
844
+ </g>
845
+ <g>
846
+ <defs>
847
+ <path id="SVGID_1_" d="M484.5,75.584c-3.172-5.342-5.833-10.084-2.979-11.628c2.195-1.188,4.646,3.294,7.229,8.878
848
+ c2.609,5.64,4.444,10.313,3.229,11.044C490.667,84.667,487.667,80.917,484.5,75.584z M571,76v-5h-23.608
849
+ c0.476-9.951-4.642-13.25-4.642-13.25l-3.125,4c0,0,3.726,2.7,3.625,5.125c-0.071,1.714-2.711,3.18-4.962,4.125H517v5h10v24h-25
850
+ v-5.666c0,0,0.839,0,2.839-0.667s6.172-3.667,4.005-6.333s-7.49,0.333-9.656,0.166s-6.479-1.5-8.146,1.917
851
+ c-1.551,3.178,0.791,5.25,5.541,6.083l-0.065,4.5H16c-2.761,0-5,2.238-5,5v17c0,2.762,2.239,5,5,5h549c2.762,0,5-2.238,5-5v-17
852
+ c0-2.762-2.238-5-5-5h-3V76H571z M535,65.625c1.125,0.625,2.25-1.125,2.25-1.125l11.625-22.375c0,0,0.75-0.875-1.75-2.125
853
+ s-3.375,0.25-3.375,0.25s-8.75,21.625-9.875,23.5S533.875,65,535,65.625z"/>
854
+ </defs>
855
+ <clipPath id="SVGID_2_">
856
+ <use xlink:href="#SVGID_1_" overflow="visible"/>
857
+ </clipPath>
858
+ <rect id="progress-time-fill" x="-100%" y="34" clip-path="url(#SVGID_2_)" fill="#BE002A" width="586" height="103"/>
859
+ </g>
860
+
861
+ <g id="death-group">
862
+ <path id="death" fill="#BE002A" d="M-46.25,40.416c-5.42-0.281-8.349,3.17-13.25,3.918c-5.716,0.871-10.583-0.918-10.583-0.918
863
+ C-67.5,49-65.175,50.6-62.083,52c5.333,2.416,4.083,3.5,2.084,4.5c-16.5,4.833-15.417,27.917-15.417,27.917L-75.5,84.75
864
+ c-1,12.25-20.25,18.75-20.25,18.75s39.447,13.471,46.25-4.25c3.583-9.333-1.553-16.869-1.667-22.75
865
+ c-0.076-3.871,2.842-8.529,6.084-12.334c3.596-4.22,6.958-10.374,6.958-15.416C-38.125,43.186-39.833,40.75-46.25,40.416z
866
+ M-40,51.959c-0.882,3.004-2.779,6.906-4.154,6.537s-0.939-4.32,0.112-7.704c0.82-2.64,2.672-5.96,3.959-5.583
867
+ C-39.005,45.523-39.073,48.8-40,51.959z"/>
868
+ <path id="death-arm" fill="#BE002A" d="M-53.375,75.25c0,0,9.375,2.25,11.25,0.25s2.313-2.342,3.375-2.791
869
+ c1.083-0.459,4.375-1.75,4.292-4.75c-0.101-3.627,0.271-4.594,1.333-5.043c1.083-0.457,2.75-1.666,2.75-1.666
870
+ s0.708-0.291,0.5-0.875s-0.791-2.125-1.583-2.959c-0.792-0.832-2.375-1.874-2.917-1.332c-0.542,0.541-7.875,7.166-7.875,7.166
871
+ s-2.667,2.791-3.417,0.125S-49.833,61-49.833,61s-3.417,1.416-3.417,1.541s-1.25,5.834-1.25,5.834l-0.583,5.833L-53.375,75.25z"/>
872
+ <path id="death-tool" fill="#BE002A" d="M-20.996,26.839l-42.819,91.475l1.812,0.848l38.342-81.909c0,0,8.833,2.643,12.412,7.414
873
+ c5,6.668,4.75,14.084,4.75,14.084s4.354-7.732,0.083-17.666C-10,32.75-19.647,28.676-19.647,28.676l0.463-0.988L-20.996,26.839z"/>
874
+ </g>
875
+ <path id="designer-body" fill="#FEFFFE" d="M514.75,100.334c0,0,1.25-16.834-6.75-16.5c-5.501,0.229-5.583,3-10.833,1.666
876
+ c-3.251-0.826-5.084-15.75-0.834-22c4.948-7.277,12.086-9.266,13.334-7.833c2.25,2.583-2,10.833-4.5,14.167
877
+ c-2.5,3.333-1.833,10.416,0.5,9.916s8.026-0.141,10,2.25c3.166,3.834,4.916,17.667,4.916,17.667l0.917,2.5l-4,0.167L514.75,100.334z
878
+ "/>
879
+
880
+ <circle id="designer-head" fill="#FEFFFE" cx="516.083" cy="53.25" r="6.083"/>
881
+
882
+ <g id="designer-arm-grop">
883
+ <path id="designer-arm" fill="#FEFFFE" d="M505.875,64.875c0,0,5.875,7.5,13.042,6.791c6.419-0.635,11.833-2.791,13.458-4.041s2-3.5,0.25-3.875
884
+ s-11.375,5.125-16,3.25c-5.963-2.418-8.25-7.625-8.25-7.625l-2,1.125L505.875,64.875z"/>
885
+ <path id="designer-pen" fill="#FEFFFE" d="M525.75,59.084c0,0-0.423-0.262-0.969,0.088c-0.586,0.375-0.547,0.891-0.547,0.891l7.172,8.984l1.261,0.453
886
+ l-0.104-1.328L525.75,59.084z"/>
887
+ </g>
888
+ </svg>
889
+
890
+ <div class="deadline-timer">
891
+ Remaining <span class="day">{opt.run_time}</span> <span class="days">s</span>
892
+ </div>
893
+
894
+ </div>
895
+ """,
896
+ elem_id='run-anim',
897
+ visible=False,
898
+ )
899
+
900
+ with gr.Group(elem_id='control-panel'):
901
+
902
+ with gr.Row():
903
+ iface.tbox_prompt = gr.Textbox(
904
+ label='Edit Prompt for Background',
905
+ info='What do you want to draw?',
906
+ value=state.value.prompts[0],
907
+ placeholder=lambda: random.choice(prompt_suggestions),
908
+ scale=2,
909
+ )
910
+
911
+ iface.slider_strength = gr.Slider(
912
+ label='Prompt Strength',
913
+ info='Blends fg & bg in the prompt level, >0.8 Preferred.',
914
+ minimum=0.5,
915
+ maximum=1.0,
916
+ value=opt.default_prompt_strength,
917
+ scale=1,
918
+ )
919
+
920
+ with gr.Row():
921
+ iface.tbox_neg_prompt = gr.Textbox(
922
+ label='Edit Negative Prompt for Background',
923
+ info='Add unwanted objects for this semantic brush.',
924
+ value=opt.default_negative_prompt,
925
+ scale=2,
926
+ )
927
+
928
+ iface.tbox_name = gr.Textbox(
929
+ label='Edit Brush Name',
930
+ info='Just for your convenience.',
931
+ value=state.value.prompt_names[0],
932
+ placeholder='🌄 Background',
933
+ scale=1,
934
+ )
935
+
936
+ with gr.Row():
937
+ iface.slider_alpha = gr.Slider(
938
+ label='Mask Alpha',
939
+ info='Factor multiplied to the mask before quantization. Extremely sensitive, >0.98 Preferred.',
940
+ minimum=0.5,
941
+ maximum=1.0,
942
+ value=opt.default_mask_strength,
943
+ )
944
+
945
+ iface.slider_std = gr.Slider(
946
+ label='Mask Blur STD',
947
+ info='Blends fg & bg in the latent level, 0 for generation, 8-32 for inpainting.',
948
+ minimum=0.0001,
949
+ maximum=100.0,
950
+ value=opt.default_mask_std,
951
+ )
952
+
953
+ iface.slider_seed = gr.Slider(
954
+ label='Seed',
955
+ info='The global seed.',
956
+ minimum=-1,
957
+ maximum=2147483647,
958
+ step=1,
959
+ value=opt.seed,
960
+ )
961
+
962
+ ### Attach event handlers
963
+
964
+ for idx, btn in enumerate(iface.btn_semantics):
965
+ btn.click(
966
+ fn=partial(select_palette, idx=idx),
967
+ inputs=[state, btn],
968
+ outputs=[state] + iface.btn_semantics + [
969
+ iface.tbox_name,
970
+ iface.tbox_prompt,
971
+ iface.tbox_neg_prompt,
972
+ iface.slider_alpha,
973
+ iface.slider_strength,
974
+ iface.slider_std,
975
+ ],
976
+ api_name=f'select_palette_{idx}',
977
+ )
978
+
979
+ iface.btn_add_palette.click(
980
+ fn=add_palette,
981
+ inputs=state,
982
+ outputs=[state, iface.btn_add_palette] + iface.btn_semantics[1:],
983
+ api_name='create_new',
984
+ )
985
+
986
+ run_event = iface.btn_generate.click(
987
+ fn=hide_element,
988
+ inputs=None,
989
+ outputs=iface.btn_generate,
990
+ api_name='hide_run_button',
991
+ ).then(
992
+ fn=show_element,
993
+ inputs=None,
994
+ outputs=iface.run_animation,
995
+ api_name='show_run_animation',
996
+ )
997
+
998
+ run_event.then(
999
+ fn=run,
1000
+ inputs=[state, iface.ctrl_semantic],
1001
+ outputs=[state, iface.image_slot],
1002
+ api_name='run',
1003
+ ).then(
1004
+ fn=hide_element,
1005
+ inputs=None,
1006
+ outputs=iface.run_animation,
1007
+ api_name='hide_run_animation',
1008
+ ).then(
1009
+ fn=show_element,
1010
+ inputs=None,
1011
+ outputs=iface.btn_generate,
1012
+ api_name='show_run_button',
1013
+ )
1014
+
1015
+ run_event.then(
1016
+ fn=None,
1017
+ inputs=None,
1018
+ outputs=None,
1019
+ api_name='run_animation',
1020
+ js=f"""
1021
+ async () => {{
1022
+ // timer arguments:
1023
+ // #1 - time of animation in mileseconds,
1024
+ // #2 - days to deadline
1025
+ const animationTime = {opt.run_time};
1026
+ const days = {opt.run_time};
1027
+
1028
+ jQuery('#progress-time-fill, #death-group').css({{'animation-duration': animationTime+'s'}});
1029
+
1030
+ var deadlineAnimation = function () {{
1031
+ setTimeout(function() {{
1032
+ jQuery('#designer-arm-grop').css({{'animation-duration': '1.5s'}});
1033
+ }}, 0);
1034
+
1035
+ setTimeout(function() {{
1036
+ jQuery('#designer-arm-grop').css({{'animation-duration': '1.0s'}});
1037
+ }}, {int(opt.run_time * 1000 * 0.2)});
1038
+
1039
+ setTimeout(function() {{
1040
+ jQuery('#designer-arm-grop').css({{'animation-duration': '0.7s'}});
1041
+ }}, {int(opt.run_time * 1000 * 0.4)});
1042
+
1043
+ setTimeout(function() {{
1044
+ jQuery('#designer-arm-grop').css({{'animation-duration': '0.3s'}});
1045
+ }}, {int(opt.run_time * 1000 * 0.6)});
1046
+
1047
+ setTimeout(function() {{
1048
+ jQuery('#designer-arm-grop').css({{'animation-duration': '0.2s'}});
1049
+ }}, {int(opt.run_time * 1000 * 0.75)});
1050
+ }};
1051
+
1052
+ var deadlineTextFinished = function () {{
1053
+ var el = jQuery('.deadline-timer');
1054
+ var html = 'Done! Retry?';
1055
+ el.html(html);
1056
+ }};
1057
+
1058
+ function timer(totalTime, deadline) {{
1059
+ var time = totalTime * 1000;
1060
+ var dayDuration = time / deadline;
1061
+ var actualDay = deadline;
1062
+
1063
+ var timer = setInterval(countTime, dayDuration);
1064
+
1065
+ function countTime() {{
1066
+ --actualDay;
1067
+ jQuery('.deadline-timer .day').text(actualDay);
1068
+
1069
+ if (actualDay == 0) {{
1070
+ clearInterval(timer);
1071
+ // jQuery('.deadline-timer .day').text(deadline);
1072
+ deadlineTextFinished();
1073
+ }}
1074
+ }}
1075
+ }}
1076
+
1077
+ var deadlineText = function () {{
1078
+ var el = jQuery('.deadline-timer');
1079
+ var htmlBase = 'Remaining <span class="day">{opt.run_time}</span> <span class="days">s</span>';
1080
+ el.html(html);
1081
+ var html = '<div class="mask-red"><div class="inner">' + htmlBase + '</div></div><div class="mask-white"><div class="inner">' + htmlBase + '</div></div>';
1082
+ el.html(html);
1083
+ }};
1084
+
1085
+ var runAnimation = function() {{
1086
+ timer(animationTime, days);
1087
+ deadlineAnimation();
1088
+ deadlineText();
1089
+
1090
+ console.log('begin interval', animationTime * 1000);
1091
+ }};
1092
+ runAnimation();
1093
+ }}
1094
+ """
1095
+ )
1096
+
1097
+ iface.slider_alpha.input(
1098
+ fn=change_mask_strength,
1099
+ inputs=[state, iface.slider_alpha],
1100
+ outputs=state,
1101
+ api_name='change_alpha',
1102
+ )
1103
+ iface.slider_std.input(
1104
+ fn=change_std,
1105
+ inputs=[state, iface.slider_std],
1106
+ outputs=state,
1107
+ api_name='change_std',
1108
+ )
1109
+ iface.slider_strength.input(
1110
+ fn=change_prompt_strength,
1111
+ inputs=[state, iface.slider_strength],
1112
+ outputs=state,
1113
+ api_name='change_strength',
1114
+ )
1115
+ iface.slider_seed.input(
1116
+ fn=reset_seed,
1117
+ inputs=[state, iface.slider_seed],
1118
+ outputs=state,
1119
+ api_name='reset_seed',
1120
+ )
1121
+
1122
+ iface.tbox_name.input(
1123
+ fn=rename_prompt,
1124
+ inputs=[state, iface.tbox_name],
1125
+ outputs=[state] + iface.btn_semantics,
1126
+ api_name='prompt_rename',
1127
+ )
1128
+ iface.tbox_prompt.input(
1129
+ fn=change_prompt,
1130
+ inputs=[state, iface.tbox_prompt],
1131
+ outputs=state,
1132
+ api_name='prompt_edit',
1133
+ )
1134
+ iface.tbox_neg_prompt.input(
1135
+ fn=change_neg_prompt,
1136
+ inputs=[state, iface.tbox_neg_prompt],
1137
+ outputs=state,
1138
+ api_name='neg_prompt_edit',
1139
+ )
1140
+
1141
+ # iface.style_select.change(
1142
+ # fn=select_style,
1143
+ # inputs=[state, iface.style_select],
1144
+ # outputs=state,
1145
+ # api_name='style_select',
1146
+ # )
1147
+ # iface.quality_select.change(
1148
+ # fn=select_quality,
1149
+ # inputs=[state, iface.quality_select],
1150
+ # outputs=state,
1151
+ # api_name='quality_select',
1152
+ # )
1153
+
1154
+ iface.btn_export_state.click(lambda x: vars(x), state, iface.json_state_export)
1155
+ iface.btn_import_state.click(import_state, [state, iface.tbox_state_import], [
1156
+ state,
1157
+ *iface.btn_semantics,
1158
+ # iface.style_select,
1159
+ # iface.quality_select,
1160
+ iface.tbox_prompt,
1161
+ iface.tbox_name,
1162
+ iface.tbox_neg_prompt,
1163
+ iface.slider_strength,
1164
+ iface.slider_alpha,
1165
+ iface.slider_std,
1166
+ iface.slider_seed,
1167
+ ])
1168
+
1169
+ # Realtime user input.
1170
+ iface.ctrl_semantic.change(
1171
+ fn=draw,
1172
+ inputs=[state, iface.ctrl_semantic],
1173
+ outputs=None,
1174
+ api_name='draw',
1175
+ )
1176
+
1177
+
1178
+ if __name__ == '__main__':
1179
+ demo.launch(server_port=opt.port)
checkpoints/put_checkpoint_models_here.txt ADDED
File without changes
data.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ import copy
22
+ from typing import Optional, Union
23
+ from PIL import Image
24
+ import torch
25
+
26
+
27
+ class BackgroundObject:
28
+ def __init__(
29
+ self,
30
+ image: Optional[Image.Image] = None,
31
+ prompt: Optional[str] = None,
32
+ negative_prompt: Optional[str] = None,
33
+ ) -> None:
34
+ self.image = image
35
+ self.prompt = prompt
36
+ self.negative_prompt = negative_prompt
37
+
38
+ @property
39
+ def is_empty(self) -> bool:
40
+ return (
41
+ self.image is None and
42
+ self.prompt is None and
43
+ self.negative_prompt is None
44
+ )
45
+
46
+ def extra_repr(self) -> str:
47
+ return ''
48
+
49
+ def __repr__(self) -> str:
50
+ strings = []
51
+ if self.image is not None:
52
+ if isinstance(self.image, Image.Image):
53
+ image_str = f'Image(size={str(self.image.size)})'
54
+ else:
55
+ image_str = f'Tensor(shape={str(self.image.shape)})'
56
+ strings.append(f'image={image_str}')
57
+ if self.prompt is not None:
58
+ strings.append(f'prompt="{self.prompt}"')
59
+ if self.negative_prompt is not None:
60
+ strings.append(f'negative_prompt="{self.negative_prompt}"')
61
+ extra_repr = self.extra_repr()
62
+ if extra_repr != '':
63
+ strings.append(extra_repr)
64
+ return f'{type(self).__name__}({", ".join(strings)})'
65
+
66
+
67
+ class LayerObject:
68
+ def __init__(
69
+ self,
70
+ idx: Optional[int] = None,
71
+ prompt: Optional[str] = None,
72
+ negative_prompt: Optional[str] = None,
73
+ suffix: Optional[str] = None,
74
+ prompt_strength: Optional[float] = None,
75
+ mask: Optional[Union[torch.Tensor, Image.Image]] = None,
76
+ mask_std: Optional[float] = None,
77
+ mask_strength: Optional[float] = None,
78
+ ) -> None:
79
+ self.idx = idx
80
+ self.prompt = prompt
81
+ self.negative_prompt = negative_prompt
82
+ self.suffix = suffix
83
+ self.prompt_strength = prompt_strength
84
+ self.mask = mask
85
+ self.mask_std = mask_std
86
+ self.mask_strength = mask_strength
87
+
88
+ @property
89
+ def is_empty(self) -> bool:
90
+ return (
91
+ self.prompt is None and
92
+ self.negative_prompt is None and
93
+ self.prompt_strength is None and
94
+ self.mask is None and
95
+ self.mask_strength is None and
96
+ self.mask_std is None
97
+ )
98
+
99
+ def merge(self, other: 'LayerObject') -> bool: # Overriden or not.
100
+ if self.idx != other.idx:
101
+ # Merge only the modification requests for the same layer.
102
+ return False
103
+
104
+ if self.prompt is None and other.prompt is not None:
105
+ self.prompt = copy.deepcopy(other.prompt)
106
+ if self.negative_prompt is None and other.negative_prompt is not None:
107
+ self.negative_prompt = copy.deepcopy(other.negative_prompt)
108
+ if self.suffix is None and other.suffix is not None:
109
+ self.suffix = copy.deepcopy(other.suffix)
110
+ if self.prompt_strength is None and other.prompt_strength is not None:
111
+ self.prompt_strength = copy.deepcopy(other.prompt_strength)
112
+ if self.mask is None and other.mask is not None:
113
+ self.mask = copy.deepcopy(other.mask)
114
+ if self.mask_strength is None and other.mask_strength is not None:
115
+ self.mask_strength = copy.deepcopy(other.mask_strength)
116
+ if self.mask_std is None and other.mask_std is not None:
117
+ self.mask_std = copy.deepcopy(other.mask_std)
118
+ return True
119
+
120
+ def extra_repr(self) -> str:
121
+ return ''
122
+
123
+ def __repr__(self) -> str:
124
+ strings = []
125
+ if self.idx is not None:
126
+ strings.append(f'idx={self.idx}')
127
+ if self.prompt is not None:
128
+ strings.append(f'prompt="{self.prompt}"')
129
+ if self.negative_prompt is not None:
130
+ strings.append(f'negative_prompt="{self.negative_prompt}"')
131
+ if self.suffix is not None:
132
+ strings.append(f'suffix="{self.suffix}"')
133
+ if self.prompt_strength is not None:
134
+ strings.append(f'prompt_strength={self.prompt_strength}')
135
+ if self.mask is not None:
136
+ if isinstance(self.mask, Image.Image):
137
+ mask_str = f'Image(size={str(self.mask.size)})'
138
+ else:
139
+ mask_str = f'Tensor(shape={str(self.mask.shape)})'
140
+ strings.append(f'mask={mask_str}')
141
+ if self.mask_std is not None:
142
+ strings.append(f'mask_std={self.mask_std}')
143
+ if self.mask_strength is not None:
144
+ strings.append(f'mask_strength={self.mask_strength}')
145
+ extra_repr = self.extra_repr()
146
+ if extra_repr != '':
147
+ strings.append(extra_repr)
148
+ return f'{type(self).__name__}({", ".join(strings)})'
149
+
150
+
151
+ class BackgroundState(BackgroundObject):
152
+ def __init__(
153
+ self,
154
+ image: Optional[Image.Image] = None,
155
+ prompt: Optional[str] = None,
156
+ negative_prompt: Optional[str] = None,
157
+ latent: Optional[torch.Tensor] = None,
158
+ embed: Optional[torch.Tensor] = None,
159
+ ) -> None:
160
+ super().__init__(image, prompt, negative_prompt)
161
+ self.latent = latent
162
+ self.embed = embed
163
+
164
+ @property
165
+ def is_incomplete(self) -> bool:
166
+ return (
167
+ self.image is None or
168
+ self.prompt is None or
169
+ self.negative_prompt is None or
170
+ self.latent is None or
171
+ self.embed is None
172
+ )
173
+
174
+ def extra_repr(self) -> str:
175
+ strings = []
176
+ if self.latent is not None:
177
+ strings.append(f'latent=Tensor(shape={str(self.latent.shape)})')
178
+ if self.embed is not None:
179
+ strings.append(f'embed=Tuple[Tensor(shape={str(self.embed[0].shape)})]')
180
+ return ', '.join(strings)
181
+
182
+
183
+ # TODO
184
+ # class LayerState:
185
+ # def __init__(
186
+ # self,
187
+ # prompst: List[str] = [],
188
+ # negative_prompts: List[str] = [],
189
+ # suffix: List[str] = [],
190
+ # masks: Optional[torch.Tensor] = None,
191
+ # mask_std: Optional[torch.Tensor] = None,
192
+ # mask_strength: Optional[torch.Tensor] = None,
193
+ # original_masks: Optional[Union[torch.Tensor, List[Image.Image]]] = None,
194
+ # ) -> None:
195
+ # self.prompts = prompts
196
+ # self.negative_prompts = negative_prompts
197
+ # self.suffix = suffix
198
+ # self.masks = masks
199
+ # self.mask_std = mask_std
200
+ # self.mask_strength = mask_strength
201
+ # self.original_masks = original_masks
202
+
203
+ # def __len__(self) -> int:
204
+ # self.check_integrity(True)
205
+ # return len(self.prompts)
206
+
207
+ # @property
208
+ # def is_empty(self) -> bool:
209
+ # self.check_integrity(True)
210
+ # return len(self.prompt) == 0
211
+
212
+ # def check_integrity(self, throw_error: bool = True) -> bool:
213
+ # p = len(self.prompts)
214
+ # flag = (
215
+ # p != len(self.negative_prompts) or
216
+ # p != len(self.suffix) or
217
+ # p != len(self.masks) or
218
+ # p != len(self.mask_std) or
219
+ # p != len(self.mask_strength) or
220
+ # p != len(self.original_masks)
221
+ # )
222
+ # if flag and throw_error:
223
+ # print(
224
+ # f'LayerState(\n\tlen(prompts): {p},\n\tlen(negative_prompts): {len(self.negative_prompts)},\n\t'
225
+ # f'len(suffix): {len(self.suffix)},\n\tlen(masks): {len(self.masks)},\n\t'
226
+ # f'len(mask_std): {len(self.mask_std)},\n\tlen(mask_strength): {len(self.mask_strength)},\n\t'
227
+ # f'len(original_masks): {len(self.original_masks)}\n)'
228
+ # )
229
+ # raise ValueError('LayerState is corrupted!')
230
+ # return not flag
231
+
232
+ # def extra_repr(self) -> str:
233
+ # strings = []
234
+ # if self.idx is not None:
235
+ # strings.append(f'idx={self.idx}')
236
+ # if self.prompt is not None:
237
+ # strings.append(f'prompt="{self.prompt}"')
238
+ # if self.negative_prompt is not None:
239
+ # strings.append(f'negative_prompt="{self.negative_prompt}"')
240
+ # if self.suffix is not None:
241
+ # strings.append(f'suffix="{self.suffix}"')
242
+ # if self.mask is not None:
243
+ # if isinstance(self.mask, Image.Image):
244
+ # mask_str = f'PIL.Image.Image(size={str(self.mask.size)})'
245
+ # else:
246
+ # mask_str = f'torch.Tensor(shape={str(self.mask.shape)})'
247
+ # strings.append(f'mask={mask_str}')
248
+ # if self.mask_std is not None:
249
+ # strings.append(f'mask_std={self.mask_std}')
250
+ # if self.mask_strength is not None:
251
+ # strings.append(f'mask_strength={self.mask_strength}')
252
+ # return f'{type(self).__name__}({", ".join(strings)})'
examples/prompt_background.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Maximalism, best quality, high quality, no humans, background, clear sky, ㅠblack sky, starry universe, planets
2
+ Maximalism, best quality, high quality, no humans, background, clear sky, blue sky
3
+ Maximalism, best quality, high quality, no humans, background, universe, void, black, galaxy, galaxy, stars, stars, stars
4
+ Maximalism, best quality, high quality, no humans, background, galaxy
5
+ Maximalism, best quality, high quality, no humans, background, sky, daylight
6
+ Maximalism, best quality, high quality, no humans, background, skyscrappers, rooftop, city of light, helicopters, bright night, sky
7
+ Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden, no humans, background
8
+ Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden
examples/prompt_background_advanced.txt ADDED
The diff for this file is too large to render. See raw diff
 
examples/prompt_boy.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1boy, looking at viewer, brown hair, blue shirt
2
+ 1boy, looking at viewer, brown hair, red shirt
3
+ 1boy, looking at viewer, brown hair, purple shirt
4
+ 1boy, looking at viewer, brown hair, orange shirt
5
+ 1boy, looking at viewer, brown hair, yellow shirt
6
+ 1boy, looking at viewer, brown hair, green shirt
7
+ 1boy, looking back, side shaved hair, cyberpunk cloths, robotic suit, large body
8
+ 1boy, looking back, short hair, renaissance cloths, noble boy
9
+ 1boy, looking back, long hair, ponytail, leather jacket, heavy metal boy
10
+ 1boy, looking at viewer, a king, kingly grace, majestic cloths, crown
11
+ 1boy, looking at viewer, an astronaut, brown hair, faint smile, engineer
12
+ 1boy, looking at viewer, a medieval knight, helmet, swordman, plate armour
13
+ 1boy, looking at viewer, black haired, old eastern cloth
14
+ 1boy, looking back, messy hair, suit, short beard, noir
15
+ 1boy, looking at viewer, cute face, light smile, starry eyes, jeans
examples/prompt_girl.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1girl, looking at viewer, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, chinese cloths
2
+ 1girl, looking at viewer, princess, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, majestic gown
3
+ 1girl, looking at viewer, astronaut girl, long red hair, space suit, black starry eyes, happy face, pretty face
4
+ 1girl, looking at viewer, fantasy adventurer, backpack
5
+ 1girl, looking at viewer, astronaut girl, spacesuit, eva, happy face
6
+ 1girl, looking at viewer, soldier, rusty cloths, backpack, pretty face, sad smile, tears
7
+ 1girl, looking at viewer, majestic cloths, long hair, glittering eye, pretty face
8
+ 1girl, looking at viewer, from behind, majestic cloths, long hair, glittering eye
9
+ 1girl, looking at viewer, evil smile, very short hair, suit, evil genius
10
+ 1girl, looking at viewer, elven queen, green hair, haughty face, eyes wide open, crazy smile, brown jacket, leaves
11
+ 1girl, looking at viewer, purple hair, happy face, black leather jacket
12
+ 1girl, looking at viewer, pink hair, happy face, blue jeans, black leather jacket
13
+ 1girl, looking at viewer, knight, medium length hair, red hair, plate armour, blue eyes, sad, pretty face, determined face
14
+ 1girl, looking at viewer, pretty face, light smile, orange hair, casual cloths
15
+ 1girl, looking at viewer, pretty face, large smile, open mouth, uniform, mcdonald employee, short wavy hair
16
+ 1girl, looking at viewer, brown hair, ponytail, happy face, bright smile, blue jeans and white shirt
examples/prompt_props.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🏯 Palace, Gyeongbokgung palace
2
+ 🌳 Garden, Chinese garden
3
+ 🏛️ Rome, Ancient city of Rome
4
+ 🧱 Wall, Castle wall
5
+ 🔴 Mars, Martian desert, Red rocky desert
6
+ 🌻 Grassland, Grasslands
7
+ 🏡 Village, A fantasy village
8
+ 🐉 Dragon, a flying chinese dragon
9
+ 🌏 Earth, Earth seen from ISS
10
+ 🚀 Space Station, the international space station
11
+ 🪻 Grassland, Rusty grassland with flowers
12
+ 🖼️ Tapestry, majestic tapestry, glittering effect, glowing in light, mural painting with mountain
13
+ 🏙️ City Ruin, city, ruins, ruins, ruins, deserted
14
+ 🏙️ Renaissance City, renaissance city, renaissance city, renaissance city
15
+ 🌷 Flowers, Flower garden
16
+ 🌼 Flowers, Flower garden, spring garden
17
+ 🌹 Flowers, Flowers flowers, flowers
18
+ ⛰️ Dolomites Mountains, Dolomites
19
+ ⛰️ Himalayas Mountains, Himalayas
20
+ ⛰️ Alps Mountains, Alps
21
+ ⛰️ Mountains, Mountains
22
+ ❄️⛰️ Mountains, Winter mountains
23
+ 🌷⛰️ Mountains, Spring mountains
24
+ 🌞⛰️ Mountains, Summer mountains
25
+ 🌵 Desert, A sandy desert, dunes
26
+ 🪨🌵 Desert, A rocky desert
27
+ 💦 Waterfall, A giant waterfall
28
+ 🌊 Ocean, Ocean
29
+ ⛱️ Seashore, Seashore
30
+ 🌅 Sea Horizon, Sea horizon
31
+ 🌊 Lake, Clear blue lake
32
+ 💻 Computer, A giant supecomputer
33
+ 🌳 Tree, A giant tree
34
+ 🌳 Forest, A forest
35
+ 🌳🌳 Forest, A dense forest
36
+ 🌲 Forest, Winter forest
37
+ 🌴 Forest, Summer forest, tropical forest
38
+ 👒 Hat, A hat
39
+ 🐶 Dog, Doggy body parts
40
+ 😻 Cat, A cat
41
+ 🦉 Owl, A small sitting owl
42
+ 🦅 Eagle, A small sitting eagle
43
+ 🚀 Rocket, A flying rocket
model.py ADDED
@@ -0,0 +1,1212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
22
+ from diffusers import DiffusionPipeline, LCMScheduler, EulerDiscreteScheduler, AutoencoderTiny
23
+ from huggingface_hub import hf_hub_download
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import torchvision.transforms as T
29
+ from einops import rearrange
30
+
31
+ from collections import deque
32
+ from typing import Tuple, List, Literal, Optional, Union
33
+ from PIL import Image
34
+
35
+ from util import load_model, gaussian_lowpass, shift_to_mask_bbox_center
36
+ from data import BackgroundObject, LayerObject, BackgroundState #, LayerState
37
+
38
+
39
+ class StreamMultiDiffusion(nn.Module):
40
+ def __init__(
41
+ self,
42
+ device: torch.device,
43
+ dtype: torch.dtype = torch.float16,
44
+ sd_version: Literal['1.5'] = '1.5',
45
+ hf_key: Optional[str] = None,
46
+ lora_key: Optional[str] = None,
47
+ use_tiny_vae: bool = True,
48
+ t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], Magic number.
49
+ width: int = 512,
50
+ height: int = 512,
51
+ frame_buffer_size: int = 1,
52
+ num_inference_steps: int = 50,
53
+ guidance_scale: float = 1.2,
54
+ delta: float = 1.0,
55
+ cfg_type: Literal['none', 'full', 'self', 'initialize'] = 'none',
56
+ seed: int = 2024,
57
+ autoflush: bool = True,
58
+ default_mask_std: float = 8.0,
59
+ default_mask_strength: float = 1.0,
60
+ default_prompt_strength: float = 0.95,
61
+ bootstrap_steps: int = 1,
62
+ bootstrap_mix_steps: float = 1.0,
63
+ # bootstrap_leak_sensitivity: float = 0.2,
64
+ preprocess_mask_cover_alpha: float = 0.3, # TODO
65
+ prompt_queue_capacity: int = 256,
66
+ mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'continuous',
67
+ use_xformers: bool = True,
68
+ ) -> None:
69
+ super().__init__()
70
+
71
+ self.device = device
72
+ self.dtype = dtype
73
+ self.seed = seed
74
+ self.sd_version = sd_version
75
+
76
+ self.autoflush = autoflush
77
+ self.default_mask_std = default_mask_std
78
+ self.default_mask_strength = default_mask_strength
79
+ self.default_prompt_strength = default_prompt_strength
80
+ self.bootstrap_steps = (
81
+ bootstrap_steps > torch.arange(len(t_index_list))).to(dtype=self.dtype, device=self.device)
82
+ self.bootstrap_mix_steps = bootstrap_mix_steps
83
+ self.bootstrap_mix_ratios = (
84
+ bootstrap_mix_steps - torch.arange(len(t_index_list), dtype=self.dtype, device=self.device)).clip_(0, 1)
85
+ # self.bootstrap_leak_sensitivity = bootstrap_leak_sensitivity
86
+ self.preprocess_mask_cover_alpha = preprocess_mask_cover_alpha
87
+ self.mask_type = mask_type
88
+
89
+ ### State definition
90
+
91
+ # [0. Start] -(prepare)-> [1. Initialized]
92
+ # [1. Initialized] -(update_background)-> [2. Background Registered] (len(self.prompts)==0)
93
+ # [2. Background Registered] -(update_layers)-> [3. Unflushed] (len(self.prompts)>0)
94
+
95
+ # [3. Unflushed] -(flush)-> [4. Ready]
96
+ # [4. Ready] -(any updates)-> [3. Unflushed]
97
+ # [4. Ready] -(__call__)-> [4. Ready], continuously returns generated image.
98
+
99
+ self.ready_checklist = {
100
+ 'initialized': False,
101
+ 'background_registered': False,
102
+ 'layers_ready': False,
103
+ 'flushed': False,
104
+ }
105
+
106
+ ### Session state update queue: for lazy update policy for streaming applications.
107
+
108
+ self.update_buffer = {
109
+ 'background': None, # Maintains a single instance of BackgroundObject
110
+ 'layers': deque(maxlen=prompt_queue_capacity), # Maintains a queue of LayerObjects
111
+ }
112
+
113
+ print(f'[INFO] Loading Stable Diffusion...')
114
+ get_scheduler = lambda pipe: LCMScheduler.from_config(pipe.scheduler.config)
115
+ lora_weight_name = None
116
+ if self.sd_version == '1.5':
117
+ if hf_key is not None:
118
+ print(f'[INFO] Using custom model key: {hf_key}')
119
+ model_key = hf_key
120
+ else:
121
+ model_key = 'runwayml/stable-diffusion-v1-5'
122
+ lora_key = 'latent-consistency/lcm-lora-sdv1-5'
123
+ lora_weight_name = 'pytorch_lora_weights.safetensors'
124
+ # elif self.sd_version == 'xl':
125
+ # model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
126
+ # lora_key = 'latent-consistency/lcm-lora-sdxl'
127
+ # lora_weight_name = 'pytorch_lora_weights.safetensors'
128
+ else:
129
+ raise ValueError(f'Stable Diffusion version {self.sd_version} not supported.')
130
+
131
+ ### Internally stored "Session" states
132
+
133
+ self.state = {
134
+ 'background': BackgroundState(), # Maintains a single instance of BackgroundState
135
+ # 'layers': LayerState(), # Maintains a single instance of LayerState
136
+ 'model_key': model_key, # The Hugging Face model ID.
137
+ }
138
+
139
+ # Create model
140
+ self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
141
+ self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
142
+
143
+ self.pipe = load_model(model_key, self.sd_version, self.device, self.dtype)
144
+
145
+ self.pipe.load_lora_weights(lora_key, weight_name=lora_weight_name, adapter_name='lcm')
146
+ self.pipe.fuse_lora(
147
+ fuse_unet=True,
148
+ fuse_text_encoder=True,
149
+ lora_scale=1.0,
150
+ safe_fusing=False,
151
+ )
152
+ if use_xformers:
153
+ self.pipe.enable_xformers_memory_efficient_attention()
154
+
155
+ self.vae = (
156
+ AutoencoderTiny.from_pretrained('madebyollin/taesd').to(device=self.device, dtype=self.dtype)
157
+ if use_tiny_vae else self.pipe.vae
158
+ )
159
+ # self.tokenizer = self.pipe.tokenizer
160
+ self.text_encoder = self.pipe.text_encoder
161
+ self.unet = self.pipe.unet
162
+ self.vae_scale_factor = self.pipe.vae_scale_factor
163
+
164
+ self.scheduler = get_scheduler(self.pipe)
165
+ self.scheduler.set_timesteps(num_inference_steps)
166
+
167
+ self.generator = None
168
+
169
+ # Lock the canvas size--changing the canvas size can be implemented by reloading the module.
170
+ self.height = height
171
+ self.width = width
172
+ self.latent_height = int(height // self.pipe.vae_scale_factor)
173
+ self.latent_width = int(width // self.pipe.vae_scale_factor)
174
+
175
+ # For bootstrapping.
176
+ self.white = self.encode_imgs(torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device))
177
+
178
+ # StreamDiffusion setting.
179
+ self.t_list = t_index_list
180
+ assert len(self.t_list) > 1, 'Current version only supports diffusion models with multiple steps.'
181
+ self.frame_bff_size = frame_buffer_size # f
182
+ self.denoising_steps_num = len(self.t_list) # t=2
183
+ self.cfg_type = cfg_type
184
+ self.num_inference_steps = num_inference_steps
185
+ self.guidance_scale = 1.0 if self.cfg_type == 'none' else guidance_scale
186
+ self.delta = delta
187
+
188
+ self.batch_size = self.denoising_steps_num * frame_buffer_size # T = t*f
189
+ if self.cfg_type == 'initialize':
190
+ self.trt_unet_batch_size = (self.denoising_steps_num + 1) * self.frame_bff_size
191
+ elif self.cfg_type == 'full':
192
+ self.trt_unet_batch_size = 2 * self.denoising_steps_num * self.frame_bff_size
193
+ else:
194
+ self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size
195
+
196
+ print(f'[INFO] Model is loaded!')
197
+
198
+ self.reset_seed(self.generator, seed)
199
+ self.reset_latent()
200
+ self.prepare()
201
+
202
+ print(f'[INFO] Parameters prepared!')
203
+
204
+ self.ready_checklist['initialized'] = True
205
+
206
+ @property
207
+ def background(self) -> BackgroundState:
208
+ return self.state['background']
209
+
210
+ # @property
211
+ # def layers(self) -> LayerState:
212
+ # return self.state['layers']
213
+
214
+ @property
215
+ def num_layers(self) -> int:
216
+ return len(self.prompts) if hasattr(self, 'prompts') else 0
217
+
218
+ @property
219
+ def is_ready_except_flush(self) -> bool:
220
+ return all(v for k, v in self.ready_checklist.items() if k != 'flushed')
221
+
222
+ @property
223
+ def is_flush_needed(self) -> bool:
224
+ return self.autoflush and not self.ready_checklist['flushed']
225
+
226
+ @property
227
+ def is_ready(self) -> bool:
228
+ return self.is_ready_except_flush and not self.is_flush_needed
229
+
230
+ @property
231
+ def is_dirty(self) -> bool:
232
+ return not (self.update_buffer['background'] is None and len(self.update_buffer['layers']) == 0)
233
+
234
+ @property
235
+ def has_background(self) -> bool:
236
+ return self.background.is_empty
237
+
238
+ # @property
239
+ # def has_layers(self) -> bool:
240
+ # return len(self.layers) > 0
241
+
242
+ def __repr__(self) -> str:
243
+ return (
244
+ f'{type(self).__name__}(\n\tbackground: {str(self.background)},\n\t'
245
+ f'model_key: {self.state["model_key"]}\n)'
246
+ # f'layers: {str(self.layers)},\n\tmodel_key: {self.state["model_key"]}\n)'
247
+ )
248
+
249
+ def check_integrity(self, throw_error: bool = True) -> bool:
250
+ p = len(self.prompts)
251
+ flag = (
252
+ p != len(self.negative_prompts) or
253
+ p != len(self.prompt_strengths) or
254
+ p != len(self.masks) or
255
+ p != len(self.mask_strengths) or
256
+ p != len(self.mask_stds) or
257
+ p != len(self.original_masks)
258
+ )
259
+ if flag and throw_error:
260
+ print(
261
+ f'LayerState(\n\tlen(prompts): {p},\n\tlen(negative_prompts): {len(self.negative_prompts)},\n\t'
262
+ f'len(prompt_strengths): {len(self.prompt_strengths)},\n\tlen(masks): {len(self.masks)},\n\t'
263
+ f'len(mask_stds): {len(self.mask_stds)},\n\tlen(mask_strengths): {len(self.mask_strengths)},\n\t'
264
+ f'len(original_masks): {len(self.original_masks)}\n)'
265
+ )
266
+ raise ValueError('[ERROR] LayerState is corrupted!')
267
+ return not flag
268
+
269
+ def check_ready(self) -> bool:
270
+ all_except_flushed = all(v for k, v in self.ready_checklist.items() if k != 'flushed')
271
+ if all_except_flushed:
272
+ if self.is_flush_needed:
273
+ self.flush()
274
+ return True
275
+
276
+ print('[WARNING] MagicDraw module is not ready yet! Complete the checklist:')
277
+ for k, v in self.ready_checklist.items():
278
+ prefix = ' [ v ] ' if v else ' [ x ] '
279
+ print(prefix + k.replace('_', ' '))
280
+ return False
281
+
282
+ def reset_seed(self, generator: Optional[torch.Generator] = None, seed: Optional[int] = None) -> None:
283
+ generator = torch.Generator(self.device) if generator is None else generator
284
+ seed = self.seed if seed is None else seed
285
+ self.generator = generator
286
+ self.generator.manual_seed(seed)
287
+
288
+ self.init_noise = torch.randn((self.batch_size, 4, self.latent_height, self.latent_width),
289
+ generator=generator, device=self.device, dtype=self.dtype)
290
+ self.stock_noise = torch.zeros_like(self.init_noise)
291
+
292
+ self.ready_checklist['flushed'] = False
293
+
294
+ def reset_latent(self) -> None:
295
+ # initialize x_t_latent (it can be any random tensor)
296
+ b = (self.denoising_steps_num - 1) * self.frame_bff_size
297
+ self.x_t_latent_buffer = torch.zeros(
298
+ (b, 4, self.latent_height, self.latent_width), dtype=self.dtype, device=self.device)
299
+
300
+ def reset_state(self) -> None:
301
+ # TODO Reset states for context switch between multiple users.
302
+ pass
303
+
304
+ def prepare(self) -> None:
305
+ # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
306
+ self.timesteps = self.scheduler.timesteps.to(self.device)
307
+ self.sub_timesteps = []
308
+ for t in self.t_list:
309
+ self.sub_timesteps.append(self.timesteps[t])
310
+ sub_timesteps_tensor = torch.tensor(self.sub_timesteps, dtype=torch.long, device=self.device)
311
+ self.sub_timesteps_tensor = sub_timesteps_tensor.repeat_interleave(self.frame_bff_size, dim=0)
312
+
313
+ c_skip_list = []
314
+ c_out_list = []
315
+ for timestep in self.sub_timesteps:
316
+ c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep)
317
+ c_skip_list.append(c_skip)
318
+ c_out_list.append(c_out)
319
+ self.c_skip = torch.stack(c_skip_list).view(len(self.t_list), 1, 1, 1).to(dtype=self.dtype, device=self.device)
320
+ self.c_out = torch.stack(c_out_list).view(len(self.t_list), 1, 1, 1).to(dtype=self.dtype, device=self.device)
321
+
322
+ alpha_prod_t_sqrt_list = []
323
+ beta_prod_t_sqrt_list = []
324
+ for timestep in self.sub_timesteps:
325
+ alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
326
+ beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
327
+ alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
328
+ beta_prod_t_sqrt_list.append(beta_prod_t_sqrt)
329
+ alpha_prod_t_sqrt = (torch.stack(alpha_prod_t_sqrt_list).view(len(self.t_list), 1, 1, 1)
330
+ .to(dtype=self.dtype, device=self.device))
331
+ beta_prod_t_sqrt = (torch.stack(beta_prod_t_sqrt_list).view(len(self.t_list), 1, 1, 1)
332
+ .to(dtype=self.dtype, device=self.device))
333
+ self.alpha_prod_t_sqrt = alpha_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0)
334
+ self.beta_prod_t_sqrt = beta_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0)
335
+
336
+ noise_lvs = ((1 - self.scheduler.alphas_cumprod.to(self.device)[self.sub_timesteps_tensor]) ** 0.5)
337
+ self.noise_lvs = noise_lvs[None, :, None, None, None]
338
+ self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
339
+
340
+ @torch.no_grad()
341
+ def get_text_prompts(self, image: Image.Image) -> str:
342
+ r"""A convenient method to extract text prompt from an image.
343
+
344
+ This is called if the user does not provide background prompt but only
345
+ the background image. We use BLIP-2 to automatically generate prompts.
346
+
347
+ Args:
348
+ image (Image.Image): A PIL image.
349
+
350
+ Returns:
351
+ A single string of text prompt.
352
+ """
353
+ question = 'Question: What are in the image? Answer:'
354
+ inputs = self.i2t_processor(image, question, return_tensors='pt')
355
+ out = self.i2t_model.generate(**inputs, max_new_tokens=77)
356
+ prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
357
+ return prompt
358
+
359
+ @torch.no_grad()
360
+ def encode_imgs(
361
+ self,
362
+ imgs: torch.Tensor,
363
+ generator: Optional[torch.Generator] = None,
364
+ add_noise: bool = False,
365
+ ) -> torch.Tensor:
366
+ r"""A wrapper function for VAE encoder of the latent diffusion model.
367
+
368
+ Args:
369
+ imgs (torch.Tensor): An image to get StableDiffusion latents.
370
+ Expected shape: (B, 3, H, W). Expected pixel scale: [0, 1].
371
+ generator (Optional[torch.Generator]): Seed for KL-Autoencoder.
372
+ add_noise (bool): Turn this on for a noisy latent.
373
+
374
+ Returns:
375
+ An image latent embedding with 1/8 size (depending on the auto-
376
+ encoder. Shape: (B, 4, H//8, W//8).
377
+ """
378
+ def _retrieve_latents(
379
+ encoder_output: torch.Tensor,
380
+ generator: Optional[torch.Generator] = None,
381
+ sample_mode: str = 'sample',
382
+ ):
383
+ if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample':
384
+ return encoder_output.latent_dist.sample(generator)
385
+ elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax':
386
+ return encoder_output.latent_dist.mode()
387
+ elif hasattr(encoder_output, 'latents'):
388
+ return encoder_output.latents
389
+ else:
390
+ raise AttributeError('[ERROR] Could not access latents of provided encoder_output')
391
+
392
+ imgs = 2 * imgs - 1
393
+ latents = self.vae.config.scaling_factor * _retrieve_latents(self.vae.encode(imgs), generator=generator)
394
+ if add_noise:
395
+ latents = self.alpha_prod_t_sqrt[0] * latents + self.beta_prod_t_sqrt[0] * self.init_noise[0]
396
+ return latents
397
+
398
+ @torch.no_grad()
399
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
400
+ r"""A wrapper function for VAE decoder of the latent diffusion model.
401
+
402
+ Args:
403
+ latents (torch.Tensor): An image latent to get associated images.
404
+ Expected shape: (B, 4, H//8, W//8).
405
+
406
+ Returns:
407
+ An image latent embedding with 1/8 size (depending on the auto-
408
+ encoder. Shape: (B, 3, H, W).
409
+ """
410
+ latents = 1 / self.vae.config.scaling_factor * latents
411
+ imgs = self.vae.decode(latents).sample
412
+ imgs = (imgs / 2 + 0.5).clip_(0, 1)
413
+ return imgs
414
+
415
+ @torch.no_grad()
416
+ def update_background(
417
+ self,
418
+ image: Optional[Image.Image] = None,
419
+ prompt: Optional[str] = None,
420
+ negative_prompt: Optional[str] = None,
421
+ ) -> bool:
422
+ flag_changed = False
423
+ if image is not None:
424
+ image_ = image.resize((self.width, self.height))
425
+ prompt = self.get_text_prompts(image_) if prompt is None else prompt
426
+ negative_prompt = '' if negative_prompt is None else negative_prompt
427
+ embed = self.pipe.encode_prompt(
428
+ prompt=[prompt],
429
+ device=self.device,
430
+ num_images_per_prompt=1,
431
+ do_classifier_free_guidance=(self.guidance_scale > 1.0),
432
+ negative_prompt=[negative_prompt],
433
+ ) # ((1, 77, 768): cond, (1, 77, 768): uncond)
434
+
435
+ self.state['background'].image = image
436
+ self.state['background'].latent = (
437
+ self.encode_imgs(T.ToTensor()(image_)[None].to(self.device, self.dtype))
438
+ ) # (1, 3, H, W)
439
+ self.state['background'].prompt = prompt
440
+ self.state['background'].negative_prompt = negative_prompt
441
+ self.state['background'].embed = embed
442
+
443
+ if self.bootstrap_steps[0] > 0:
444
+ mix_ratio = self.bootstrap_mix_ratios[:, None, None, None]
445
+ self.bootstrap_latent = mix_ratio * self.white + (1.0 - mix_ratio) * self.state['background'].latent
446
+
447
+ self.ready_checklist['background_registered'] = True
448
+ flag_changed = True
449
+ else:
450
+ if not self.ready_checklist['background_registered']:
451
+ print('[WARNING] Register background image first! Request ignored.')
452
+ return False
453
+
454
+ if prompt is not None:
455
+ self.background.prompt = prompt
456
+ flag_changed = True
457
+ if negative_prompt is not None:
458
+ self.background.negative_prompt = negative_prompt
459
+ flag_changed = True
460
+ if flag_changed:
461
+ self.background.embed = self.pipe.encode_prompt(
462
+ prompt=[self.background.prompt],
463
+ device=self.device,
464
+ num_images_per_prompt=1,
465
+ do_classifier_free_guidance=(self.guidance_scale > 1.0),
466
+ negative_prompt=[self.background.negative_prompt],
467
+ ) # ((1, 77, 768): cond, (1, 77, 768): uncond)
468
+
469
+ self.ready_checklist['flushed'] = not flag_changed
470
+ return flag_changed
471
+
472
+ @torch.no_grad()
473
+ def process_mask(
474
+ self,
475
+ masks: Optional[Union[torch.Tensor, Image.Image, List[Image.Image]]] = None,
476
+ strength: Optional[Union[torch.Tensor, float]] = None,
477
+ std: Optional[Union[torch.Tensor, float]] = None,
478
+ ) -> Tuple[torch.Tensor]:
479
+ r"""Fast preprocess of masks for region-based generation with fine-
480
+ grained controls.
481
+
482
+ Mask preprocessing is done in four steps:
483
+ 1. Resizing: Resize the masks into the specified width and height by
484
+ nearest neighbor interpolation.
485
+ 2. (Optional) Ordering: Masks with higher indices are considered to
486
+ cover the masks with smaller indices. Covered masks are decayed
487
+ in its alpha value by the specified factor of
488
+ `preprocess_mask_cover_alpha`.
489
+ 3. Blurring: Gaussian blur is applied to the mask with the specified
490
+ standard deviation (isotropic). This results in gradual increase of
491
+ masked region as the timesteps evolve, naturally blending fore-
492
+ ground and the predesignated background. Not strictly required if
493
+ you want to produce images from scratch withoout background.
494
+ 4. Quantization: Split the real-numbered masks of value between [0, 1]
495
+ into predefined noise levels for each quantized scheduling step of
496
+ the diffusion sampler. For example, if the diffusion model sampler
497
+ has noise level of [0.9977, 0.9912, 0.9735, 0.8499, 0.5840], which
498
+ is the default noise level of this module with schedule [0, 4, 12,
499
+ 25, 37], the masks are split into binary masks whose values are
500
+ greater than these levels. This results in tradual increase of mask
501
+ region as the timesteps increase. Details are described in our
502
+ paper at https://arxiv.org/pdf/2403.09055.pdf.
503
+
504
+ On the Three Modes of `mask_type`:
505
+ `self.mask_type` is predefined at the initialization stage of this
506
+ pipeline. Three possible modes are available: 'discrete', 'semi-
507
+ continuous', and 'continuous'. These define the mask quantization
508
+ modes we use. Basically, this (subtly) controls the smoothness of
509
+ foreground-background blending. Continuous modes produces nonbinary
510
+ masks to further blend foreground and background latents by linear-
511
+ ly interpolating between them. Semi-continuous masks only applies
512
+ continuous mask at the last step of the LCM sampler. Due to the
513
+ large step size of the LCM scheduler, we find that our continuous
514
+ blending helps generating seamless inpainting and editing results.
515
+
516
+ Args:
517
+ masks (Union[torch.Tensor, Image.Image, List[Image.Image]]): Masks.
518
+ strength (Optional[Union[torch.Tensor, float]]): Mask strength that
519
+ overrides the default value. A globally multiplied factor to
520
+ the mask at the initial stage of processing. Can be applied
521
+ seperately for each mask.
522
+ std (Optional[Union[torch.Tensor, float]]): Mask blurring Gaussian
523
+ kernel's standard deviation. Overrides the default value. Can
524
+ be applied seperately for each mask.
525
+
526
+ Returns: A tuple of tensors.
527
+ - masks: Preprocessed (ordered, blurred, and quantized) binary/non-
528
+ binary masks (see the explanation on `mask_type` above) for
529
+ region-based image synthesis.
530
+ - strengths: Return mask strengths for caching.
531
+ - std: Return mask blur standard deviations for caching.
532
+ - original_masks: Return original masks for caching.
533
+ """
534
+ if masks is None:
535
+ kwargs = {'dtype': self.dtype, 'device': self.device}
536
+ original_masks = torch.zeros((0, 1, self.latent_height, self.latent_width), dtype=self.dtype)
537
+ masks = torch.zeros((0, self.batch_size, 1, self.latent_height, self.latent_width), **kwargs)
538
+ strength = torch.zeros((0,), **kwargs)
539
+ std = torch.zeros((0,), **kwargs)
540
+ return masks, strength, std, original_masks
541
+
542
+ if isinstance(masks, Image.Image):
543
+ masks = [masks]
544
+ if isinstance(masks, (tuple, list)):
545
+ # Assumes white background for Image.Image;
546
+ # inverted boolean masks with shape (1, 1, H, W) for torch.Tensor.
547
+ masks = torch.cat([
548
+ # (T.ToTensor()(mask.resize((self.width, self.height), Image.NEAREST)) < 0.5)[None, :1]
549
+ (1.0 - T.ToTensor()(mask.resize((self.width, self.height), Image.BILINEAR)))[None, :1]
550
+ for mask in masks
551
+ ], dim=0).float().clip_(0, 1)
552
+ original_masks = masks
553
+ masks = masks.float().to(self.device)
554
+
555
+ # Background mask alpha is decayed by the specified factor where foreground masks covers it.
556
+ if self.preprocess_mask_cover_alpha > 0:
557
+ masks = torch.stack([
558
+ torch.where(
559
+ masks[i + 1:].sum(dim=0) > 0,
560
+ mask * self.preprocess_mask_cover_alpha,
561
+ mask,
562
+ ) if i < len(masks) - 1 else mask
563
+ for i, mask in enumerate(masks)
564
+ ], dim=0)
565
+
566
+ if std is None:
567
+ std = self.default_mask_std
568
+ if isinstance(std, (int, float)):
569
+ std = [std] * len(masks)
570
+ if isinstance(std, (list, tuple)):
571
+ std = torch.as_tensor(std, dtype=torch.float, device=self.device)
572
+
573
+ # Mask preprocessing parameters are fetched from the default settings.
574
+ if strength is None:
575
+ strength = self.default_mask_strength
576
+ if isinstance(strength, (int, float)):
577
+ strength = [strength] * len(masks)
578
+ if isinstance(strength, (list, tuple)):
579
+ strength = torch.as_tensor(strength, dtype=torch.float, device=self.device)
580
+
581
+ if (std > 0).any():
582
+ std = torch.where(std > 0, std, 1e-5)
583
+ masks = gaussian_lowpass(masks, std)
584
+ # NOTE: This `strength` aligns with `denoising strength`. However, with LCM, using strength < 0.96
585
+ # gives unpleasant results.
586
+ masks = masks * strength[:, None, None, None]
587
+ masks = masks.unsqueeze(1).repeat(1, self.noise_lvs.shape[1], 1, 1, 1)
588
+
589
+ if self.mask_type == 'discrete':
590
+ # Discrete mode.
591
+ masks = masks > self.noise_lvs
592
+ elif self.mask_type == 'semi-continuous':
593
+ # Semi-continuous mode (continuous at the last step only).
594
+ masks = torch.cat((
595
+ masks[:, :-1] > self.noise_lvs[:, :-1],
596
+ (
597
+ (masks[:, -1:] - self.next_noise_lvs[:, -1:])
598
+ / (self.noise_lvs[:, -1:] - self.next_noise_lvs[:, -1:])
599
+ ).clip_(0, 1),
600
+ ), dim=1)
601
+ elif self.mask_type == 'continuous':
602
+ # Continuous mode: Have the exact same `1` coverage with discrete mode, but the mask gradually
603
+ # decreases continuously after the discrete mode boundary to become `0` at the
604
+ # next lower threshold.
605
+ masks = ((masks - self.next_noise_lvs) / (self.noise_lvs - self.next_noise_lvs)).clip_(0, 1)
606
+
607
+ # NOTE: Post processing mask strength does not align with conventional 'denoising_strength'. However,
608
+ # fine-grained mask alpha channel tuning is available with this form.
609
+ # masks = masks * strength[None, :, None, None, None]
610
+
611
+ masks = rearrange(masks.float(), 'p t () h w -> (p t) () h w')
612
+ masks = F.interpolate(masks, size=(self.latent_height, self.latent_width), mode='nearest')
613
+ masks = rearrange(masks.to(self.dtype), '(p t) () h w -> p t () h w', p=len(std))
614
+ return masks, strength, std, original_masks
615
+
616
+ @torch.no_grad()
617
+ def update_layers(
618
+ self,
619
+ prompts: Union[str, List[str]],
620
+ negative_prompts: Optional[Union[str, List[str]]] = None,
621
+ suffix: Optional[str] = None, #', background is ',
622
+ prompt_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
623
+ masks: Optional[Union[torch.Tensor, Image.Image, List[Image.Image]]] = None,
624
+ mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
625
+ mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
626
+ ) -> None:
627
+ if not self.ready_checklist['background_registered']:
628
+ print('[WARNING] Register background image first! Request ignored.')
629
+ return
630
+
631
+ ### Register prompts
632
+
633
+ if isinstance(prompts, str):
634
+ prompts = [prompts]
635
+ if negative_prompts is None:
636
+ negative_prompts = ''
637
+ if isinstance(negative_prompts, str):
638
+ negative_prompts = [negative_prompts]
639
+ fg_prompt = [p + suffix + self.background.prompt if suffix is not None else p for p in prompts]
640
+ self.prompts = fg_prompt
641
+ self.negative_prompts = negative_prompts
642
+ p = self.num_layers
643
+
644
+ e = self.pipe.encode_prompt(
645
+ prompt=fg_prompt,
646
+ device=self.device,
647
+ num_images_per_prompt=1,
648
+ do_classifier_free_guidance=(self.guidance_scale > 1.0),
649
+ negative_prompt=negative_prompts,
650
+ ) # (p, 77, 768)
651
+
652
+ if prompt_strengths is None:
653
+ prompt_strengths = self.default_prompt_strength
654
+ if isinstance(prompt_strengths, (int, float)):
655
+ prompt_strengths = [prompt_strengths] * p
656
+ if isinstance(prompt_strengths, (list, tuple)):
657
+ prompt_strengths = torch.as_tensor(prompt_strengths, dtype=self.dtype, device=self.device)
658
+ self.prompt_strengths = prompt_strengths
659
+
660
+ s = prompt_strengths[:, None, None]
661
+ self.prompt_embeds = torch.lerp(self.background.embed[0], e[0], s).repeat(self.batch_size, 1, 1) # (T * p, 77, 768)
662
+ if self.guidance_scale > 1.0 and self.cfg_type in ('initialize', 'full'):
663
+ b = self.batch_size if self.cfg_type == 'full' else self.frame_bff_size
664
+ uncond_prompt_embeds = torch.lerp(self.background.embed[1], e[1], s).repeat(b, 1, 1) # (T * p, 77, 768)
665
+ self.prompt_embeds = torch.cat([uncond_prompt_embeds, self.prompt_embeds], dim=0) # (2 * T * p, 77, 768)
666
+
667
+ self.sub_timesteps_tensor_ = self.sub_timesteps_tensor.repeat_interleave(p) # (T * p,)
668
+ self.init_noise_ = self.init_noise.repeat_interleave(p, dim=0) # (T * p, 77, 768)
669
+ self.stock_noise_ = self.stock_noise.repeat_interleave(p, dim=0) # (T * p, 77, 768)
670
+ self.c_out_ = self.c_out.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
671
+ self.c_skip_ = self.c_skip.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
672
+ self.beta_prod_t_sqrt_ = self.beta_prod_t_sqrt.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
673
+ self.alpha_prod_t_sqrt_ = self.alpha_prod_t_sqrt.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
674
+
675
+ ### Register new masks
676
+
677
+ if isinstance(masks, Image.Image):
678
+ masks = [masks]
679
+ n = len(masks) if masks is not None else 0
680
+
681
+ # Modificiation.
682
+ masks, mask_strengths, mask_stds, original_masks = self.process_mask(masks, mask_strengths, mask_stds)
683
+
684
+ self.counts = masks.sum(dim=0) # (T, 1, h, w)
685
+ self.bg_mask = (1 - self.counts).clip_(0, 1) # (T, 1, h, w)
686
+ self.masks = masks # (p, T, 1, h, w)
687
+ self.mask_strengths = mask_strengths # (p,)
688
+ self.mask_stds = mask_stds # (p,)
689
+ self.original_masks = original_masks # (p, 1, h, w)
690
+
691
+ if p > n:
692
+ # Add more masks: counts and bg_masks are not changed, but only masks are changed.
693
+ self.masks = torch.cat((
694
+ self.masks,
695
+ torch.zeros(
696
+ (p - n, self.batch_size, 1, self.latent_height, self.latent_width),
697
+ dtype=self.dtype,
698
+ device=self.device,
699
+ ),
700
+ ), dim=0)
701
+ print(f'[WARNING] Detected more prompts ({p}) than masks ({n}). '
702
+ 'Automatically adds blank masks for the additional prompts.')
703
+ elif p < n:
704
+ # Warns user to add more prompts.
705
+ print(f'[WARNING] Detected more masks ({n}) than prompts ({p}). '
706
+ 'Additional masks are ignored until more prompts are provided.')
707
+
708
+ self.ready_checklist['layers_ready'] = True
709
+ self.ready_checklist['flushed'] = False
710
+
711
+ @torch.no_grad()
712
+ def update_single_layer(
713
+ self,
714
+ idx: Optional[int] = None,
715
+ prompt: Optional[str] = None,
716
+ negative_prompt: Optional[str] = None,
717
+ suffix: Optional[str] = None, #', background is ',
718
+ prompt_strength: Optional[float] = None,
719
+ mask: Optional[Union[torch.Tensor, Image.Image]] = None,
720
+ mask_strength: Optional[float] = None,
721
+ mask_std: Optional[float] = None,
722
+ ) -> None:
723
+
724
+ ### Possible input combinations and expected behaviors
725
+
726
+ # The module will consider a layer, a pair of (prompt, mask), to be 'active' only if a prompt
727
+ # is registered. A blank mask will be assigned if no mask is provided for the 'active' layer.
728
+ # The layers should be in either of ('active', 'inactive') states. 'inactive' layers will not
729
+ # receive any input unless equipped with prompt. 'active' layers receive any input and modify
730
+ # their states accordingly. In the actual implementation, only the 'active' layers are stored
731
+ # and can be accessed by the fields. Values len(self.prompts) = self.num_layers is the number
732
+ # of 'active' layers.
733
+
734
+ # If no background is registered. The layers should be all 'inactive'.
735
+ if not self.ready_checklist['background_registered']:
736
+ print('[WARNING] Register background image first! Request ignored.')
737
+ return
738
+
739
+ # The first layer create request should be carrying a prompt. If only mask is drawn without a
740
+ # prompt, it just ignores the request--the user will update her request soon.
741
+ if self.num_layers == 0:
742
+ if prompt is not None:
743
+ self.update_layers(
744
+ prompts=prompt,
745
+ negative_prompts=negative_prompt,
746
+ suffix=suffix,
747
+ prompt_strengths=prompt_strength,
748
+ masks=mask,
749
+ mask_strengths=mask_strength,
750
+ mask_stds=mask_std,
751
+ )
752
+ return
753
+
754
+ # Invalid request indices -> considered as a layer add request.
755
+ if idx is None or idx > self.num_layers or idx < 0:
756
+ idx = self.num_layers
757
+
758
+ # Two modes for the layer edits: 'append mode' and 'edit mode'. 'append mode' appends a new
759
+ # layer at the end of the layers list. 'edit mode' modifies internal variables for the given
760
+ # index. 'append mode' is defined by the request index and strictly requires a prompt input.
761
+ is_appending = idx == self.num_layers
762
+ if is_appending and prompt is None:
763
+ print(f'[WARNING] Creating a new prompt at index ({idx}) but found no prompt. Request ignored.')
764
+ return
765
+
766
+ ### Register prompts
767
+
768
+ # | prompt | neg_prompt | append mode (idx==len) | edit mode (idx<len) |
769
+ # | --------- | ---------- | ----------------------- | -------------------- |
770
+ # | given | given | append new prompt embed | replace prompt embed |
771
+ # | given | not given | append new prompt embed | replace prompt embed |
772
+ # | not given | given | NOT ALLOWED | replace prompt embed |
773
+ # | not given | not given | NOT ALLOWED | do nothing |
774
+
775
+ # | prompt_strength | append mode (idx==len) | edit mode (idx<len) |
776
+ # | --------------- | ---------------------- | ---------------------------------------------- |
777
+ # | given | use given strength | use given strength |
778
+ # | not given | use default strength | replace strength / if no existing, use default |
779
+
780
+ p = self.num_layers
781
+
782
+ flag_prompt_edited = (
783
+ prompt is not None or
784
+ negative_prompt is not None or
785
+ prompt_strength is not None
786
+ )
787
+
788
+ if flag_prompt_edited:
789
+ is_double_cond = self.guidance_scale > 1.0 and self.cfg_type in ('initialize', 'full')
790
+
791
+ # Synchonize the internal state.
792
+
793
+ # We have asserted that prompt is not None if the mode is 'appending'.
794
+ if prompt is not None:
795
+ if suffix is not None:
796
+ prompt = prompt + suffix + self.background.prompt
797
+ if is_appending:
798
+ self.prompts.append(prompt)
799
+ else:
800
+ self.prompts[idx] = prompt
801
+
802
+ if negative_prompt is not None:
803
+ if is_appending:
804
+ self.negative_prompts.append(negative_prompt)
805
+ else:
806
+ self.negative_prompts[idx] = negative_prompt
807
+ elif is_appending:
808
+ # Make sure that negative prompts are well specified.
809
+ self.negative_prompts.append('')
810
+
811
+ if is_appending:
812
+ if prompt_strength is None:
813
+ prompt_strength = self.default_prompt_strength
814
+ self.prompt_strengths = torch.cat((
815
+ self.prompt_strengths,
816
+ torch.as_tensor([prompt_strength], dtype=self.dtype, device=self.device),
817
+ ), dim=0)
818
+ elif prompt_strength is not None:
819
+ self.prompt_strengths[idx] = prompt_strength
820
+
821
+ # Edit currently stored prompt embeddings.
822
+
823
+ if is_double_cond:
824
+ uncond_prompt_embed_, prompt_embed_ = torch.chunk(self.prompt_embeds, 2, dim=0)
825
+ uncond_prompt_embed_ = rearrange(uncond_prompt_embed_, '(t p) c1 c2 -> t p c1 c2', p=p)
826
+ prompt_embed_ = rearrange(prompt_embed_, '(t p) c1 c2 -> t p c1 c2', p=p)
827
+ else:
828
+ uncond_prompt_embed_ = None
829
+ prompt_embed_ = rearrange(self.prompt_embeds, '(t p) c1 c2 -> t p c1 c2', p=p)
830
+
831
+ e = self.pipe.encode_prompt(
832
+ prompt=self.prompts[idx],
833
+ device=self.device,
834
+ num_images_per_prompt=1,
835
+ do_classifier_free_guidance=(self.guidance_scale > 1.0),
836
+ negative_prompt=self.negative_prompts[idx],
837
+ ) # (1, 77, 768), (1, 77, 768)
838
+
839
+ s = self.prompt_strengths[idx]
840
+ t = prompt_embed_.shape[0]
841
+ prompt_embed = torch.lerp(self.background.embed[0], e[0], s)[None].repeat(t, 1, 1, 1) # (1, 77, 768)
842
+ if is_double_cond:
843
+ uncond_prompt_embed = torch.lerp(self.background.embed[1], e[1], s)[None].repeat(t, 1, 1, 1) # (1, 77, 768)
844
+
845
+ if is_appending:
846
+ prompt_embed_ = torch.cat((prompt_embed_, prompt_embed), dim=1)
847
+ if is_double_cond:
848
+ uncond_prompt_embed_ = torch.cat((uncond_prompt_embed_, uncond_prompt_embed), dim=1)
849
+ else:
850
+ prompt_embed_[:, idx:(idx + 1)] = prompt_embed
851
+ if is_double_cond:
852
+ uncond_prompt_embed_[:, idx:(idx + 1)] = uncond_prompt_embed
853
+
854
+ self.prompt_embeds = rearrange(prompt_embed_, 't p c1 c2 -> (t p) c1 c2')
855
+ if is_double_cond:
856
+ uncond_prompt_embeds = rearrange(uncond_prompt_embed_, 't p c1 c2 -> (t p) c1 c2')
857
+ self.prompt_embeds = torch.cat([uncond_prompt_embeds, self.prompt_embeds], dim=0) # (2 * T * p, 77, 768)
858
+
859
+ self.ready_checklist['flushed'] = False
860
+
861
+ if is_appending:
862
+ p = self.num_layers
863
+ self.sub_timesteps_tensor_ = self.sub_timesteps_tensor.repeat_interleave(p) # (T * p,)
864
+ self.init_noise_ = self.init_noise.repeat_interleave(p, dim=0) # (T * p, 77, 768)
865
+ self.stock_noise_ = self.stock_noise.repeat_interleave(p, dim=0) # (T * p, 77, 768)
866
+ self.c_out_ = self.c_out.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
867
+ self.c_skip_ = self.c_skip.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
868
+ self.beta_prod_t_sqrt_ = self.beta_prod_t_sqrt.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
869
+ self.alpha_prod_t_sqrt_ = self.alpha_prod_t_sqrt.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
870
+
871
+ ### Register new masks
872
+
873
+ # | mask | std / str | append mode (idx==len) | edit mode (idx<len) |
874
+ # | --------- | --------- | ---------------------------- | ----------------------------- |
875
+ # | given | given | create mask with given val | create mask with given val |
876
+ # | given | not given | create mask with default val | create mask with existing val |
877
+ # | not given | given | create blank mask | replace mask with given val |
878
+ # | not given | not given | create blank mask | do nothing |
879
+
880
+ flag_nonzero_mask = False
881
+ if mask is not None:
882
+ # Mask image is given -> create mask.
883
+ mask, strength, std, original_mask = self.process_mask(mask, mask_strength, mask_std)
884
+ flag_nonzero_mask = True
885
+
886
+ elif is_appending:
887
+ # No given mask & append mode -> create white mask.
888
+ mask = torch.zeros(
889
+ (1, self.batch_size, 1, self.latent_height, self.latent_width),
890
+ dtype=self.dtype,
891
+ device=self.device,
892
+ )
893
+ strength = torch.as_tensor([self.default_mask_strength], dtype=self.dtype, device=self.device)
894
+ std = torch.as_tensor([self.default_mask_std], dtype=self.dtype, device=self.device)
895
+ original_mask = torch.zeros((1, 1, self.latent_height, self.latent_width), dtype=self.dtype)
896
+
897
+ elif mask_std is not None or mask_strength is not None:
898
+ # No given mask & edit mode & given std / str -> replace existing mask with given std / str.
899
+ if mask_std is None:
900
+ mask_std = self.mask_stds[idx:(idx + 1)]
901
+ if mask_strength is None:
902
+ mask_strength = self.mask_strengths[idx:(idx + 1)]
903
+ mask, strength, std, original_mask = self.process_mask(
904
+ self.original_masks[idx:(idx + 1)], mask_strength, mask_std)
905
+ flag_nonzero_mask = True
906
+
907
+ else:
908
+ # No given mask & no given std & edit mode -> Do nothing.
909
+ return
910
+
911
+ if is_appending:
912
+ # Append mode.
913
+ self.masks = torch.cat((self.masks, mask), dim=0) # (p, T, 1, h, w)
914
+ self.mask_strengths = torch.cat((self.mask_strengths, strength), dim=0) # (p,)
915
+ self.mask_stds = torch.cat((self.mask_stds, std), dim=0) # (p,)
916
+ self.original_masks = torch.cat((self.original_masks, original_mask), dim=0) # (p, 1, h, w)
917
+ if flag_nonzero_mask:
918
+ self.counts = self.counts + mask[0] if hasattr(self, 'counts') else mask[0] # (T, 1, h, w)
919
+ self.bg_mask = (1 - self.counts).clip_(0, 1) # (T, 1, h, w)
920
+ else:
921
+ # Edit mode.
922
+ if flag_nonzero_mask:
923
+ self.counts = self.counts - self.masks[idx] + mask[0] # (T, 1, h, w)
924
+ self.bg_mask = (1 - self.counts).clip_(0, 1) # (T, 1, h, w)
925
+ self.masks[idx:(idx + 1)] = mask # (p, T, 1, h, w)
926
+ self.mask_strengths[idx:(idx + 1)] = strength # (p,)
927
+ self.mask_stds[idx:(idx + 1)] = std # (p,)
928
+ self.original_masks[idx:(idx + 1)] = original_mask # (p, 1, h, w)
929
+
930
+ # if flag_nonzero_mask:
931
+ # self.ready_checklist['flushed'] = False
932
+
933
+ @torch.no_grad()
934
+ def register_all(
935
+ self,
936
+ prompts: Union[str, List[str]],
937
+ masks: Union[Image.Image, List[Image.Image]],
938
+ background: Image.Image,
939
+ background_prompt: Optional[str] = None,
940
+ background_negative_prompt: str = '',
941
+ negative_prompts: Union[str, List[str]] = '',
942
+ suffix: Optional[str] = None, #', background is ',
943
+ prompt_strengths: float = 1.0,
944
+ mask_strengths: float = 1.0,
945
+ mask_stds: Union[torch.Tensor, float] = 10.0,
946
+ ) -> None:
947
+ # The order of this registration should not be changed!
948
+ self.update_background(background, background_prompt, background_negative_prompt)
949
+ self.update_layers(prompts, negative_prompts, suffix, prompt_strengths, masks, mask_strengths, mask_stds)
950
+
951
+ def update(
952
+ self,
953
+ background: Optional[Image.Image] = None,
954
+ background_prompt: Optional[str] = None,
955
+ background_negative_prompt: Optional[str] = None,
956
+ idx: Optional[int] = None,
957
+ prompt: Optional[str] = None,
958
+ negative_prompt: Optional[str] = None,
959
+ suffix: Optional[str] = None,
960
+ prompt_strength: Optional[float] = None,
961
+ mask: Optional[Union[torch.Tensor, Image.Image]] = None,
962
+ mask_strength: Optional[float] = None,
963
+ mask_std: Optional[float] = None,
964
+ ) -> None:
965
+ # For lazy update (to solve minor synchonization problem with gradio).
966
+ bq = BackgroundObject(
967
+ image=background,
968
+ prompt=background_prompt,
969
+ negative_prompt=background_negative_prompt,
970
+ )
971
+ if not bq.is_empty:
972
+ self.update_buffer['background'] = bq
973
+
974
+ lq = LayerObject(
975
+ idx=idx,
976
+ prompt=prompt,
977
+ negative_prompt=negative_prompt,
978
+ suffix=suffix,
979
+ prompt_strength=prompt_strength,
980
+ mask=mask,
981
+ mask_strength=mask_strength,
982
+ mask_std=mask_std,
983
+ )
984
+ if not lq.is_empty:
985
+ limit = self.update_buffer['layers'].maxlen
986
+
987
+ # Optimize the prompt queue: Overrride uncommitted layers with the same idx.
988
+ new_q = deque(maxlen=limit)
989
+ for _ in range(len(self.update_buffer['layers'])):
990
+ # Check from the newest to the oldest.
991
+ # Copy old requests only if the current query does not carry those requests.
992
+ query = self.update_buffer['layers'].pop()
993
+ overriden = lq.merge(query)
994
+ if not overriden:
995
+ new_q.appendleft(query)
996
+ self.update_buffer['layers'] = new_q
997
+
998
+ if len(self.update_buffer['layers']) == limit:
999
+ print(f'[WARNING] Maximum prompt change query limit ({limit}) is reached. '
1000
+ f'Current query {lq} will be ignored.')
1001
+ else:
1002
+ self.update_buffer['layers'].append(lq)
1003
+
1004
+ @torch.no_grad()
1005
+ def commit(self) -> None:
1006
+ flag_changed = self.is_dirty
1007
+ bq = self.update_buffer['background']
1008
+ lq = self.update_buffer['layers']
1009
+ count_bq_req = int(bq is not None and not bq.is_empty)
1010
+ count_lq_req = len(lq)
1011
+
1012
+ if flag_changed:
1013
+ print(f'[INFO] Requests found: {count_bq_req} background requests '
1014
+ f'& {count_lq_req} layer requests:\n{str(bq)}, {", ".join([str(l) for l in lq])}')
1015
+
1016
+ bq = self.update_buffer['background']
1017
+ if bq is not None:
1018
+ self.update_background(**vars(bq))
1019
+ self.update_buffer['background'] = None
1020
+
1021
+ while len(lq) > 0:
1022
+ l = lq.popleft()
1023
+ self.update_single_layer(**vars(l))
1024
+
1025
+ if flag_changed:
1026
+ print(f'[INFO] Requests resolved: {count_bq_req} background requests '
1027
+ f'& {count_lq_req} layer requests.')
1028
+
1029
+ def scheduler_step_batch(
1030
+ self,
1031
+ model_pred_batch: torch.Tensor,
1032
+ x_t_latent_batch: torch.Tensor,
1033
+ idx: Optional[int] = None,
1034
+ ) -> torch.Tensor:
1035
+ r"""Denoise-only step for reverse diffusion scheduler.
1036
+
1037
+ Args:
1038
+ model_pred_batch (torch.Tensor): Noise prediction results.
1039
+ x_t_latent_batch (torch.Tensor): Noisy latent.
1040
+ idx (Optional[int]): Instead of timesteps (in [0, 1000]-scale) use
1041
+ indices for the timesteps tensor (ranged in
1042
+ [0, len(timesteps)-1]). Specify only if a single-index, not
1043
+ stream-batched inference is what you want.
1044
+
1045
+ Returns:
1046
+ A denoised tensor with the same size as latent.
1047
+ """
1048
+ if idx is None:
1049
+ F_theta = (x_t_latent_batch - self.beta_prod_t_sqrt_ * model_pred_batch) / self.alpha_prod_t_sqrt_
1050
+ denoised_batch = self.c_out_ * F_theta + self.c_skip_ * x_t_latent_batch
1051
+ else:
1052
+ F_theta = (x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch) / self.alpha_prod_t_sqrt[idx]
1053
+ denoised_batch = self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch
1054
+ return denoised_batch
1055
+
1056
+ def unet_step(
1057
+ self,
1058
+ x_t_latent: torch.Tensor, # (T, 4, h, w)
1059
+ idx: Optional[int] = None,
1060
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1061
+ p = self.num_layers
1062
+ x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
1063
+
1064
+ if self.bootstrap_steps[0] > 0:
1065
+ # Background bootstrapping.
1066
+ bootstrap_latent = self.scheduler.add_noise(
1067
+ self.bootstrap_latent,
1068
+ self.stock_noise,
1069
+ torch.tensor(self.sub_timesteps_tensor, device=self.device),
1070
+ )
1071
+ x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
1072
+ bootstrap_mask = (
1073
+ self.masks * self.bootstrap_steps[None, :, None, None, None]
1074
+ + (1.0 - self.bootstrap_steps[None, :, None, None, None])
1075
+ ) # (p, t, c, h, w)
1076
+ x_t_latent = (1.0 - bootstrap_mask) * bootstrap_latent[None] + bootstrap_mask * x_t_latent
1077
+ x_t_latent = rearrange(x_t_latent, 'p t c h w -> (t p) c h w')
1078
+
1079
+ # Centering.
1080
+ x_t_latent = shift_to_mask_bbox_center(x_t_latent, rearrange(self.masks, 'p t c h w -> (t p) c h w'), reverse=True)
1081
+
1082
+ t_list = self.sub_timesteps_tensor_ # (T * p,)
1083
+ if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':
1084
+ x_t_latent_plus_uc = torch.concat([x_t_latent[:p], x_t_latent], dim=0) # (T * p + 1, 4, h, w)
1085
+ t_list = torch.concat([t_list[:p], t_list], dim=0) # (T * p + 1, 4, h, w)
1086
+ elif self.guidance_scale > 1.0 and self.cfg_type == 'full':
1087
+ x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0) # (2 * T * p, 4, h, w)
1088
+ t_list = torch.concat([t_list, t_list], dim=0) # (2 * T * p,)
1089
+ else:
1090
+ x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
1091
+
1092
+ model_pred = self.unet(
1093
+ x_t_latent_plus_uc, # (B, 4, h, w)
1094
+ t_list, # (B,)
1095
+ encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1096
+ return_dict=False,
1097
+ # TODO: Add SDXL Support.
1098
+ # added_cond_kwargs={'text_embeds': add_text_embeds, 'time_ids': add_time_ids},
1099
+ )[0] # (B, 4, h, w)
1100
+
1101
+ if self.bootstrap_steps[0] > 0:
1102
+ # Uncentering.
1103
+ bootstrap_mask = rearrange(self.masks, 'p t c h w -> (t p) c h w')
1104
+ if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':
1105
+ bootstrap_mask_ = torch.concat([bootstrap_mask[:p], bootstrap_mask], dim=0)
1106
+ elif self.guidance_scale > 1.0 and self.cfg_type == 'full':
1107
+ bootstrap_mask_ = torch.concat([bootstrap_mask, bootstrap_mask], dim=0)
1108
+ else:
1109
+ bootstrap_mask_ = bootstrap_mask
1110
+ model_pred = shift_to_mask_bbox_center(model_pred, bootstrap_mask_)
1111
+ x_t_latent = shift_to_mask_bbox_center(x_t_latent, bootstrap_mask)
1112
+
1113
+ # # Remove leakage (optional).
1114
+ # leak = (latent_ - bg_latent_).pow(2).mean(dim=1, keepdim=True)
1115
+ # leak_sigmoid = torch.sigmoid(leak / self.bootstrap_leak_sensitivity) * 2 - 1
1116
+ # fg_mask_ = fg_mask_ * leak_sigmoid
1117
+
1118
+ ### noise_pred_text, noise_pred_uncond: (T * p, 4, h, w)
1119
+ ### self.stock_noise, init_noise: (T, 4, h, w)
1120
+
1121
+ if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':
1122
+ noise_pred_text = model_pred[p:]
1123
+ self.stock_noise_ = torch.concat([model_pred[:p], self.stock_noise_[p:]], dim=0)
1124
+ elif self.guidance_scale > 1.0 and self.cfg_type == 'full':
1125
+ noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
1126
+ else:
1127
+ noise_pred_text = model_pred
1128
+ if self.guidance_scale > 1.0 and self.cfg_type in ('self', 'initialize'):
1129
+ noise_pred_uncond = self.stock_noise_ * self.delta
1130
+
1131
+ if self.guidance_scale > 1.0 and self.cfg_type != 'none':
1132
+ model_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1133
+ else:
1134
+ model_pred = noise_pred_text
1135
+
1136
+ # compute the previous noisy sample x_t -> x_t-1
1137
+ denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
1138
+
1139
+ if self.cfg_type in ('self' , 'initialize'):
1140
+ scaled_noise = self.beta_prod_t_sqrt_ * self.stock_noise_
1141
+ delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx)
1142
+
1143
+ # Do mask edit.
1144
+ alpha_next = torch.concat([self.alpha_prod_t_sqrt_[p:], torch.ones_like(self.alpha_prod_t_sqrt_[:p])], dim=0)
1145
+ delta_x = alpha_next * delta_x
1146
+ beta_next = torch.concat([self.beta_prod_t_sqrt_[p:], torch.ones_like(self.beta_prod_t_sqrt_[:p])], dim=0)
1147
+ delta_x = delta_x / beta_next
1148
+ init_noise = torch.concat([self.init_noise_[p:], self.init_noise_[:p]], dim=0)
1149
+ self.stock_noise_ = init_noise + delta_x
1150
+
1151
+ p2 = len(self.t_list) - 1
1152
+ background = torch.concat([
1153
+ self.scheduler.add_noise(
1154
+ self.background.latent.repeat(p2, 1, 1, 1),
1155
+ self.stock_noise[1:],
1156
+ torch.tensor(self.t_list[1:], device=self.device),
1157
+ ),
1158
+ self.background.latent,
1159
+ ], dim=0)
1160
+
1161
+ denoised_batch = rearrange(denoised_batch, '(t p) c h w -> p t c h w', p=p)
1162
+ latent = (self.masks * denoised_batch).sum(dim=0) # (T, 4, h, w)
1163
+ latent = torch.where(self.counts > 0, latent / self.counts, latent)
1164
+
1165
+ # latent = (
1166
+ # (1 - self.bg_mask) * self.mask_strengths * latent
1167
+ # + ((1 - self.bg_mask) * (1.0 - self.mask_strengths) + self.bg_mask) * background
1168
+ # )
1169
+ latent = (1 - self.bg_mask) * latent + self.bg_mask * background
1170
+
1171
+ return latent
1172
+
1173
+ @torch.no_grad()
1174
+ def __call__(
1175
+ self,
1176
+ no_decode: bool = False,
1177
+ ignore_check_ready: bool = False,
1178
+ ) -> Optional[Union[torch.Tensor, Image.Image]]:
1179
+ if not ignore_check_ready and not self.check_ready():
1180
+ return
1181
+ if not ignore_check_ready and self.is_dirty:
1182
+ print("I'm so dirty now!")
1183
+ self.commit()
1184
+ self.flush()
1185
+
1186
+ latent = torch.randn((1, self.unet.config.in_channels, self.latent_height, self.latent_width),
1187
+ dtype=self.dtype, device=self.device) # (1, 4, h, w)
1188
+ latent = torch.cat((latent, self.x_t_latent_buffer), dim=0) # (t, 4, h, w)
1189
+ self.stock_noise = torch.cat((self.init_noise[:1], self.stock_noise[:-1]), dim=0) # (t, 4, h, w)
1190
+ if self.cfg_type in ('self', 'initialize'):
1191
+ self.stock_noise_ = self.stock_noise.repeat_interleave(self.num_layers, dim=0) # (T * p, 77, 768)
1192
+
1193
+ x_0_pred_batch = self.unet_step(latent)
1194
+
1195
+ latent = x_0_pred_batch[-1:]
1196
+ self.x_t_latent_buffer = (
1197
+ self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
1198
+ + self.beta_prod_t_sqrt[1:] * self.init_noise[1:]
1199
+ )
1200
+
1201
+ # For pipeline flushing.
1202
+ if no_decode:
1203
+ return latent
1204
+
1205
+ imgs = self.decode_latents(latent.half()) # (1, 3, H, W)
1206
+ img = T.ToPILImage()(imgs[0].cpu())
1207
+ return img
1208
+
1209
+ def flush(self) -> None:
1210
+ for _ in self.t_list:
1211
+ self(True, True)
1212
+ self.ready_checklist['flushed'] = True
prompt_util.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Union
2
+
3
+
4
+ quality_prompt_list = [
5
+ {
6
+ "name": "(None)",
7
+ "prompt": "{prompt}",
8
+ "negative_prompt": "nsfw, lowres",
9
+ },
10
+ {
11
+ "name": "Standard v3.0",
12
+ "prompt": "{prompt}, masterpiece, best quality",
13
+ "negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
14
+ },
15
+ {
16
+ "name": "Standard v3.1",
17
+ "prompt": "{prompt}, masterpiece, best quality, very aesthetic, absurdres",
18
+ "negative_prompt": "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
19
+ },
20
+ {
21
+ "name": "Light v3.1",
22
+ "prompt": "{prompt}, (masterpiece), best quality, very aesthetic, perfect face",
23
+ "negative_prompt": "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
24
+ },
25
+ {
26
+ "name": "Heavy v3.1",
27
+ "prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
28
+ "negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
29
+ },
30
+ ]
31
+
32
+ style_list = [
33
+ {
34
+ "name": "(None)",
35
+ "prompt": "{prompt}",
36
+ "negative_prompt": "",
37
+ },
38
+ {
39
+ "name": "Cinematic",
40
+ "prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
41
+ "negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
42
+ },
43
+ {
44
+ "name": "Photographic",
45
+ "prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
46
+ "negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
47
+ },
48
+ {
49
+ "name": "Anime",
50
+ "prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
51
+ "negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
52
+ },
53
+ {
54
+ "name": "Manga",
55
+ "prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
56
+ "negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
57
+ },
58
+ {
59
+ "name": "Digital Art",
60
+ "prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
61
+ "negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
62
+ },
63
+ {
64
+ "name": "Pixel art",
65
+ "prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
66
+ "negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
67
+ },
68
+ {
69
+ "name": "Fantasy art",
70
+ "prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
71
+ "negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
72
+ },
73
+ {
74
+ "name": "Neonpunk",
75
+ "prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
76
+ "negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
77
+ },
78
+ {
79
+ "name": "3D Model",
80
+ "prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
81
+ "negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
82
+ },
83
+ ]
84
+
85
+
86
+ _style_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
87
+ _quality_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}
88
+
89
+
90
+ def preprocess_prompt(
91
+ positive: str,
92
+ negative: str = "",
93
+ style_dict: Dict[str, dict] = _quality_dict,
94
+ style_name: str = "Standard v3.1", # "Heavy v3.1"
95
+ add_style: bool = True,
96
+ ) -> Tuple[str, str]:
97
+ p, n = style_dict.get(style_name, style_dict["(None)"])
98
+
99
+ if add_style and positive.strip():
100
+ formatted_positive = p.format(prompt=positive)
101
+ else:
102
+ formatted_positive = positive
103
+
104
+ combined_negative = n
105
+ if negative.strip():
106
+ if combined_negative:
107
+ combined_negative += ", " + negative
108
+ else:
109
+ combined_negative = negative
110
+
111
+ return formatted_positive, combined_negative
112
+
113
+
114
+ def preprocess_prompts(
115
+ positives: List[str],
116
+ negatives: List[str] = None,
117
+ style_dict = _style_dict,
118
+ style_name: str = "Manga", # "(None)"
119
+ quality_dict = _quality_dict,
120
+ quality_name: str = "Standard v3.1", # "Heavy v3.1"
121
+ add_style: bool = True,
122
+ add_quality_tags = True,
123
+ ) -> Tuple[List[str], List[str]]:
124
+ if negatives is None:
125
+ negatives = ['' for _ in positives]
126
+
127
+ positives_ = []
128
+ negatives_ = []
129
+ for pos, neg in zip(positives, negatives):
130
+ pos, neg = preprocess_prompt(pos, neg, quality_dict, quality_name, add_quality_tags)
131
+ pos, neg = preprocess_prompt(pos, neg, style_dict, style_name, add_style)
132
+ positives_.append(pos)
133
+ negatives_.append(neg)
134
+ return positives_, negatives_
135
+
136
+
137
+ def print_prompts(
138
+ positives: Union[str, List[str]],
139
+ negatives: Union[str, List[str]],
140
+ has_background: bool = False,
141
+ ) -> None:
142
+ if isinstance(positives, str):
143
+ positives = [positives]
144
+ if isinstance(negatives, str):
145
+ negatives = [negatives]
146
+
147
+ for i, prompt in enumerate(positives):
148
+ prefix = ((f'Prompt{i}' if i > 0 else 'Background Prompt')
149
+ if has_background else f'Prompt{i + 1}')
150
+ print(prefix + ': ' + prompt)
151
+ for i, prompt in enumerate(negatives):
152
+ prefix = ((f'Negative Prompt{i}' if i > 0 else 'Background Negative Prompt')
153
+ if has_background else f'Negative Prompt{i + 1}')
154
+ print(prefix + ': ' + prompt)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision
3
+ xformers==0.0.22
4
+ einops
5
+ diffusers
6
+ transformers
7
+ huggingface_hub[torch]
8
+ Pillow
9
+ emoji
10
+ numpy
11
+ tqdm
12
+ jupyterlab
13
+ gradio @ https://gradio-builds.s3.amazonaws.com/7129aa5719aaa95a75397a83d3e1f3b72adf8050/gradio-4.26.0-py3-none-any.whl
timer/LICENSE_timer.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2024 Jonathan Trancozo (https://codepen.io/jtrancozo/pen/mEoEVw)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
timer/index.html ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en" >
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>CodePen - #2 - Project Deadline - SVG animation with CSS3</title>
6
+ <link href='https://fonts.googleapis.com/css?family=Oswald' rel='stylesheet' type='text/css'>
7
+
8
+ <meta property="og:image" content="https://i.imgur.com/9xiPyyv.png" /><link rel="stylesheet" href="./style.css">
9
+
10
+ </head>
11
+ <body>
12
+ <!-- partial:index.partial.html -->
13
+ <div id="deadline">
14
+ <svg preserveAspectRatio="none" id="line" viewBox="0 0 581 158" enable-background="new 0 0 581 158">
15
+ <g id="fire">
16
+ <rect id="mask-fire-black" x="511" y="41" width="38" height="34"/>
17
+ <g>
18
+ <defs>
19
+ <rect id="mask_fire" x="511" y="41" width="38" height="34"/>
20
+ </defs>
21
+ <clipPath id="mask-fire_1_">
22
+ <use xlink:href="#mask_fire" overflow="visible"/>
23
+ </clipPath>
24
+ <g id="group-fire" clip-path="url(#mask-fire_1_)">
25
+ <path id="red-flame" fill="#B71342" d="M528.377,100.291c6.207,0,10.947-3.272,10.834-8.576 c-0.112-5.305-2.934-8.803-8.237-10.383c-5.306-1.581-3.838-7.9-0.79-9.707c-7.337,2.032-7.581,5.891-7.11,8.238 c0.789,3.951,7.56,4.402,5.077,9.48c-2.482,5.079-8.012,1.129-6.319-2.257c-2.843,2.233-4.78,6.681-2.259,9.703 C521.256,98.809,524.175,100.291,528.377,100.291z"/>
26
+ <path id="yellow-flame" opacity="0.71" fill="#F7B523" d="M528.837,100.291c4.197,0,5.108-1.854,5.974-5.417 c0.902-3.724-1.129-6.207-5.305-9.931c-2.396-2.137-1.581-4.176-0.565-6.32c-4.401,1.918-3.384,5.304-2.482,6.658 c1.511,2.267,2.099,2.364,0.42,5.8c-1.679,3.435-5.42,0.764-4.275-1.527c-1.921,1.512-2.373,4.04-1.528,6.563 C522.057,99.051,525.994,100.291,528.837,100.291z"/>
27
+ <path id="white-flame" opacity="0.81" fill="#FFFFFF" d="M529.461,100.291c-2.364,0-4.174-1.322-4.129-3.469 c0.04-2.145,1.117-3.56,3.141-4.198c2.022-0.638,1.463-3.195,0.302-3.925c2.798,0.821,2.89,2.382,2.711,3.332 c-0.301,1.597-2.883,1.779-1.938,3.834c0.912,1.975,3.286,0.938,2.409-0.913c1.086,0.903,1.826,2.701,0.864,3.924 C532.18,99.691,531.064,100.291,529.461,100.291z"/>
28
+ </g>
29
+ </g>
30
+ </g>
31
+ <g id="progress-trail">
32
+ <path fill="#FFFFFF" d="M491.979,83.878c1.215-0.73-0.62-5.404-3.229-11.044c-2.583-5.584-5.034-10.066-7.229-8.878
33
+ c-2.854,1.544-0.192,6.286,2.979,11.628C487.667,80.917,490.667,84.667,491.979,83.878z"/>
34
+ <path fill="#FFFFFF" d="M571,76v-5h-23.608c0.476-9.951-4.642-13.25-4.642-13.25l-3.125,4c0,0,3.726,2.7,3.625,5.125
35
+ c-0.071,1.714-2.711,3.18-4.962,4.125H517v5h10v24h-25v-5.666c0,0,0.839,0,2.839-0.667s6.172-3.667,4.005-6.333
36
+ s-7.49,0.333-9.656,0.166s-6.479-1.5-8.146,1.917c-1.551,3.178,0.791,5.25,5.541,6.083l-0.065,4.5H16c-2.761,0-5,2.238-5,5v17
37
+ c0,2.762,2.239,5,5,5h549c2.762,0,5-2.238,5-5v-17c0-2.762-2.238-5-5-5h-3V76H571z"/>
38
+ <path fill="#FFFFFF" d="M535,65.625c1.125,0.625,2.25-1.125,2.25-1.125l11.625-22.375c0,0,0.75-0.875-1.75-2.125
39
+ s-3.375,0.25-3.375,0.25s-8.75,21.625-9.875,23.5S533.875,65,535,65.625z"/>
40
+ </g>
41
+ <g>
42
+ <defs>
43
+ <path id="SVGID_1_" d="M484.5,75.584c-3.172-5.342-5.833-10.084-2.979-11.628c2.195-1.188,4.646,3.294,7.229,8.878
44
+ c2.609,5.64,4.444,10.313,3.229,11.044C490.667,84.667,487.667,80.917,484.5,75.584z M571,76v-5h-23.608
45
+ c0.476-9.951-4.642-13.25-4.642-13.25l-3.125,4c0,0,3.726,2.7,3.625,5.125c-0.071,1.714-2.711,3.18-4.962,4.125H517v5h10v24h-25
46
+ v-5.666c0,0,0.839,0,2.839-0.667s6.172-3.667,4.005-6.333s-7.49,0.333-9.656,0.166s-6.479-1.5-8.146,1.917
47
+ c-1.551,3.178,0.791,5.25,5.541,6.083l-0.065,4.5H16c-2.761,0-5,2.238-5,5v17c0,2.762,2.239,5,5,5h549c2.762,0,5-2.238,5-5v-17
48
+ c0-2.762-2.238-5-5-5h-3V76H571z M535,65.625c1.125,0.625,2.25-1.125,2.25-1.125l11.625-22.375c0,0,0.75-0.875-1.75-2.125
49
+ s-3.375,0.25-3.375,0.25s-8.75,21.625-9.875,23.5S533.875,65,535,65.625z"/>
50
+ </defs>
51
+ <clipPath id="SVGID_2_">
52
+ <use xlink:href="#SVGID_1_" overflow="visible"/>
53
+ </clipPath>
54
+ <rect id="progress-time-fill" x="-100%" y="34" clip-path="url(#SVGID_2_)" fill="#BE002A" width="586" height="103"/>
55
+ </g>
56
+
57
+ <g id="death-group">
58
+ <path id="death" fill="#BE002A" d="M-46.25,40.416c-5.42-0.281-8.349,3.17-13.25,3.918c-5.716,0.871-10.583-0.918-10.583-0.918
59
+ C-67.5,49-65.175,50.6-62.083,52c5.333,2.416,4.083,3.5,2.084,4.5c-16.5,4.833-15.417,27.917-15.417,27.917L-75.5,84.75
60
+ c-1,12.25-20.25,18.75-20.25,18.75s39.447,13.471,46.25-4.25c3.583-9.333-1.553-16.869-1.667-22.75
61
+ c-0.076-3.871,2.842-8.529,6.084-12.334c3.596-4.22,6.958-10.374,6.958-15.416C-38.125,43.186-39.833,40.75-46.25,40.416z
62
+ M-40,51.959c-0.882,3.004-2.779,6.906-4.154,6.537s-0.939-4.32,0.112-7.704c0.82-2.64,2.672-5.96,3.959-5.583
63
+ C-39.005,45.523-39.073,48.8-40,51.959z"/>
64
+ <path id="death-arm" fill="#BE002A" d="M-53.375,75.25c0,0,9.375,2.25,11.25,0.25s2.313-2.342,3.375-2.791
65
+ c1.083-0.459,4.375-1.75,4.292-4.75c-0.101-3.627,0.271-4.594,1.333-5.043c1.083-0.457,2.75-1.666,2.75-1.666
66
+ s0.708-0.291,0.5-0.875s-0.791-2.125-1.583-2.959c-0.792-0.832-2.375-1.874-2.917-1.332c-0.542,0.541-7.875,7.166-7.875,7.166
67
+ s-2.667,2.791-3.417,0.125S-49.833,61-49.833,61s-3.417,1.416-3.417,1.541s-1.25,5.834-1.25,5.834l-0.583,5.833L-53.375,75.25z"/>
68
+ <path id="death-tool" fill="#BE002A" d="M-20.996,26.839l-42.819,91.475l1.812,0.848l38.342-81.909c0,0,8.833,2.643,12.412,7.414
69
+ c5,6.668,4.75,14.084,4.75,14.084s4.354-7.732,0.083-17.666C-10,32.75-19.647,28.676-19.647,28.676l0.463-0.988L-20.996,26.839z"/>
70
+ </g>
71
+ <path id="designer-body" fill="#FEFFFE" d="M514.75,100.334c0,0,1.25-16.834-6.75-16.5c-5.501,0.229-5.583,3-10.833,1.666
72
+ c-3.251-0.826-5.084-15.75-0.834-22c4.948-7.277,12.086-9.266,13.334-7.833c2.25,2.583-2,10.833-4.5,14.167
73
+ c-2.5,3.333-1.833,10.416,0.5,9.916s8.026-0.141,10,2.25c3.166,3.834,4.916,17.667,4.916,17.667l0.917,2.5l-4,0.167L514.75,100.334z
74
+ "/>
75
+
76
+ <circle id="designer-head" fill="#FEFFFE" cx="516.083" cy="53.25" r="6.083"/>
77
+
78
+ <g id="designer-arm-grop">
79
+ <path id="designer-arm" fill="#FEFFFE" d="M505.875,64.875c0,0,5.875,7.5,13.042,6.791c6.419-0.635,11.833-2.791,13.458-4.041s2-3.5,0.25-3.875
80
+ s-11.375,5.125-16,3.25c-5.963-2.418-8.25-7.625-8.25-7.625l-2,1.125L505.875,64.875z"/>
81
+ <path id="designer-pen" fill="#FEFFFE" d="M525.75,59.084c0,0-0.423-0.262-0.969,0.088c-0.586,0.375-0.547,0.891-0.547,0.891l7.172,8.984l1.261,0.453
82
+ l-0.104-1.328L525.75,59.084z"/>
83
+ </g>
84
+ </svg>
85
+
86
+ <div class="deadline-days">
87
+ Deadline <span class="day">7</span> <span class="days">days</span>
88
+ </div>
89
+
90
+ </div>
91
+ <!-- partial -->
92
+ <script src='https://code.jquery.com/jquery-2.2.4.min.js'></script><script src="./script.js"></script>
93
+
94
+ </body>
95
+ </html>
timer/script.js ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Init
2
+ var $ = jQuery;
3
+ var animationTime = 20,
4
+ days = 7;
5
+
6
+ $(document).ready(function(){
7
+
8
+ // timer arguments:
9
+ // #1 - time of animation in mileseconds,
10
+ // #2 - days to deadline
11
+
12
+ $('#progress-time-fill, #death-group').css({'animation-duration': animationTime+'s'});
13
+
14
+ var deadlineAnimation = function () {
15
+ setTimeout(function(){
16
+ $('#designer-arm-grop').css({'animation-duration': '1.5s'});
17
+ },0);
18
+
19
+ setTimeout(function(){
20
+ $('#designer-arm-grop').css({'animation-duration': '1s'});
21
+ },4000);
22
+
23
+ setTimeout(function(){
24
+ $('#designer-arm-grop').css({'animation-duration': '0.7s'});
25
+ },8000);
26
+
27
+ setTimeout(function(){
28
+ $('#designer-arm-grop').css({'animation-duration': '0.3s'});
29
+ },12000);
30
+
31
+ setTimeout(function(){
32
+ $('#designer-arm-grop').css({'animation-duration': '0.2s'});
33
+ },15000);
34
+ };
35
+
36
+ function timer(totalTime, deadline) {
37
+ var time = totalTime * 1000;
38
+ var dayDuration = time / deadline;
39
+ var actualDay = deadline;
40
+
41
+ var timer = setInterval(countTime, dayDuration);
42
+
43
+ function countTime() {
44
+ --actualDay;
45
+ $('.deadline-days .day').text(actualDay);
46
+
47
+ if (actualDay == 0) {
48
+ clearInterval(timer);
49
+ $('.deadline-days .day').text(deadline);
50
+ }
51
+ }
52
+ }
53
+
54
+ var deadlineText = function () {
55
+ var $el = $('.deadline-days');
56
+ var html = '<div class="mask-red"><div class="inner">' + $el.html() + '</div></div><div class="mask-white"><div class="inner">' + $el.html() + '</div></div>';
57
+ $el.html(html);
58
+ };
59
+
60
+ deadlineText();
61
+
62
+ deadlineAnimation();
63
+ timer(animationTime, days);
64
+
65
+ setInterval(function(){
66
+ timer(animationTime, days);
67
+ deadlineAnimation();
68
+
69
+ console.log('begin interval', animationTime * 1000);
70
+
71
+ }, animationTime * 1000);
72
+
73
+ });
timer/style.css ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Svg Projects
3
+ Author: Jonathan Trancozo
4
+ Language: HTML, CSS3 and SVG
5
+ Project_version: V1
6
+ Project_description:
7
+ [pt-br]
8
+ Por anos eu vi essa imagem e pensava “Isso ficaria bem massa animado” e hoje consegui expressar um pouco da minha imaginação.
9
+ O desenho foi produzido no Adobe Illustrator e exportado em SVG. As animações foram feitas com CSS3 usando principalmente [transform].
10
+
11
+ Até uma próxima.
12
+
13
+
14
+ [en]
15
+ For years I saw this picture and thought "That would be amazing animated " and today I managed to express some of my imagination.
16
+ The design was produced in Adobe Illustrator and export in SVG . The animations were made with CSS3 using mainly [ transform ].
17
+
18
+ See you.
19
+
20
+ */
21
+
22
+ html {
23
+ font-size: 1em;
24
+ line-height: 1.4;
25
+ }
26
+
27
+ html,
28
+ body {
29
+ height: 100%;
30
+ }
31
+
32
+ body {
33
+ margin: 0;
34
+ padding: 0;
35
+ background: transparent;
36
+ }
37
+
38
+
39
+ #deadline {
40
+ width:581px;
41
+ max-width: 100%;
42
+ height:158px;
43
+ position: absolute;
44
+ top: 50%;
45
+ left: 50%;
46
+ z-index: 1;
47
+ transform: translate(-50%, -50%);
48
+ }
49
+
50
+ #deadline svg {
51
+ width: 100%;
52
+ }
53
+
54
+ #progress-time-fill {
55
+ -webkit-animation-name: progress-fill;
56
+ animation-name: progress-fill;
57
+ -webkit-animation-timing-function: linear;
58
+ animation-timing-function: linear;
59
+ -webkit-animation-iteration-count: infinite;
60
+ animation-iteration-count: infinite;
61
+ }
62
+
63
+ /* Death */
64
+ #death-group {
65
+ -webkit-animation-name: walk;
66
+ animation-name: walk;
67
+ -webkit-animation-timing-function: ease;
68
+ animation-timing-function: ease;
69
+ -webkit-animation-iteration-count: infinite;
70
+ animation-iteration-count: infinite;
71
+ transform: translateX(0);
72
+ }
73
+
74
+ #death-arm {
75
+ -webkit-animation: move-arm 3s ease infinite;
76
+ animation: move-arm 3s ease infinite;
77
+ /* transform-origin: left center; */
78
+ transform-origin: -60px 74px;
79
+ }
80
+
81
+ #death-tool {
82
+ -webkit-animation: move-tool 3s ease infinite;
83
+ animation: move-tool 3s ease infinite;
84
+ transform-origin: -48px center;
85
+ }
86
+
87
+ /* Designer */
88
+
89
+ #designer-arm-grop {
90
+ -webkit-animation: write 1.5s ease infinite;
91
+ animation: write 1.5s ease infinite;
92
+ transform: translate(0, 0) rotate(0deg) scale(1, 1);
93
+ transform-origin: 90% top;
94
+ }
95
+
96
+ .deadline-timer {
97
+ color: #fff;
98
+ text-align: center;
99
+ width: 200px;
100
+ margin: 0 auto;
101
+ position: relative;
102
+ height: 40px;
103
+ font-family: 'Oswald', sans-serif;
104
+ font-size: 18pt;
105
+ margin-top: -90px;
106
+ }
107
+
108
+ .deadline-timer .inner {
109
+ width: 200px;
110
+ position: relative;
111
+ top: 0;
112
+ left: 0;
113
+ }
114
+
115
+ .mask-red,
116
+ .mask-white {
117
+ position: absolute;
118
+ top: 0;
119
+ width: 100%;
120
+ overflow: hidden;
121
+ height: 100%;
122
+ }
123
+
124
+ @-webkit-keyframes progress-fill {
125
+ 0% {
126
+ x: -100%;
127
+ }
128
+
129
+ 100% {
130
+ x: -3%;
131
+ }
132
+ }
133
+
134
+ @keyframes progress-fill {
135
+ 0% {
136
+ x: -100%;
137
+ }
138
+
139
+ 100% {
140
+ x: -3%;
141
+ }
142
+ }
143
+
144
+ @-webkit-keyframes walk {
145
+ 0% {
146
+ transform: translateX(0);
147
+ }
148
+ 6% {
149
+ transform: translateX(0);
150
+ }
151
+ 10% {
152
+ transform: translateX(100px);
153
+ },
154
+
155
+ 15% {
156
+ transform: translateX(140px);
157
+ }
158
+
159
+ 25% {
160
+ transform: translateX(170px);
161
+ }
162
+
163
+ 35% {
164
+ transform: translateX(220px);
165
+ }
166
+
167
+ 45% {
168
+ transform: translateX(280px);
169
+ }
170
+
171
+ 55% {
172
+ transform: translateX(340px);
173
+ }
174
+
175
+ 65% {
176
+ transform: translateX(370px);
177
+ }
178
+
179
+ 75% {
180
+ transform: translateX(430px);
181
+ }
182
+
183
+ 85% {
184
+ transform: translateX(460px);
185
+ }
186
+
187
+ 100% {
188
+ transform: translateX(520px);
189
+ }
190
+ }
191
+
192
+ @keyframes walk {
193
+ 0% {
194
+ transform: translateX(0);
195
+ }
196
+ 6% {
197
+ transform: translateX(0);
198
+ }
199
+ 10% {
200
+ transform: translateX(100px);
201
+ },
202
+
203
+ 15% {
204
+ transform: translateX(140px);
205
+ }
206
+
207
+ 25% {
208
+ transform: translateX(170px);
209
+ }
210
+
211
+ 35% {
212
+ transform: translateX(220px);
213
+ }
214
+
215
+ 45% {
216
+ transform: translateX(280px);
217
+ }
218
+
219
+ 55% {
220
+ transform: translateX(340px);
221
+ }
222
+
223
+ 65% {
224
+ transform: translateX(370px);
225
+ }
226
+
227
+ 75% {
228
+ transform: translateX(430px);
229
+ }
230
+
231
+ 85% {
232
+ transform: translateX(460px);
233
+ }
234
+
235
+ 100% {
236
+ transform: translateX(520px);
237
+ }
238
+ }
239
+
240
+ @-webkit-keyframes move-arm {
241
+ 0% {
242
+ transform: rotate(0);
243
+ }
244
+
245
+ 5% {
246
+ transform: rotate(0);
247
+ }
248
+
249
+ 9% {
250
+ transform: rotate(40deg);
251
+ }
252
+
253
+ 80% {
254
+ transform: rotate(0);
255
+ }
256
+ }
257
+
258
+ @keyframes move-arm {
259
+ 0% {
260
+ transform: rotate(0);
261
+ }
262
+
263
+ 5% {
264
+ transform: rotate(0);
265
+ }
266
+
267
+ 9% {
268
+ transform: rotate(40deg);
269
+ }
270
+
271
+ 80% {
272
+ transform: rotate(0);
273
+ }
274
+ }
275
+
276
+ @-webkit-keyframes move-tool {
277
+ 0% {
278
+ transform: rotate(0);
279
+ }
280
+
281
+ 5% {
282
+ transform: rotate(0);
283
+ }
284
+
285
+ 9% {
286
+ transform: rotate(50deg);
287
+ }
288
+
289
+ 80% {
290
+ transform: rotate(0);
291
+ }
292
+ }
293
+
294
+ @keyframes move-tool {
295
+ 0% {
296
+ transform: rotate(0);
297
+ }
298
+
299
+ 5% {
300
+ transform: rotate(0);
301
+ }
302
+
303
+ 9% {
304
+ transform: rotate(50deg);
305
+ }
306
+
307
+ 80% {
308
+ transform: rotate(0);
309
+ }
310
+ }
311
+
312
+ /* Design animations */
313
+
314
+ @-webkit-keyframes write {
315
+ 0% {
316
+ transform: translate(0, 0) rotate(0deg) scale(1, 1);
317
+ }
318
+
319
+ 16% {
320
+ transform: translate(0px, 0px) rotate(5deg) scale(0.8, 1);
321
+ }
322
+
323
+ 32% {
324
+ transform: translate(0px, 0px) rotate(0deg) scale(1, 1);
325
+ }
326
+
327
+ 48% {
328
+ transform: translate(0px, 0px) rotate(6deg) scale(0.8, 1);
329
+ }
330
+
331
+ 65% {
332
+ transform: translate(0px, 0px) rotate(0deg) scale(1, 1);
333
+ }
334
+
335
+ 83% {
336
+ transform: translate(0px, 0px) rotate(4deg) scale(0.8, 1);
337
+ }
338
+ }
339
+
340
+ @keyframes write {
341
+ 0% {
342
+ transform: translate(0, 0) rotate(0deg) scale(1, 1);
343
+ }
344
+
345
+ 16% {
346
+ transform: translate(0px, 0px) rotate(5deg) scale(0.8, 1);
347
+ }
348
+
349
+ 32% {
350
+ transform: translate(0px, 0px) rotate(0deg) scale(1, 1);
351
+ }
352
+
353
+ 48% {
354
+ transform: translate(0px, 0px) rotate(6deg) scale(0.8, 1);
355
+ }
356
+
357
+ 65% {
358
+ transform: translate(0px, 0px) rotate(0deg) scale(1, 1);
359
+ }
360
+
361
+ 83% {
362
+ transform: translate(0px, 0px) rotate(4deg) scale(0.8, 1);
363
+ }
364
+ }
365
+
366
+ @-webkit-keyframes text-red {
367
+ 0% {
368
+ width: 0%;
369
+ }
370
+
371
+ 100% {
372
+ width: 98%;
373
+ }
374
+ }
375
+
376
+ @keyframes text-red {
377
+ 0% {
378
+ width: 0%;
379
+ }
380
+
381
+ 100% {
382
+ width: 98%;
383
+ }
384
+ }
385
+
386
+ /* Flames */
387
+
388
+ /* @keyframes show-flames {
389
+ 0% {
390
+ transform: translateY(0);
391
+ }
392
+ 74% {
393
+ transform: translateY(0);
394
+ }
395
+ 80% {
396
+ transform: translateY(-30px);
397
+ }
398
+ 97% {
399
+ transform: translateY(-30px);
400
+ }
401
+ 100% {
402
+ transform: translateY(0px);
403
+ }
404
+ } */
405
+
406
+ @-webkit-keyframes show-flames {
407
+ 0% {
408
+ opacity: 0;
409
+ }
410
+ 74% {
411
+ opacity: 0;
412
+ }
413
+ 80% {
414
+ opacity: 1;
415
+ }
416
+ 99% {
417
+ opacity: 1;
418
+ }
419
+ 100% {
420
+ opacity: 0;
421
+ }
422
+ }
423
+
424
+ @keyframes show-flames {
425
+ 0% {
426
+ opacity: 0;
427
+ }
428
+ 74% {
429
+ opacity: 0;
430
+ }
431
+ 80% {
432
+ opacity: 1;
433
+ }
434
+ 99% {
435
+ opacity: 1;
436
+ }
437
+ 100% {
438
+ opacity: 0;
439
+ }
440
+ }
441
+
442
+ @-webkit-keyframes red-flame {
443
+ 0% {
444
+ transform: translateY(-30px) scale(1, 1);
445
+ }
446
+
447
+ 25% {
448
+ transform: translateY(-30px) scale(1.1, 1.1);
449
+ }
450
+
451
+ 75% {
452
+ transform: translateY(-30px) scale(0.8, 0.7);
453
+ }
454
+
455
+ 100% {
456
+ transform: translateY(-30px) scale(1, 1);
457
+ }
458
+ }
459
+
460
+ @keyframes red-flame {
461
+ 0% {
462
+ transform: translateY(-30px) scale(1, 1);
463
+ }
464
+
465
+ 25% {
466
+ transform: translateY(-30px) scale(1.1, 1.1);
467
+ }
468
+
469
+ 75% {
470
+ transform: translateY(-30px) scale(0.8, 0.7);
471
+ }
472
+
473
+ 100% {
474
+ transform: translateY(-30px) scale(1, 1);
475
+ }
476
+ }
477
+
478
+ @-webkit-keyframes yellow-flame {
479
+ 0% {
480
+ transform: translateY(-30px) scale(0.8, 0.7);
481
+ }
482
+
483
+ 50% {
484
+ transform: translateY(-30px) scale(1.1, 1.2);
485
+ }
486
+
487
+ 100% {
488
+ transform: translateY(-30px) scale(1, 1);
489
+ }
490
+ }
491
+
492
+ @keyframes yellow-flame {
493
+ 0% {
494
+ transform: translateY(-30px) scale(0.8, 0.7);
495
+ }
496
+
497
+ 50% {
498
+ transform: translateY(-30px) scale(1.1, 1.2);
499
+ }
500
+
501
+ 100% {
502
+ transform: translateY(-30px) scale(1, 1);
503
+ }
504
+ }
util.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Jaerin Lee
2
+
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ import concurrent.futures
22
+ import time
23
+ from typing import Any, Callable, List, Literal, Tuple, Union
24
+
25
+ from PIL import Image
26
+ import numpy as np
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.cuda.amp as amp
31
+ import torchvision.transforms as T
32
+ import torchvision.transforms.functional as TF
33
+
34
+ from diffusers import (
35
+ DiffusionPipeline,
36
+ StableDiffusionPipeline,
37
+ StableDiffusionXLPipeline,
38
+ )
39
+
40
+
41
+ def seed_everything(seed: int) -> None:
42
+ torch.manual_seed(seed)
43
+ torch.cuda.manual_seed(seed)
44
+ torch.backends.cudnn.deterministic = True
45
+ torch.backends.cudnn.benchmark = True
46
+
47
+
48
+ def load_model(
49
+ model_key: str,
50
+ sd_version: Literal['1.5', 'xl'],
51
+ device: torch.device,
52
+ dtype: torch.dtype,
53
+ ) -> torch.nn.Module:
54
+ if model_key.endswith('.safetensors'):
55
+ if sd_version == '1.5':
56
+ pipeline = StableDiffusionPipeline
57
+ elif sd_version == 'xl':
58
+ pipeline = StableDiffusionXLPipeline
59
+ else:
60
+ raise ValueError(f'Stable Diffusion version {sd_version} not supported.')
61
+ return pipeline.from_single_file(model_key, torch_dtype=dtype).to(device)
62
+ try:
63
+ return DiffusionPipeline.from_pretrained(model_key, variant='fp16', torch_dtype=dtype).to(device)
64
+ except:
65
+ return DiffusionPipeline.from_pretrained(model_key, variant=None, torch_dtype=dtype).to(device)
66
+
67
+
68
+ def get_cutoff(cutoff: float = None, scale: float = None) -> float:
69
+ if cutoff is not None:
70
+ return cutoff
71
+
72
+ if scale is not None and cutoff is None:
73
+ return 0.5 / scale
74
+
75
+ raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
76
+
77
+
78
+ def get_scale(cutoff: float = None, scale: float = None) -> float:
79
+ if scale is not None:
80
+ return scale
81
+
82
+ if cutoff is not None and scale is None:
83
+ return 0.5 / cutoff
84
+
85
+ raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
86
+
87
+
88
+ def filter_2d_by_kernel_1d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
89
+ assert len(k.shape) in (1,), 'Kernel size should be one of (1,).'
90
+ # assert len(k.shape) in (1, 2), 'Kernel size should be one of (1, 2).'
91
+
92
+ b, c, h, w = x.shape
93
+ ks = k.shape[-1]
94
+ k = k.view(1, 1, -1).repeat(c, 1, 1)
95
+
96
+ x = x.permute(0, 2, 1, 3)
97
+ x = x.reshape(b * h, c, w)
98
+ x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
99
+ x = F.conv1d(x, k, groups=c)
100
+ x = x.reshape(b, h, c, w).permute(0, 3, 2, 1).reshape(b * w, c, h)
101
+ x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
102
+ x = F.conv1d(x, k, groups=c)
103
+ x = x.reshape(b, w, c, h).permute(0, 2, 3, 1)
104
+ return x
105
+
106
+
107
+ def filter_2d_by_kernel_2d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
108
+ assert len(k.shape) in (2, 3), 'Kernel size should be one of (2, 3).'
109
+
110
+ x = F.pad(x, (
111
+ k.shape[-2] // 2, (k.shape[-2] - 1) // 2,
112
+ k.shape[-1] // 2, (k.shape[-1] - 1) // 2,
113
+ ), mode='replicate')
114
+
115
+ b, c, _, _ = x.shape
116
+ if len(k.shape) == 2 or (len(k.shape) == 3 and k.shape[0] == 1):
117
+ k = k.view(1, 1, *k.shape[-2:]).repeat(c, 1, 1, 1)
118
+ x = F.conv2d(x, k, groups=c)
119
+ elif len(k.shape) == 3:
120
+ assert k.shape[0] == b, \
121
+ 'The number of kernels should match the batch size.'
122
+
123
+ k = k.unsqueeze(1)
124
+ x = F.conv2d(x.permute(1, 0, 2, 3), k, groups=b).permute(1, 0, 2, 3)
125
+ return x
126
+
127
+
128
+ @amp.autocast(False)
129
+ def filter_by_kernel(
130
+ x: torch.Tensor,
131
+ k: torch.Tensor,
132
+ is_batch: bool = False,
133
+ ) -> torch.Tensor:
134
+ k_dim = len(k.shape)
135
+ if k_dim == 1 or k_dim == 2 and is_batch:
136
+ return filter_2d_by_kernel_1d(x, k)
137
+ elif k_dim == 2 or k_dim == 3 and is_batch:
138
+ return filter_2d_by_kernel_2d(x, k)
139
+ else:
140
+ raise ValueError('Kernel size should be one of (1, 2, 3).')
141
+
142
+
143
+ def gen_gauss_lowpass_filter_2d(
144
+ std: torch.Tensor,
145
+ window_size: int = None,
146
+ ) -> torch.Tensor:
147
+ # Gaussian kernel size is odd in order to preserve the center.
148
+ if window_size is None:
149
+ window_size = (
150
+ 2 * int(np.ceil(3 * std.max().detach().cpu().numpy())) + 1)
151
+
152
+ y = torch.arange(
153
+ window_size, dtype=std.dtype, device=std.device
154
+ ).view(-1, 1).repeat(1, window_size)
155
+ grid = torch.stack((y.t(), y), dim=-1)
156
+ grid -= 0.5 * (window_size - 1) # (W, W)
157
+ var = (std * std).unsqueeze(-1).unsqueeze(-1)
158
+ distsq = (grid * grid).sum(dim=-1).unsqueeze(0).repeat(*std.shape, 1, 1)
159
+ k = torch.exp(-0.5 * distsq / var)
160
+ k /= k.sum(dim=(-2, -1), keepdim=True)
161
+ return k
162
+
163
+
164
+ def gaussian_lowpass(
165
+ x: torch.Tensor,
166
+ std: Union[float, Tuple[float], torch.Tensor] = None,
167
+ cutoff: Union[float, torch.Tensor] = None,
168
+ scale: Union[float, torch.Tensor] = None,
169
+ ) -> torch.Tensor:
170
+ if std is None:
171
+ cutoff = get_cutoff(cutoff, scale)
172
+ std = 0.5 / (np.pi * cutoff)
173
+ if isinstance(std, (float, int)):
174
+ std = (std, std)
175
+ if isinstance(std, torch.Tensor):
176
+ """Using nn.functional.conv2d with Gaussian kernels built in runtime is
177
+ 80% faster than transforms.functional.gaussian_blur for individual
178
+ items.
179
+
180
+ (in GPU); However, in CPU, the result is exactly opposite. But you
181
+ won't gonna run this on CPU, right?
182
+ """
183
+ if len(list(s for s in std.shape if s != 1)) >= 2:
184
+ raise NotImplementedError(
185
+ 'Anisotropic Gaussian filter is not currently available.')
186
+
187
+ # k.shape == (B, W, W).
188
+ k = gen_gauss_lowpass_filter_2d(std=std.view(-1))
189
+ if k.shape[0] == 1:
190
+ return filter_by_kernel(x, k[0], False)
191
+ else:
192
+ return filter_by_kernel(x, k, True)
193
+ else:
194
+ # Gaussian kernel size is odd in order to preserve the center.
195
+ window_size = tuple(2 * int(np.ceil(3 * s)) + 1 for s in std)
196
+ return TF.gaussian_blur(x, window_size, std)
197
+
198
+
199
+ def blend(
200
+ fg: Union[torch.Tensor, Image.Image],
201
+ bg: Union[torch.Tensor, Image.Image],
202
+ mask: Union[torch.Tensor, Image.Image],
203
+ std: float = 0.0,
204
+ ) -> Image.Image:
205
+ if not isinstance(fg, torch.Tensor):
206
+ fg = T.ToTensor()(fg)
207
+ if not isinstance(bg, torch.Tensor):
208
+ bg = T.ToTensor()(bg)
209
+ if not isinstance(mask, torch.Tensor):
210
+ mask = (T.ToTensor()(mask) < 0.5).float()[:1]
211
+ if std > 0:
212
+ mask = gaussian_lowpass(mask[None], std)[0].clip_(0, 1)
213
+ return T.ToPILImage()(fg * mask + bg * (1 - mask))
214
+
215
+
216
+ def get_panorama_views(
217
+ panorama_height: int,
218
+ panorama_width: int,
219
+ window_size: int = 64,
220
+ ) -> tuple[List[Tuple[int]], torch.Tensor]:
221
+ stride = window_size // 2
222
+ is_horizontal = panorama_width > panorama_height
223
+ num_blocks_height = (panorama_height - window_size + stride - 1) // stride + 1
224
+ num_blocks_width = (panorama_width - window_size + stride - 1) // stride + 1
225
+ total_num_blocks = num_blocks_height * num_blocks_width
226
+
227
+ half_fwd = torch.linspace(0, 1, (window_size + 1) // 2)
228
+ half_rev = half_fwd.flip(0)
229
+ if window_size % 2 == 1:
230
+ half_rev = half_rev[1:]
231
+ c = torch.cat((half_fwd, half_rev))
232
+ one = torch.ones_like(c)
233
+ f = c.clone()
234
+ f[:window_size // 2] = 1
235
+ b = c.clone()
236
+ b[-(window_size // 2):] = 1
237
+
238
+ h = [one] if num_blocks_height == 1 else [f] + [c] * (num_blocks_height - 2) + [b]
239
+ w = [one] if num_blocks_width == 1 else [f] + [c] * (num_blocks_width - 2) + [b]
240
+
241
+ views = []
242
+ masks = torch.zeros(total_num_blocks, panorama_height, panorama_width) # (n, h, w)
243
+ for i in range(total_num_blocks):
244
+ hi, wi = i // num_blocks_width, i % num_blocks_width
245
+ h_start = hi * stride
246
+ h_end = min(h_start + window_size, panorama_height)
247
+ w_start = wi * stride
248
+ w_end = min(w_start + window_size, panorama_width)
249
+ views.append((h_start, h_end, w_start, w_end))
250
+
251
+ h_width = h_end - h_start
252
+ w_width = w_end - w_start
253
+ masks[i, h_start:h_end, w_start:w_end] = h[hi][:h_width, None] * w[wi][None, :w_width]
254
+
255
+ # Sum of the mask weights at each pixel `masks.sum(dim=1)` must be unity.
256
+ return views, masks[None] # (1, n, h, w)
257
+
258
+
259
+ def shift_to_mask_bbox_center(im: torch.Tensor, mask: torch.Tensor, reverse: bool = False) -> List[int]:
260
+ h, w = mask.shape[-2:]
261
+ device = mask.device
262
+ mask = mask.reshape(-1, h, w)
263
+ # assert mask.shape[0] == im.shape[0]
264
+ h_occupied = mask.sum(dim=-2) > 0
265
+ w_occupied = mask.sum(dim=-1) > 0
266
+ l = torch.argmax(h_occupied * torch.arange(w, 0, -1).to(device), 1, keepdim=True).cpu()
267
+ r = torch.argmax(h_occupied * torch.arange(w).to(device), 1, keepdim=True).cpu()
268
+ t = torch.argmax(w_occupied * torch.arange(h, 0, -1).to(device), 1, keepdim=True).cpu()
269
+ b = torch.argmax(w_occupied * torch.arange(h).to(device), 1, keepdim=True).cpu()
270
+ tb = (t + b + 1) // 2
271
+ lr = (l + r + 1) // 2
272
+ shifts = (tb - (h // 2), lr - (w // 2))
273
+ shifts = torch.cat(shifts, dim=1) # (p, 2)
274
+ if reverse:
275
+ shifts = shifts * -1
276
+ return torch.stack([i.roll(shifts=s.tolist(), dims=(-2, -1)) for i, s in zip(im, shifts)], dim=0)
277
+
278
+
279
+ class Streamer:
280
+ def __init__(self, fn: Callable, ema_alpha: float = 0.9) -> None:
281
+ self.fn = fn
282
+ self.ema_alpha = ema_alpha
283
+
284
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
285
+ self.future = self.executor.submit(fn)
286
+ self.image = None
287
+
288
+ self.prev_exec_time = 0
289
+ self.ema_exec_time = 0
290
+
291
+ @property
292
+ def throughput(self) -> float:
293
+ return 1.0 / self.ema_exec_time if self.ema_exec_time else float('inf')
294
+
295
+ def timed_fn(self) -> Any:
296
+ start = time.time()
297
+ res = self.fn()
298
+ end = time.time()
299
+ self.prev_exec_time = end - start
300
+ self.ema_exec_time = self.ema_exec_time * self.ema_alpha + self.prev_exec_time * (1 - self.ema_alpha)
301
+ return res
302
+
303
+ def __call__(self) -> Any:
304
+ if self.future.done() or self.image is None:
305
+ # get the result (the new image) and start a new task
306
+ image = self.future.result()
307
+ self.future = self.executor.submit(self.timed_fn)
308
+ self.image = image
309
+ return image
310
+ else:
311
+ # if self.fn() is not ready yet, use the previous image
312
+ # NOTE: This assumes that we have access to a previously generated image here.
313
+ # If there's no previous image (i.e., this is the first invocation), you could fall
314
+ # back to some default image or handle it differently based on your requirements.
315
+ return self.image