wwen1997 commited on
Commit
7615afe
1 Parent(s): 4d1afb9

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demos.gif filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import uuid
3
+ from PIL import Image
4
+ import numpy as np
5
+ import cv2
6
+ from scipy.interpolate import interp1d, PchipInterpolator
7
+ from packaging import version
8
+
9
+ import torch
10
+ import torchvision
11
+ import gradio as gr
12
+ # from moviepy.editor import *
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.utils import load_image, export_to_video, export_to_gif
15
+
16
+ import os
17
+ import sys
18
+ sys.path.insert(0, os.getcwd())
19
+ from models_diffusers.controlnet_svd import ControlNetSVDModel
20
+ from models_diffusers.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
21
+ from pipelines.pipeline_stable_video_diffusion_interp_control import StableVideoDiffusionInterpControlPipeline
22
+ from gradio_demo.utils_drag import *
23
+
24
+ import warnings
25
+ print("gr file", gr.__file__)
26
+
27
+ from huggingface_hub import hf_hub_download
28
+ os.makedirs("checkpoints", exist_ok=True)
29
+ hf_hub_download(
30
+ "wwen1997/framer_512x320",
31
+ "checkpoints/framer_512x320",
32
+ token=os.environ["TOKEN"],
33
+ )
34
+
35
+
36
+ def get_args():
37
+ import argparse
38
+ parser = argparse.ArgumentParser()
39
+
40
+ parser.add_argument("--min_guidance_scale", type=float, default=1.0)
41
+ parser.add_argument("--max_guidance_scale", type=float, default=3.0)
42
+ parser.add_argument("--middle_max_guidance", type=int, default=0, choices=[0, 1])
43
+ parser.add_argument("--with_control", type=int, default=1, choices=[0, 1])
44
+
45
+ parser.add_argument("--controlnet_cond_scale", type=float, default=1.0)
46
+
47
+ parser.add_argument(
48
+ "--dataset",
49
+ type=str,
50
+ default='videoswap',
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--model", type=str,
55
+ default="checkpoints/framer_512x320",
56
+ help="Path to model.",
57
+ )
58
+
59
+ parser.add_argument("--output_dir", type=str, default="gradio_demo/outputs", help="Path to the output video.")
60
+
61
+ parser.add_argument("--seed", type=int, default=42, help="random seed.")
62
+
63
+ parser.add_argument("--noise_aug", type=float, default=0.02)
64
+
65
+ parser.add_argument("--num_frames", type=int, default=14)
66
+ parser.add_argument("--frame_interval", type=int, default=2)
67
+
68
+ parser.add_argument("--width", type=int, default=512)
69
+ parser.add_argument("--height", type=int, default=320)
70
+
71
+ parser.add_argument(
72
+ "--num_workers",
73
+ type=int,
74
+ default=8,
75
+ help=(
76
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
77
+ ),
78
+ )
79
+
80
+ args = parser.parse_args()
81
+
82
+ return args
83
+
84
+
85
+ args = get_args()
86
+ ensure_dirname(args.output_dir)
87
+
88
+
89
+ color_list = []
90
+ for i in range(20):
91
+ color = np.concatenate([np.random.random(4)*255], axis=0)
92
+ color_list.append(color)
93
+
94
+
95
+ def interpolate_trajectory(points, n_points):
96
+ x = [point[0] for point in points]
97
+ y = [point[1] for point in points]
98
+
99
+ t = np.linspace(0, 1, len(points))
100
+
101
+ # fx = interp1d(t, x, kind='cubic')
102
+ # fy = interp1d(t, y, kind='cubic')
103
+ fx = PchipInterpolator(t, x)
104
+ fy = PchipInterpolator(t, y)
105
+
106
+ new_t = np.linspace(0, 1, n_points)
107
+
108
+ new_x = fx(new_t)
109
+ new_y = fy(new_t)
110
+ new_points = list(zip(new_x, new_y))
111
+
112
+ return new_points
113
+
114
+
115
+ def gen_gaussian_heatmap(imgSize=200):
116
+ circle_img = np.zeros((imgSize, imgSize), np.float32)
117
+ circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1)
118
+
119
+ isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)
120
+
121
+ for i in range(imgSize):
122
+ for j in range(imgSize):
123
+ isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp(
124
+ -1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2)))
125
+
126
+ isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
127
+ isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
128
+ isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8)
129
+
130
+ return isotropicGrayscaleImage
131
+
132
+
133
+ def get_vis_image(
134
+ target_size=(512 , 512), points=None, side=20,
135
+ num_frames=14,
136
+ # original_size=(512 , 512), args="", first_frame=None, is_mask = False, model_id=None,
137
+ ):
138
+
139
+ # images = []
140
+ vis_images = []
141
+ heatmap = gen_gaussian_heatmap()
142
+
143
+ trajectory_list = []
144
+ radius_list = []
145
+
146
+ for index, point in enumerate(points):
147
+ trajectories = [[int(i[0]), int(i[1])] for i in point]
148
+ trajectory_list.append(trajectories)
149
+
150
+ radius = 20
151
+ radius_list.append(radius)
152
+
153
+ if len(trajectory_list) == 0:
154
+ vis_images = [Image.fromarray(np.zeros(target_size, np.uint8)) for _ in range(num_frames)]
155
+ return vis_images
156
+
157
+ for idxx, point in enumerate(trajectory_list[0]):
158
+ new_img = np.zeros(target_size, np.uint8)
159
+ vis_img = new_img.copy()
160
+ # ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
161
+
162
+ if idxx >= args.num_frames:
163
+ break
164
+
165
+ # for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
166
+ for cc, (trajectory, radius) in enumerate(zip(trajectory_list, radius_list)):
167
+
168
+ center_coordinate = trajectory[idxx]
169
+ trajectory_ = trajectory[:idxx]
170
+ side = min(radius, 50)
171
+
172
+ y1 = max(center_coordinate[1] - side,0)
173
+ y2 = min(center_coordinate[1] + side, target_size[0] - 1)
174
+ x1 = max(center_coordinate[0] - side, 0)
175
+ x2 = min(center_coordinate[0] + side, target_size[1] - 1)
176
+
177
+ if x2-x1>3 and y2-y1>3:
178
+ need_map = cv2.resize(heatmap, (x2-x1, y2-y1))
179
+ new_img[y1:y2, x1:x2] = need_map.copy()
180
+
181
+ if cc >= 0:
182
+ vis_img[y1:y2,x1:x2] = need_map.copy()
183
+ if len(trajectory_) == 1:
184
+ vis_img[trajectory_[0][1], trajectory_[0][0]] = 255
185
+ else:
186
+ for itt in range(len(trajectory_)-1):
187
+ cv2.line(vis_img, (trajectory_[itt][0], trajectory_[itt][1]), (trajectory_[itt+1][0], trajectory_[itt+1][1]), (255, 255, 255), 3)
188
+
189
+ img = new_img
190
+
191
+ # Ensure all images are in RGB format
192
+ if len(img.shape) == 2: # Grayscale image
193
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
194
+ vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2RGB)
195
+ elif len(img.shape) == 3 and img.shape[2] == 3: # Color image in BGR format
196
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
197
+ vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
198
+
199
+ # Convert the numpy array to a PIL image
200
+ # pil_img = Image.fromarray(img)
201
+ # images.append(pil_img)
202
+ vis_images.append(Image.fromarray(vis_img))
203
+
204
+ return vis_images
205
+
206
+
207
+ def frames_to_video(frames_folder, output_video_path, fps=7):
208
+ frame_files = os.listdir(frames_folder)
209
+ # sort the frame files by their names
210
+ frame_files = sorted(frame_files, key=lambda x: int(x.split(".")[0]))
211
+
212
+ video = []
213
+ for frame_file in frame_files:
214
+ frame_path = os.path.join(frames_folder, frame_file)
215
+ frame = torchvision.io.read_image(frame_path)
216
+ video.append(frame)
217
+
218
+ video = torch.stack(video)
219
+ video = rearrange(video, 'T C H W -> T H W C')
220
+ torchvision.io.write_video(output_video_path, video, fps=fps)
221
+
222
+
223
+ def save_gifs_side_by_side(
224
+ batch_output,
225
+ validation_control_images,
226
+ output_folder,
227
+ target_size=(512 , 512),
228
+ duration=200,
229
+ point_tracks=None,
230
+ ):
231
+ flattened_batch_output = batch_output
232
+ def create_gif(image_list, gif_path, duration=100):
233
+ pil_images = [validate_and_convert_image(img, target_size=target_size) for img in image_list]
234
+ pil_images = [img for img in pil_images if img is not None]
235
+ if pil_images:
236
+ pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], loop=0, duration=duration)
237
+
238
+ # also save all the pil_images
239
+ tmp_folder = gif_path.replace(".gif", "")
240
+ print(tmp_folder)
241
+ ensure_dirname(tmp_folder)
242
+ tmp_frame_list = []
243
+ for idx, pil_image in enumerate(pil_images):
244
+ tmp_frame_path = os.path.join(tmp_folder, f"{idx}.png")
245
+ pil_image.save(tmp_frame_path)
246
+ tmp_frame_list.append(tmp_frame_path)
247
+
248
+ # also save as mp4
249
+ output_video_path = gif_path.replace(".gif", ".mp4")
250
+ frames_to_video(tmp_folder, output_video_path, fps=7)
251
+
252
+ # Creating GIFs for each image list
253
+ timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
254
+ gif_paths = []
255
+
256
+ for idx, image_list in enumerate([validation_control_images, flattened_batch_output]):
257
+
258
+ gif_path = os.path.join(output_folder.replace("vis_gif.gif", ""), f"temp_{idx}_{timestamp}.gif")
259
+ create_gif(image_list, gif_path)
260
+ gif_paths.append(gif_path)
261
+
262
+ # also save the point_tracks
263
+ assert point_tracks is not None
264
+ point_tracks_path = gif_path.replace(".gif", ".npy")
265
+ np.save(point_tracks_path, point_tracks.cpu().numpy())
266
+
267
+ # Function to combine GIFs side by side
268
+ def combine_gifs_side_by_side(gif_paths, output_path):
269
+ print(gif_paths)
270
+ gifs = [Image.open(gif) for gif in gif_paths]
271
+
272
+ # Assuming all gifs have the same frame count and duration
273
+ frames = []
274
+ for frame_idx in range(gifs[-1].n_frames):
275
+ combined_frame = None
276
+ for gif in gifs:
277
+ if frame_idx >= gif.n_frames:
278
+ gif.seek(gif.n_frames - 1)
279
+ else:
280
+ gif.seek(frame_idx)
281
+ if combined_frame is None:
282
+ combined_frame = gif.copy()
283
+ else:
284
+ combined_frame = get_concat_h(combined_frame, gif.copy(), gap=10)
285
+ frames.append(combined_frame)
286
+
287
+ if output_path.endswith(".mp4"):
288
+ video = [torchvision.transforms.functional.pil_to_tensor(frame) for frame in frames]
289
+ video = torch.stack(video)
290
+ video = rearrange(video, 'T C H W -> T H W C')
291
+ torchvision.io.write_video(output_path, video, fps=7)
292
+ print(f"Saved video to {output_path}")
293
+ else:
294
+ frames[0].save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=duration)
295
+
296
+ # Helper function to concatenate images horizontally
297
+ def get_concat_h(im1, im2, gap=10):
298
+ # # img first, heatmap second
299
+ # im1, im2 = im2, im1
300
+
301
+ dst = Image.new('RGB', (im1.width + im2.width + gap, max(im1.height, im2.height)), (255, 255, 255))
302
+ dst.paste(im1, (0, 0))
303
+ dst.paste(im2, (im1.width + gap, 0))
304
+ return dst
305
+
306
+ # Helper function to concatenate images vertically
307
+ def get_concat_v(im1, im2):
308
+ dst = Image.new('RGB', (max(im1.width, im2.width), im1.height + im2.height))
309
+ dst.paste(im1, (0, 0))
310
+ dst.paste(im2, (0, im1.height))
311
+ return dst
312
+
313
+ # Combine the GIFs into a single file
314
+ combined_gif_path = output_folder
315
+ combine_gifs_side_by_side(gif_paths, combined_gif_path)
316
+
317
+ combined_gif_path_v = gif_path.replace(".gif", "_v.mp4")
318
+ ensure_dirname(combined_gif_path_v.replace(".mp4", ""))
319
+ combine_gifs_side_by_side(gif_paths, combined_gif_path_v)
320
+
321
+ # # Clean up temporary GIFs
322
+ # for gif_path in gif_paths:
323
+ # os.remove(gif_path)
324
+
325
+ return combined_gif_path
326
+
327
+
328
+ # Define functions
329
+ def validate_and_convert_image(image, target_size=(512 , 512)):
330
+ if image is None:
331
+ print("Encountered a None image")
332
+ return None
333
+
334
+ if isinstance(image, torch.Tensor):
335
+ # Convert PyTorch tensor to PIL Image
336
+ if image.ndim == 3 and image.shape[0] in [1, 3]: # Check for CxHxW format
337
+ if image.shape[0] == 1: # Convert single-channel grayscale to RGB
338
+ image = image.repeat(3, 1, 1)
339
+ image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
340
+ image = Image.fromarray(image)
341
+ else:
342
+ print(f"Invalid image tensor shape: {image.shape}")
343
+ return None
344
+ elif isinstance(image, Image.Image):
345
+ # Resize PIL Image
346
+ image = image.resize(target_size)
347
+ else:
348
+ print("Image is not a PIL Image or a PyTorch tensor")
349
+ return None
350
+
351
+ return image
352
+
353
+
354
+ class Framer:
355
+
356
+ def __init__(self, device, args, height, width, model_length, dtype=torch.float16, use_sift=False):
357
+ self.device = device
358
+ self.dtype = dtype
359
+
360
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
361
+ os.path.join(args.model, "unet"),
362
+ torch_dtype=torch.float16,
363
+ low_cpu_mem_usage=True,
364
+ custom_resume=True,
365
+ )
366
+ unet = unet.to(device, dtype)
367
+
368
+ controlnet = ControlNetSVDModel.from_pretrained(
369
+ os.path.join(args.model, "controlnet"),
370
+ )
371
+ controlnet = controlnet.to(device, dtype)
372
+
373
+ if is_xformers_available():
374
+ import xformers
375
+ xformers_version = version.parse(xformers.__version__)
376
+ unet.enable_xformers_memory_efficient_attention()
377
+ # controlnet.enable_xformers_memory_efficient_attention()
378
+ else:
379
+ raise ValueError(
380
+ "xformers is not available. Make sure it is installed correctly")
381
+
382
+ pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
383
+ "stabilityai/stable-video-diffusion-img2vid-xt",
384
+ unet=unet,
385
+ controlnet=controlnet,
386
+ low_cpu_mem_usage=False,
387
+ torch_dtype=torch.float16, variant="fp16", local_files_only=True,
388
+ )
389
+ pipe.to(device)
390
+
391
+ self.pipeline = pipe
392
+ # self.pipeline.enable_model_cpu_offload()
393
+
394
+ self.height = height
395
+ self.width = width
396
+ self.args = args
397
+ self.model_length = model_length
398
+ self.use_sift = use_sift
399
+
400
+ def run(self, first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
401
+ original_width, original_height = 512, 320 # TODO
402
+
403
+ # load_image
404
+ image = Image.open(first_frame_path).convert('RGB')
405
+ width, height = image.size
406
+ image = image.resize((self.width, self.height))
407
+
408
+ image_end = Image.open(last_frame_path).convert('RGB')
409
+ image_end = image_end.resize((self.width, self.height))
410
+
411
+ input_all_points = tracking_points.constructor_args['value']
412
+
413
+ sift_track_update = False
414
+ anchor_points_flag = None
415
+
416
+ if (len(input_all_points) == 0) and self.use_sift:
417
+ sift_track_update = True
418
+ controlnet_cond_scale = 0.5
419
+
420
+ from models_diffusers.sift_match import sift_match
421
+ from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory
422
+
423
+ output_file_sift = os.path.join(args.output_dir, "sift.png")
424
+
425
+ # (f, topk, 2), f=2 (before interpolation)
426
+ pred_tracks = sift_match(
427
+ image,
428
+ image_end,
429
+ thr=0.5,
430
+ topk=5,
431
+ method="random",
432
+ output_path=output_file_sift,
433
+ )
434
+
435
+ if pred_tracks is not None:
436
+ # interpolate the tracks, following draganything gradio demo
437
+ pred_tracks = sift_interpolate_trajectory(pred_tracks, num_frames=self.model_length)
438
+
439
+ anchor_points_flag = torch.zeros((self.model_length, pred_tracks.shape[1])).to(pred_tracks.device)
440
+ anchor_points_flag[0] = 1
441
+ anchor_points_flag[-1] = 1
442
+
443
+ pred_tracks = pred_tracks.permute(1, 0, 2) # (num_points, num_frames, 2)
444
+
445
+ else:
446
+
447
+ resized_all_points = [
448
+ tuple([
449
+ tuple([int(e1[0] * self.width / original_width), int(e1[1] * self.height / original_height)])
450
+ for e1 in e])
451
+ for e in input_all_points
452
+ ]
453
+
454
+ # a list of num_tracks tuples, each tuple contains a track with several points, represented as (x, y)
455
+ # in image w & h scale
456
+
457
+ for idx, splited_track in enumerate(resized_all_points):
458
+ if len(splited_track) == 0:
459
+ warnings.warn("running without point trajectory control")
460
+ continue
461
+
462
+ if len(splited_track) == 1: # stationary point
463
+ displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
464
+ splited_track = tuple([splited_track[0], displacement_point])
465
+ # interpolate the track
466
+ splited_track = interpolate_trajectory(splited_track, self.model_length)
467
+ splited_track = splited_track[:self.model_length]
468
+ resized_all_points[idx] = splited_track
469
+
470
+ pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
471
+
472
+ vis_images = get_vis_image(
473
+ target_size=(self.args.height, self.args.width),
474
+ points=pred_tracks,
475
+ num_frames=self.model_length,
476
+ )
477
+
478
+ if len(pred_tracks.shape) != 3:
479
+ print("pred_tracks.shape", pred_tracks.shape)
480
+ with_control = False
481
+ controlnet_cond_scale = 0.0
482
+ else:
483
+ with_control = True
484
+ pred_tracks = pred_tracks.permute(1, 0, 2).to(self.device, self.dtype) # (num_frames, num_points, 2)
485
+
486
+ point_embedding = None
487
+ video_frames = self.pipeline(
488
+ image,
489
+ image_end,
490
+ # trajectory control
491
+ with_control=with_control,
492
+ point_tracks=pred_tracks,
493
+ point_embedding=point_embedding,
494
+ with_id_feature=False,
495
+ controlnet_cond_scale=controlnet_cond_scale,
496
+ # others
497
+ num_frames=14,
498
+ width=width,
499
+ height=height,
500
+ # decode_chunk_size=8,
501
+ # generator=generator,
502
+ motion_bucket_id=motion_bucket_id,
503
+ fps=7,
504
+ num_inference_steps=30,
505
+ # track
506
+ sift_track_update=sift_track_update,
507
+ anchor_points_flag=anchor_points_flag,
508
+ ).frames[0]
509
+
510
+ vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
511
+ vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
512
+ vis_images = [Image.fromarray(img) for img in vis_images]
513
+
514
+ # video_frames = [img for sublist in video_frames for img in sublist]
515
+ val_save_dir = os.path.join(args.output_dir, "vis_gif.gif")
516
+ save_gifs_side_by_side(
517
+ video_frames,
518
+ vis_images[:self.model_length],
519
+ val_save_dir,
520
+ target_size=(self.width, self.height),
521
+ duration=110,
522
+ point_tracks=pred_tracks,
523
+ )
524
+
525
+ return val_save_dir
526
+
527
+
528
+ with gr.Blocks() as demo:
529
+ gr.Markdown("""<h1 align="center">Framer: Interactive Frame Interpolation</h1><br>""")
530
+
531
+ gr.Markdown("""Gradio Demo for <a href='https://arxiv.org/abs/2410.18978'><b>Framer: Interactive Frame Interpolation</b></a>.<br>
532
+ Github Repo can be found at https://github.com/aim-uofa/Framer<br>
533
+ The template is inspired by DragAnything.""")
534
+
535
+ gr.Image(label="Framer: Interactive Frame Interpolation", value="assets/demos.gif", height=432, width=768)
536
+
537
+ gr.Markdown("""## Usage: <br>
538
+ 1. Upload images<br>
539
+ &ensp; 1.1 Upload the start image via the "Upload Start Image" button.<br>
540
+ &ensp; 1.2. Upload the end image via the "Upload End Image" button.<br>
541
+ 2. (Optional) Draw some drags.<br>
542
+ &ensp; 2.1. Click "Add Drag Trajectory" to add the motion trajectory.<br>
543
+ &ensp; 2.2. You can click several points on either start or end image to forms a path.<br>
544
+ &ensp; 2.3. Click "Delete last drag" to delete the whole lastest path.<br>
545
+ &ensp; 2.4. Click "Delete last step" to delete the lastest clicked control point.<br>
546
+ 3. Interpolate the images (according the path) with a click on "Run" button. <br>""")
547
+
548
+ # device, args, height, width, model_length
549
+ Framer = Framer("cuda:0", args, 320, 512, 14)
550
+ first_frame_path = gr.State()
551
+ last_frame_path = gr.State()
552
+ tracking_points = gr.State([])
553
+
554
+ def reset_states(first_frame_path, last_frame_path, tracking_points):
555
+ first_frame_path = gr.State()
556
+ last_frame_path = gr.State()
557
+ tracking_points = gr.State([])
558
+
559
+ return first_frame_path, last_frame_path, tracking_points
560
+
561
+
562
+ def preprocess_image(image):
563
+
564
+ image_pil = image2pil(image.name)
565
+
566
+ raw_w, raw_h = image_pil.size
567
+ # resize_ratio = max(512 / raw_w, 320 / raw_h)
568
+ # image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
569
+ # image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
570
+ image_pil = image_pil.resize((512, 320), Image.BILINEAR)
571
+
572
+ first_frame_path = os.path.join(args.output_dir, f"first_frame_{str(uuid.uuid4())[:4]}.png")
573
+
574
+ image_pil.save(first_frame_path)
575
+
576
+ return first_frame_path, first_frame_path, gr.State([])
577
+
578
+
579
+ def preprocess_image_end(image_end):
580
+
581
+ image_end_pil = image2pil(image_end.name)
582
+
583
+ raw_w, raw_h = image_end_pil.size
584
+ # resize_ratio = max(512 / raw_w, 320 / raw_h)
585
+ # image_end_pil = image_end_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
586
+ # image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
587
+ image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
588
+
589
+ last_frame_path = os.path.join(args.output_dir, f"last_frame_{str(uuid.uuid4())[:4]}.png")
590
+
591
+ image_end_pil.save(last_frame_path)
592
+
593
+ return last_frame_path, last_frame_path, gr.State([])
594
+
595
+
596
+ def add_drag(tracking_points):
597
+ tracking_points.constructor_args['value'].append([])
598
+ return tracking_points
599
+
600
+
601
+ def delete_last_drag(tracking_points, first_frame_path, last_frame_path):
602
+ tracking_points.constructor_args['value'].pop()
603
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
604
+ transparent_background_end = Image.open(last_frame_path).convert('RGBA')
605
+ w, h = transparent_background.size
606
+ transparent_layer = np.zeros((h, w, 4))
607
+
608
+ for track in tracking_points.constructor_args['value']:
609
+ if len(track) > 1:
610
+ for i in range(len(track)-1):
611
+ start_point = track[i]
612
+ end_point = track[i+1]
613
+ vx = end_point[0] - start_point[0]
614
+ vy = end_point[1] - start_point[1]
615
+ arrow_length = np.sqrt(vx**2 + vy**2)
616
+ if i == len(track)-2:
617
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
618
+ else:
619
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
620
+ else:
621
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
622
+
623
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
624
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
625
+ trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
626
+
627
+ return tracking_points, trajectory_map, trajectory_map_end
628
+
629
+
630
+ def delete_last_step(tracking_points, first_frame_path, last_frame_path):
631
+ tracking_points.constructor_args['value'][-1].pop()
632
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
633
+ transparent_background_end = Image.open(last_frame_path).convert('RGBA')
634
+ w, h = transparent_background.size
635
+ transparent_layer = np.zeros((h, w, 4))
636
+
637
+ for track in tracking_points.constructor_args['value']:
638
+ if len(track) > 1:
639
+ for i in range(len(track)-1):
640
+ start_point = track[i]
641
+ end_point = track[i+1]
642
+ vx = end_point[0] - start_point[0]
643
+ vy = end_point[1] - start_point[1]
644
+ arrow_length = np.sqrt(vx**2 + vy**2)
645
+ if i == len(track)-2:
646
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
647
+ else:
648
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
649
+ else:
650
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
651
+
652
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
653
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
654
+ trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
655
+
656
+ return tracking_points, trajectory_map, trajectory_map_end
657
+
658
+
659
+ def add_tracking_points(tracking_points, first_frame_path, last_frame_path, evt: gr.SelectData): # SelectData is a subclass of EventData
660
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
661
+ tracking_points.constructor_args['value'][-1].append(evt.index)
662
+
663
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
664
+ transparent_background_end = Image.open(last_frame_path).convert('RGBA')
665
+
666
+ w, h = transparent_background.size
667
+ transparent_layer = 0
668
+ for idx, track in enumerate(tracking_points.constructor_args['value']):
669
+ # mask = cv2.imread(
670
+ # os.path.join(args.output_dir, f"mask_{idx+1}.jpg")
671
+ # )
672
+ mask = np.zeros((320, 512, 3))
673
+ color = color_list[idx+1]
674
+ transparent_layer = mask[:, :, 0].reshape(h, w, 1) * color.reshape(1, 1, -1) + transparent_layer
675
+
676
+ if len(track) > 1:
677
+ for i in range(len(track)-1):
678
+ start_point = track[i]
679
+ end_point = track[i+1]
680
+ vx = end_point[0] - start_point[0]
681
+ vy = end_point[1] - start_point[1]
682
+ arrow_length = np.sqrt(vx**2 + vy**2)
683
+ if i == len(track)-2:
684
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
685
+ else:
686
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
687
+ else:
688
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
689
+
690
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
691
+ alpha_coef = 0.99
692
+ im2_data = transparent_layer.getdata()
693
+ new_im2_data = [(r, g, b, int(a * alpha_coef)) for r, g, b, a in im2_data]
694
+ transparent_layer.putdata(new_im2_data)
695
+
696
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
697
+ trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
698
+
699
+ return tracking_points, trajectory_map, trajectory_map_end
700
+
701
+ with gr.Row():
702
+ with gr.Column(scale=1):
703
+ image_upload_button = gr.UploadButton(label="Upload Start Image", file_types=["image"])
704
+ image_end_upload_button = gr.UploadButton(label="Upload End Image", file_types=["image"])
705
+ # select_area_button = gr.Button(value="Select Area with SAM")
706
+ add_drag_button = gr.Button(value="Add New Drag Trajectory")
707
+ reset_button = gr.Button(value="Reset")
708
+ run_button = gr.Button(value="Run")
709
+ delete_last_drag_button = gr.Button(value="Delete last drag")
710
+ delete_last_step_button = gr.Button(value="Delete last step")
711
+
712
+ with gr.Column(scale=7):
713
+ with gr.Row():
714
+ with gr.Column(scale=6):
715
+ input_image = gr.Image(
716
+ label="start frame",
717
+ interactive=True,
718
+ height=320,
719
+ width=512,
720
+ )
721
+
722
+ with gr.Column(scale=6):
723
+ input_image_end = gr.Image(
724
+ label="end frame",
725
+ interactive=True,
726
+ height=320,
727
+ width=512,
728
+ )
729
+
730
+ with gr.Row():
731
+ with gr.Column(scale=1):
732
+
733
+ controlnet_cond_scale = gr.Slider(
734
+ label='Control Scale',
735
+ minimum=0.0,
736
+ maximum=10,
737
+ step=0.1,
738
+ value=1.0,
739
+ )
740
+
741
+ motion_bucket_id = gr.Slider(
742
+ label='Motion Bucket',
743
+ minimum=1,
744
+ maximum=180,
745
+ step=1,
746
+ value=100,
747
+ )
748
+
749
+ with gr.Column(scale=5):
750
+ output_video = gr.Image(
751
+ label="Output Video",
752
+ height=320,
753
+ width=1152,
754
+ )
755
+
756
+
757
+ with gr.Row():
758
+ gr.Markdown("""
759
+ ## Citation
760
+ ```bibtex
761
+ @article{wang2024framer,
762
+ title={Framer: Interactive Frame Interpolation},
763
+ author={Wang, Wen and Wang, Qiuyu and Zheng, Kecheng and Ouyang, Hao and Chen, Zhekai and Gong, Biao and Chen, Hao and Shen, Yujun and Shen, Chunhua},
764
+ journal={arXiv preprint https://arxiv.org/abs/2410.18978},
765
+ year={2024}
766
+ }
767
+ ```
768
+ """)
769
+
770
+ image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points])
771
+
772
+ image_end_upload_button.upload(preprocess_image_end, image_end_upload_button, [input_image_end, last_frame_path, tracking_points])
773
+
774
+ add_drag_button.click(add_drag, tracking_points, [tracking_points, ])
775
+
776
+ delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
777
+
778
+ delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
779
+
780
+ reset_button.click(reset_states, [first_frame_path, last_frame_path, tracking_points], [first_frame_path, last_frame_path, tracking_points])
781
+
782
+ input_image.select(add_tracking_points, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
783
+
784
+ input_image_end.select(add_tracking_points, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
785
+
786
+ run_button.click(Framer.run, [first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id], output_video)
787
+
788
+ demo.queue().launch()
assets/demos.gif ADDED

Git LFS Details

  • SHA256: 77f65e13a8fb42a36ace4dd01dbbcc21e0d8bd65231a65a656e47f1cdc48fcf9
  • Pointer size: 132 Bytes
  • Size of remote file: 5.72 MB
gradio_demo/utils_drag.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ import os
3
+ import sys
4
+ import shutil
5
+ import logging
6
+ import colorlog
7
+ from tqdm import tqdm
8
+ import time
9
+ import yaml
10
+ import random
11
+ import importlib
12
+ from PIL import Image
13
+ from warnings import simplefilter
14
+ import imageio
15
+ import math
16
+ import collections
17
+ import json
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.optim import Adam
22
+ import torch.nn.functional as F
23
+ from torch.utils.data import DataLoader
24
+ from torch.utils.data import DataLoader, Dataset
25
+ from einops import rearrange, repeat
26
+ import torch.distributed as dist
27
+ from torchvision import datasets, transforms, utils
28
+
29
+ logging.getLogger().setLevel(logging.WARNING)
30
+ simplefilter(action='ignore', category=FutureWarning)
31
+
32
+
33
+ def get_logger(filename=None):
34
+ """
35
+ examples:
36
+ logger = get_logger('try_logging.txt')
37
+
38
+ logger.debug("Do something.")
39
+ logger.info("Start print log.")
40
+ logger.warning("Something maybe fail.")
41
+ try:
42
+ raise ValueError()
43
+ except ValueError:
44
+ logger.error("Error", exc_info=True)
45
+
46
+ tips:
47
+ DO NOT logger.inf(some big tensors since color may not helpful.)
48
+ """
49
+ logger = logging.getLogger('utils')
50
+ level = logging.DEBUG
51
+ logger.setLevel(level=level)
52
+ # Use propagate to avoid multiple loggings.
53
+ logger.propagate = False
54
+ # Remove %(levelname)s since we have colorlog to represent levelname.
55
+ format_str = '[%(asctime)s <%(filename)s:%(lineno)d> %(funcName)s] %(message)s'
56
+
57
+ streamHandler = logging.StreamHandler()
58
+ streamHandler.setLevel(level)
59
+ coloredFormatter = colorlog.ColoredFormatter(
60
+ '%(log_color)s' + format_str,
61
+ datefmt='%Y-%m-%d %H:%M:%S',
62
+ reset=True,
63
+ log_colors={
64
+ 'DEBUG': 'cyan',
65
+ # 'INFO': 'white',
66
+ 'WARNING': 'yellow',
67
+ 'ERROR': 'red',
68
+ 'CRITICAL': 'reg,bg_white',
69
+ }
70
+ )
71
+
72
+ streamHandler.setFormatter(coloredFormatter)
73
+ logger.addHandler(streamHandler)
74
+
75
+ if filename:
76
+ fileHandler = logging.FileHandler(filename)
77
+ fileHandler.setLevel(level)
78
+ formatter = logging.Formatter(format_str)
79
+ fileHandler.setFormatter(formatter)
80
+ logger.addHandler(fileHandler)
81
+
82
+ # Fix multiple logging for torch.distributed
83
+ try:
84
+ class UniqueLogger:
85
+ def __init__(self, logger):
86
+ self.logger = logger
87
+ self.local_rank = torch.distributed.get_rank()
88
+
89
+ def info(self, msg, *args, **kwargs):
90
+ if self.local_rank == 0:
91
+ return self.logger.info(msg, *args, **kwargs)
92
+
93
+ def warning(self, msg, *args, **kwargs):
94
+ if self.local_rank == 0:
95
+ return self.logger.warning(msg, *args, **kwargs)
96
+
97
+ logger = UniqueLogger(logger)
98
+ # AssertionError for gpu with no distributed
99
+ # AttributeError for no gpu.
100
+ except Exception:
101
+ pass
102
+ return logger
103
+
104
+
105
+ logger = get_logger()
106
+
107
+ def split_filename(filename):
108
+ absname = os.path.abspath(filename)
109
+ dirname, basename = os.path.split(absname)
110
+ split_tmp = basename.rsplit('.', maxsplit=1)
111
+ if len(split_tmp) == 2:
112
+ rootname, extname = split_tmp
113
+ elif len(split_tmp) == 1:
114
+ rootname = split_tmp[0]
115
+ extname = None
116
+ else:
117
+ raise ValueError("programming error!")
118
+ return dirname, rootname, extname
119
+
120
+ def data2file(data, filename, type=None, override=False, printable=False, **kwargs):
121
+ dirname, rootname, extname = split_filename(filename)
122
+ print_did_not_save_flag = True
123
+ if type:
124
+ extname = type
125
+ if not os.path.exists(dirname):
126
+ os.makedirs(dirname, exist_ok=True)
127
+
128
+ if not os.path.exists(filename) or override:
129
+ if extname in ['jpg', 'png', 'jpeg']:
130
+ utils.save_image(data, filename, **kwargs)
131
+ elif extname == 'gif':
132
+ imageio.mimsave(filename, data, format='GIF', duration=kwargs.get('duration'), loop=0)
133
+ elif extname == 'txt':
134
+ if kwargs is None:
135
+ kwargs = {}
136
+ max_step = kwargs.get('max_step')
137
+ if max_step is None:
138
+ max_step = np.Infinity
139
+
140
+ with open(filename, 'w', encoding='utf-8') as f:
141
+ for i, e in enumerate(data):
142
+ if i < max_step:
143
+ f.write(str(e) + '\n')
144
+ else:
145
+ break
146
+ else:
147
+ raise ValueError('Do not support this type')
148
+ if printable: logger.info('Saved data to %s' % os.path.abspath(filename))
149
+ else:
150
+ if print_did_not_save_flag: logger.info(
151
+ 'Did not save data to %s because file exists and override is False' % os.path.abspath(
152
+ filename))
153
+
154
+
155
+ def file2data(filename, type=None, printable=True, **kwargs):
156
+ dirname, rootname, extname = split_filename(filename)
157
+ print_load_flag = True
158
+ if type:
159
+ extname = type
160
+
161
+ if extname in ['pth', 'ckpt']:
162
+ data = torch.load(filename, map_location=kwargs.get('map_location'))
163
+ elif extname == 'txt':
164
+ top = kwargs.get('top', None)
165
+ with open(filename, encoding='utf-8') as f:
166
+ if top:
167
+ data = [f.readline() for _ in range(top)]
168
+ else:
169
+ data = [e for e in f.read().split('\n') if e]
170
+ elif extname == 'yaml':
171
+ with open(filename, 'r') as f:
172
+ data = yaml.load(f)
173
+ else:
174
+ raise ValueError('type can only support h5, npy, json, txt')
175
+ if printable:
176
+ if print_load_flag:
177
+ logger.info('Loaded data from %s' % os.path.abspath(filename))
178
+ return data
179
+
180
+
181
+ def ensure_dirname(dirname, override=False):
182
+ if os.path.exists(dirname) and override:
183
+ logger.info('Removing dirname: %s' % os.path.abspath(dirname))
184
+ try:
185
+ shutil.rmtree(dirname)
186
+ except OSError as e:
187
+ raise ValueError('Failed to delete %s because %s' % (dirname, e))
188
+
189
+ if not os.path.exists(dirname):
190
+ logger.info('Making dirname: %s' % os.path.abspath(dirname))
191
+ os.makedirs(dirname, exist_ok=True)
192
+
193
+
194
+ def import_filename(filename):
195
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
196
+ module = importlib.util.module_from_spec(spec)
197
+ sys.modules[spec.name] = module
198
+ spec.loader.exec_module(module)
199
+ return module
200
+
201
+
202
+ def adaptively_load_state_dict(target, state_dict):
203
+ target_dict = target.state_dict()
204
+
205
+ try:
206
+ common_dict = {k: v for k, v in state_dict.items() if k in target_dict and v.size() == target_dict[k].size()}
207
+ except Exception as e:
208
+ logger.warning('load error %s', e)
209
+ common_dict = {k: v for k, v in state_dict.items() if k in target_dict}
210
+
211
+ if 'param_groups' in common_dict and common_dict['param_groups'][0]['params'] != \
212
+ target.state_dict()['param_groups'][0]['params']:
213
+ logger.warning('Detected mismatch params, auto adapte state_dict to current')
214
+ common_dict['param_groups'][0]['params'] = target.state_dict()['param_groups'][0]['params']
215
+ target_dict.update(common_dict)
216
+ target.load_state_dict(target_dict)
217
+
218
+ missing_keys = [k for k in target_dict.keys() if k not in common_dict]
219
+ unexpected_keys = [k for k in state_dict.keys() if k not in common_dict]
220
+
221
+ if len(unexpected_keys) != 0:
222
+ logger.warning(
223
+ f"Some weights of state_dict were not used in target: {unexpected_keys}"
224
+ )
225
+ if len(missing_keys) != 0:
226
+ logger.warning(
227
+ f"Some weights of state_dict are missing used in target {missing_keys}"
228
+ )
229
+ if len(unexpected_keys) == 0 and len(missing_keys) == 0:
230
+ logger.warning("Strictly Loaded state_dict.")
231
+
232
+
233
+ def set_seed(seed=42):
234
+ random.seed(seed)
235
+ os.environ['PYHTONHASHSEED'] = str(seed)
236
+ np.random.seed(seed)
237
+ torch.manual_seed(seed)
238
+ torch.cuda.manual_seed(seed)
239
+ torch.backends.cudnn.deterministic = True
240
+
241
+ def image2pil(filename):
242
+ return Image.open(filename)
243
+
244
+
245
+ def image2arr(filename):
246
+ pil = image2pil(filename)
247
+ return pil2arr(pil)
248
+
249
+
250
+ def pil2arr(pil):
251
+ if isinstance(pil, list):
252
+ arr = np.array(
253
+ [np.array(e.convert('RGB').getdata(), dtype=np.uint8).reshape(e.size[1], e.size[0], 3) for e in pil])
254
+ else:
255
+ arr = np.array(pil)
256
+ return arr
257
+
258
+
259
+ def arr2pil(arr):
260
+ if arr.ndim == 3:
261
+ return Image.fromarray(arr.astype('uint8'), 'RGB')
262
+ elif arr.ndim == 4:
263
+ return [Image.fromarray(e.astype('uint8'), 'RGB') for e in list(arr)]
264
+ else:
265
+ raise ValueError('arr must has ndim of 3 or 4, but got %s' % arr.ndim)
266
+
267
+
268
+ def notebook_show(*images):
269
+ from IPython.display import Image
270
+ from IPython.display import display
271
+ display(*[Image(e) for e in images])
models_diffusers/attention.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from diffusers.utils import USE_PEFT_BACKEND
20
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
21
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
22
+ # from diffusers.models.attention_processor import Attention
23
+ from models_diffusers.attention_processor import Attention
24
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
+ from diffusers.models.lora import LoRACompatibleLinear
26
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
27
+
28
+
29
+ def _chunked_feed_forward(
30
+ ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
31
+ ):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ if lora_scale is None:
40
+ ff_output = torch.cat(
41
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
42
+ dim=chunk_dim,
43
+ )
44
+ else:
45
+ # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
46
+ ff_output = torch.cat(
47
+ [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
48
+ dim=chunk_dim,
49
+ )
50
+
51
+ return ff_output
52
+
53
+
54
+ @maybe_allow_in_graph
55
+ class GatedSelfAttentionDense(nn.Module):
56
+ r"""
57
+ A gated self-attention dense layer that combines visual features and object features.
58
+
59
+ Parameters:
60
+ query_dim (`int`): The number of channels in the query.
61
+ context_dim (`int`): The number of channels in the context.
62
+ n_heads (`int`): The number of heads to use for attention.
63
+ d_head (`int`): The number of channels in each head.
64
+ """
65
+
66
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
67
+ super().__init__()
68
+
69
+ # we need a linear projection since we need cat visual feature and obj feature
70
+ self.linear = nn.Linear(context_dim, query_dim)
71
+
72
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
73
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
74
+
75
+ self.norm1 = nn.LayerNorm(query_dim)
76
+ self.norm2 = nn.LayerNorm(query_dim)
77
+
78
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
79
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
80
+
81
+ self.enabled = True
82
+
83
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
84
+ if not self.enabled:
85
+ return x
86
+
87
+ n_visual = x.shape[1]
88
+ objs = self.linear(objs)
89
+
90
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
91
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
92
+
93
+ return x
94
+
95
+
96
+ @maybe_allow_in_graph
97
+ class BasicTransformerBlock(nn.Module):
98
+ r"""
99
+ A basic Transformer block.
100
+
101
+ Parameters:
102
+ dim (`int`): The number of channels in the input and output.
103
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
104
+ attention_head_dim (`int`): The number of channels in each head.
105
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
106
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
107
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
108
+ num_embeds_ada_norm (:
109
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
110
+ attention_bias (:
111
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
112
+ only_cross_attention (`bool`, *optional*):
113
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
114
+ double_self_attention (`bool`, *optional*):
115
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
116
+ upcast_attention (`bool`, *optional*):
117
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
118
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
119
+ Whether to use learnable elementwise affine parameters for normalization.
120
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
121
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
122
+ final_dropout (`bool` *optional*, defaults to False):
123
+ Whether to apply a final dropout after the last feed-forward layer.
124
+ attention_type (`str`, *optional*, defaults to `"default"`):
125
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
126
+ positional_embeddings (`str`, *optional*, defaults to `None`):
127
+ The type of positional embeddings to apply to.
128
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
129
+ The maximum number of positional embeddings to apply.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ dim: int,
135
+ num_attention_heads: int,
136
+ attention_head_dim: int,
137
+ dropout=0.0,
138
+ cross_attention_dim: Optional[int] = None,
139
+ activation_fn: str = "geglu",
140
+ num_embeds_ada_norm: Optional[int] = None,
141
+ attention_bias: bool = False,
142
+ only_cross_attention: bool = False,
143
+ double_self_attention: bool = False,
144
+ upcast_attention: bool = False,
145
+ norm_elementwise_affine: bool = True,
146
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
147
+ norm_eps: float = 1e-5,
148
+ final_dropout: bool = False,
149
+ attention_type: str = "default",
150
+ positional_embeddings: Optional[str] = None,
151
+ num_positional_embeddings: Optional[int] = None,
152
+ ):
153
+ super().__init__()
154
+ self.only_cross_attention = only_cross_attention
155
+
156
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
157
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
158
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
159
+ self.use_layer_norm = norm_type == "layer_norm"
160
+
161
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
162
+ raise ValueError(
163
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
164
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
165
+ )
166
+
167
+ if positional_embeddings and (num_positional_embeddings is None):
168
+ raise ValueError(
169
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
170
+ )
171
+
172
+ if positional_embeddings == "sinusoidal":
173
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
174
+ else:
175
+ self.pos_embed = None
176
+
177
+ # Define 3 blocks. Each block has its own normalization layer.
178
+ # 1. Self-Attn
179
+ if self.use_ada_layer_norm:
180
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
181
+ elif self.use_ada_layer_norm_zero:
182
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
183
+ else:
184
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
185
+
186
+ self.attn1 = Attention(
187
+ query_dim=dim,
188
+ heads=num_attention_heads,
189
+ dim_head=attention_head_dim,
190
+ dropout=dropout,
191
+ bias=attention_bias,
192
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
193
+ upcast_attention=upcast_attention,
194
+ )
195
+
196
+ # 2. Cross-Attn
197
+ if cross_attention_dim is not None or double_self_attention:
198
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
199
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
200
+ # the second cross attention block.
201
+ self.norm2 = (
202
+ AdaLayerNorm(dim, num_embeds_ada_norm)
203
+ if self.use_ada_layer_norm
204
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
205
+ )
206
+ self.attn2 = Attention(
207
+ query_dim=dim,
208
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
209
+ heads=num_attention_heads,
210
+ dim_head=attention_head_dim,
211
+ dropout=dropout,
212
+ bias=attention_bias,
213
+ upcast_attention=upcast_attention,
214
+ ) # is self-attn if encoder_hidden_states is none
215
+ else:
216
+ self.norm2 = None
217
+ self.attn2 = None
218
+
219
+ # 3. Feed-forward
220
+ if not self.use_ada_layer_norm_single:
221
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
222
+
223
+ self.ff = FeedForward(
224
+ dim,
225
+ dropout=dropout,
226
+ activation_fn=activation_fn,
227
+ final_dropout=final_dropout,
228
+ )
229
+
230
+ # 4. Fuser
231
+ if attention_type == "gated" or attention_type == "gated-text-image":
232
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
233
+
234
+ # 5. Scale-shift for PixArt-Alpha.
235
+ if self.use_ada_layer_norm_single:
236
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
237
+
238
+ # let chunk size default to None
239
+ self._chunk_size = None
240
+ self._chunk_dim = 0
241
+
242
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
243
+ # Sets chunk feed-forward
244
+ self._chunk_size = chunk_size
245
+ self._chunk_dim = dim
246
+
247
+ def forward(
248
+ self,
249
+ hidden_states: torch.FloatTensor,
250
+ attention_mask: Optional[torch.FloatTensor] = None,
251
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
252
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
253
+ timestep: Optional[torch.LongTensor] = None,
254
+ cross_attention_kwargs: Dict[str, Any] = None,
255
+ class_labels: Optional[torch.LongTensor] = None,
256
+ ) -> torch.FloatTensor:
257
+ # Notice that normalization is always applied before the real computation in the following blocks.
258
+ # 0. Self-Attention
259
+ batch_size = hidden_states.shape[0]
260
+
261
+ if self.use_ada_layer_norm:
262
+ norm_hidden_states = self.norm1(hidden_states, timestep)
263
+ elif self.use_ada_layer_norm_zero:
264
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
265
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
266
+ )
267
+ elif self.use_layer_norm:
268
+ norm_hidden_states = self.norm1(hidden_states)
269
+ elif self.use_ada_layer_norm_single:
270
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
271
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
272
+ ).chunk(6, dim=1)
273
+ norm_hidden_states = self.norm1(hidden_states)
274
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
275
+ norm_hidden_states = norm_hidden_states.squeeze(1)
276
+ else:
277
+ raise ValueError("Incorrect norm used")
278
+
279
+ if self.pos_embed is not None:
280
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
281
+
282
+ # 1. Retrieve lora scale.
283
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
284
+
285
+ # 2. Prepare GLIGEN inputs
286
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
287
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
288
+
289
+ attn_output = self.attn1(
290
+ norm_hidden_states,
291
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
292
+ attention_mask=attention_mask,
293
+ **cross_attention_kwargs,
294
+ )
295
+ if self.use_ada_layer_norm_zero:
296
+ attn_output = gate_msa.unsqueeze(1) * attn_output
297
+ elif self.use_ada_layer_norm_single:
298
+ attn_output = gate_msa * attn_output
299
+
300
+ hidden_states = attn_output + hidden_states
301
+ if hidden_states.ndim == 4:
302
+ hidden_states = hidden_states.squeeze(1)
303
+
304
+ # 2.5 GLIGEN Control
305
+ if gligen_kwargs is not None:
306
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
307
+
308
+ # 3. Cross-Attention
309
+ if self.attn2 is not None:
310
+ if self.use_ada_layer_norm:
311
+ norm_hidden_states = self.norm2(hidden_states, timestep)
312
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
313
+ norm_hidden_states = self.norm2(hidden_states)
314
+ elif self.use_ada_layer_norm_single:
315
+ # For PixArt norm2 isn't applied here:
316
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
317
+ norm_hidden_states = hidden_states
318
+ else:
319
+ raise ValueError("Incorrect norm")
320
+
321
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
322
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
323
+
324
+ attn_output = self.attn2(
325
+ norm_hidden_states,
326
+ encoder_hidden_states=encoder_hidden_states,
327
+ attention_mask=encoder_attention_mask,
328
+ **cross_attention_kwargs,
329
+ )
330
+ hidden_states = attn_output + hidden_states
331
+
332
+ # 4. Feed-forward
333
+ if not self.use_ada_layer_norm_single:
334
+ norm_hidden_states = self.norm3(hidden_states)
335
+
336
+ if self.use_ada_layer_norm_zero:
337
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
338
+
339
+ if self.use_ada_layer_norm_single:
340
+ norm_hidden_states = self.norm2(hidden_states)
341
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
342
+
343
+ if self._chunk_size is not None:
344
+ # "feed_forward_chunk_size" can be used to save memory
345
+ ff_output = _chunked_feed_forward(
346
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
347
+ )
348
+ else:
349
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
350
+
351
+ if self.use_ada_layer_norm_zero:
352
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
353
+ elif self.use_ada_layer_norm_single:
354
+ ff_output = gate_mlp * ff_output
355
+
356
+ hidden_states = ff_output + hidden_states
357
+ if hidden_states.ndim == 4:
358
+ hidden_states = hidden_states.squeeze(1)
359
+
360
+ return hidden_states
361
+
362
+
363
+ @maybe_allow_in_graph
364
+ class TemporalBasicTransformerBlock(nn.Module):
365
+ r"""
366
+ A basic Transformer block for video like data.
367
+
368
+ Parameters:
369
+ dim (`int`): The number of channels in the input and output.
370
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
371
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
372
+ attention_head_dim (`int`): The number of channels in each head.
373
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ dim: int,
379
+ time_mix_inner_dim: int,
380
+ num_attention_heads: int,
381
+ attention_head_dim: int,
382
+ cross_attention_dim: Optional[int] = None,
383
+ ):
384
+ super().__init__()
385
+ self.is_res = dim == time_mix_inner_dim
386
+
387
+ self.norm_in = nn.LayerNorm(dim)
388
+
389
+ # Define 3 blocks. Each block has its own normalization layer.
390
+ # 1. Self-Attn
391
+ self.norm_in = nn.LayerNorm(dim)
392
+ self.ff_in = FeedForward(
393
+ dim,
394
+ dim_out=time_mix_inner_dim,
395
+ activation_fn="geglu",
396
+ )
397
+
398
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
399
+ self.attn1 = Attention(
400
+ query_dim=time_mix_inner_dim,
401
+ heads=num_attention_heads,
402
+ dim_head=attention_head_dim,
403
+ cross_attention_dim=None,
404
+ )
405
+
406
+ # 2. Cross-Attn
407
+ if cross_attention_dim is not None:
408
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
409
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
410
+ # the second cross attention block.
411
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
412
+ self.attn2 = Attention(
413
+ query_dim=time_mix_inner_dim,
414
+ cross_attention_dim=cross_attention_dim,
415
+ heads=num_attention_heads,
416
+ dim_head=attention_head_dim,
417
+ ) # is self-attn if encoder_hidden_states is none
418
+ else:
419
+ self.norm2 = None
420
+ self.attn2 = None
421
+
422
+ # 3. Feed-forward
423
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
424
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
425
+
426
+ # let chunk size default to None
427
+ self._chunk_size = None
428
+ self._chunk_dim = None
429
+
430
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
431
+ # Sets chunk feed-forward
432
+ self._chunk_size = chunk_size
433
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
434
+ self._chunk_dim = 1
435
+
436
+ def forward(
437
+ self,
438
+ hidden_states: torch.FloatTensor,
439
+ num_frames: int,
440
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
441
+ ) -> torch.FloatTensor:
442
+ # Notice that normalization is always applied before the real computation in the following blocks.
443
+ # 0. Self-Attention
444
+ batch_size = hidden_states.shape[0]
445
+
446
+ batch_frames, seq_length, channels = hidden_states.shape
447
+ batch_size = batch_frames // num_frames
448
+
449
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
450
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
451
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
452
+
453
+ residual = hidden_states
454
+ hidden_states = self.norm_in(hidden_states)
455
+
456
+ if self._chunk_size is not None:
457
+ hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)
458
+ else:
459
+ hidden_states = self.ff_in(hidden_states)
460
+
461
+ if self.is_res:
462
+ hidden_states = hidden_states + residual
463
+
464
+ norm_hidden_states = self.norm1(hidden_states)
465
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
466
+ hidden_states = attn_output + hidden_states
467
+
468
+ # 3. Cross-Attention
469
+ if self.attn2 is not None:
470
+ norm_hidden_states = self.norm2(hidden_states)
471
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
472
+ hidden_states = attn_output + hidden_states
473
+
474
+ # 4. Feed-forward
475
+ norm_hidden_states = self.norm3(hidden_states)
476
+
477
+ if self._chunk_size is not None:
478
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
479
+ else:
480
+ ff_output = self.ff(norm_hidden_states)
481
+
482
+ if self.is_res:
483
+ hidden_states = ff_output + hidden_states
484
+ else:
485
+ hidden_states = ff_output
486
+
487
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
488
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
489
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
490
+
491
+ return hidden_states
492
+
493
+
494
+ class FeedForward(nn.Module):
495
+ r"""
496
+ A feed-forward layer.
497
+
498
+ Parameters:
499
+ dim (`int`): The number of channels in the input.
500
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
501
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
502
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
503
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
504
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
505
+ """
506
+
507
+ def __init__(
508
+ self,
509
+ dim: int,
510
+ dim_out: Optional[int] = None,
511
+ mult: int = 4,
512
+ dropout: float = 0.0,
513
+ activation_fn: str = "geglu",
514
+ final_dropout: bool = False,
515
+ ):
516
+ super().__init__()
517
+ inner_dim = int(dim * mult)
518
+ dim_out = dim_out if dim_out is not None else dim
519
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
520
+
521
+ if activation_fn == "gelu":
522
+ act_fn = GELU(dim, inner_dim)
523
+ if activation_fn == "gelu-approximate":
524
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
525
+ elif activation_fn == "geglu":
526
+ act_fn = GEGLU(dim, inner_dim)
527
+ elif activation_fn == "geglu-approximate":
528
+ act_fn = ApproximateGELU(dim, inner_dim)
529
+
530
+ self.net = nn.ModuleList([])
531
+ # project in
532
+ self.net.append(act_fn)
533
+ # project dropout
534
+ self.net.append(nn.Dropout(dropout))
535
+ # project out
536
+ self.net.append(linear_cls(inner_dim, dim_out))
537
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
538
+ if final_dropout:
539
+ self.net.append(nn.Dropout(dropout))
540
+
541
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
542
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
543
+ for module in self.net:
544
+ if isinstance(module, compatible_cls):
545
+ hidden_states = module(hidden_states, scale)
546
+ else:
547
+ hidden_states = module(hidden_states)
548
+ return hidden_states
models_diffusers/attention_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
models_diffusers/controlnet_svd.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import FromOriginalControlnetMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ # from diffusers.models.attention_processor import (
25
+ from models_diffusers.attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ AttentionProcessor,
29
+ AttnAddedKVProcessor,
30
+ AttnProcessor,
31
+ )
32
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ # from diffusers.models.unet_3d_blocks import get_down_block, get_up_block, UNetMidBlockSpatioTemporal
35
+ from models_diffusers.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
36
+ from diffusers.models import UNetSpatioTemporalConditionModel
37
+ from einops import rearrange
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ @dataclass
44
+ class ControlNetOutput(BaseOutput):
45
+ """
46
+ The output of [`ControlNetModel`].
47
+
48
+ Args:
49
+ down_block_res_samples (`tuple[torch.Tensor]`):
50
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
51
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
52
+ used to condition the original UNet's downsampling activations.
53
+ mid_down_block_re_sample (`torch.Tensor`):
54
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
55
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
56
+ Output can be used to condition the original UNet's middle block activation.
57
+ """
58
+
59
+ down_block_res_samples: Tuple[torch.Tensor]
60
+ mid_block_res_sample: torch.Tensor
61
+
62
+
63
+ class ControlNetConditioningEmbeddingSVD(nn.Module):
64
+ """
65
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
66
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
67
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
68
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
69
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
70
+ model) to encode image-space conditions ... into feature maps ..."
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ conditioning_embedding_channels: int,
76
+ conditioning_channels: int = 3,
77
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
78
+ with_id_feature: bool = False,
79
+ feature_channels: int = 160,
80
+ feature_out_channels: Tuple[int, ...] = (160, 160, 256, 256),
81
+ ):
82
+ super().__init__()
83
+
84
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
85
+
86
+ self.blocks = nn.ModuleList([])
87
+
88
+ for i in range(len(block_out_channels) - 1):
89
+ channel_in = block_out_channels[i]
90
+ channel_out = block_out_channels[i + 1]
91
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
92
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
93
+
94
+ self.conv_out = zero_module(
95
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
96
+ )
97
+
98
+ self.with_id_feature = with_id_feature
99
+
100
+ def forward(self, conditioning, point_embedding=None, point_tracks=None):
101
+ #this seeems appropriate? idk if i should be applying a more complex setup to handle the frames
102
+ #combine batch and frames dimensions
103
+ batch_size, frames, channels, height, width = conditioning.size()
104
+ conditioning = conditioning.view(batch_size * frames, channels, height, width)
105
+
106
+ embedding = self.conv_in(conditioning)
107
+ embedding = F.silu(embedding)
108
+
109
+ for block in self.blocks:
110
+ embedding = block(embedding)
111
+ embedding = F.silu(embedding)
112
+
113
+ embedding = self.conv_out(embedding)
114
+
115
+ assert not self.with_id_feature
116
+
117
+ return embedding
118
+
119
+
120
+ class ControlNetSVDModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
121
+ r"""
122
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
123
+ shaped output.
124
+
125
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
126
+ for all models (such as downloading or saving).
127
+
128
+ Parameters:
129
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
130
+ Height and width of input/output sample.
131
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
132
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
133
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
134
+ The tuple of downsample blocks to use.
135
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
136
+ The tuple of upsample blocks to use.
137
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
138
+ The tuple of output channels for each block.
139
+ addition_time_embed_dim: (`int`, defaults to 256):
140
+ Dimension to to encode the additional time ids.
141
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
142
+ The dimension of the projection of encoded `added_time_ids`.
143
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
144
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
145
+ The dimension of the cross attention features.
146
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
147
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
148
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
149
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
150
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
151
+ The number of attention heads.
152
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
153
+ """
154
+
155
+ _supports_gradient_checkpointing = True
156
+
157
+ @register_to_config
158
+ def __init__(
159
+ self,
160
+ sample_size: Optional[int] = None,
161
+ in_channels: int = 8,
162
+ out_channels: int = 4,
163
+ down_block_types: Tuple[str] = (
164
+ "CrossAttnDownBlockSpatioTemporal",
165
+ "CrossAttnDownBlockSpatioTemporal",
166
+ "CrossAttnDownBlockSpatioTemporal",
167
+ "DownBlockSpatioTemporal",
168
+ ),
169
+ up_block_types: Tuple[str] = (
170
+ "UpBlockSpatioTemporal",
171
+ "CrossAttnUpBlockSpatioTemporal",
172
+ "CrossAttnUpBlockSpatioTemporal",
173
+ "CrossAttnUpBlockSpatioTemporal",
174
+ ),
175
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
176
+ addition_time_embed_dim: int = 256,
177
+ projection_class_embeddings_input_dim: int = 768,
178
+ layers_per_block: Union[int, Tuple[int]] = 2,
179
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
180
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
181
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
182
+ num_frames: int = 14,
183
+ conditioning_channels: int = 3,
184
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
185
+ # NOTE: adapter for dift feature
186
+ with_id_feature: bool = False,
187
+ feature_channels: int = 160,
188
+ feature_out_channels: Tuple[int, ...] = (160, 160, 256, 256),
189
+ ):
190
+ super().__init__()
191
+ self.sample_size = sample_size
192
+
193
+ print("layers per block is", layers_per_block)
194
+
195
+ # Check inputs
196
+ if len(down_block_types) != len(up_block_types):
197
+ raise ValueError(
198
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
199
+ )
200
+
201
+ if len(block_out_channels) != len(down_block_types):
202
+ raise ValueError(
203
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
204
+ )
205
+
206
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
207
+ raise ValueError(
208
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
209
+ )
210
+
211
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
212
+ raise ValueError(
213
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
214
+ )
215
+
216
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
217
+ raise ValueError(
218
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
219
+ )
220
+
221
+ # input
222
+ self.conv_in = nn.Conv2d(
223
+ in_channels,
224
+ block_out_channels[0],
225
+ kernel_size=3,
226
+ padding=1,
227
+ )
228
+
229
+ # time
230
+ time_embed_dim = block_out_channels[0] * 4
231
+
232
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
233
+ timestep_input_dim = block_out_channels[0]
234
+
235
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
236
+
237
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
238
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
239
+
240
+ self.down_blocks = nn.ModuleList([])
241
+ self.controlnet_down_blocks = nn.ModuleList([])
242
+
243
+ if isinstance(num_attention_heads, int):
244
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
245
+
246
+ if isinstance(cross_attention_dim, int):
247
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
248
+
249
+ if isinstance(layers_per_block, int):
250
+ layers_per_block = [layers_per_block] * len(down_block_types)
251
+
252
+ if isinstance(transformer_layers_per_block, int):
253
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
254
+
255
+ blocks_time_embed_dim = time_embed_dim
256
+ self.controlnet_cond_embedding = ControlNetConditioningEmbeddingSVD(
257
+ conditioning_embedding_channels=block_out_channels[0],
258
+ block_out_channels=conditioning_embedding_out_channels,
259
+ conditioning_channels=conditioning_channels,
260
+ # optionally with point feature for conditioning
261
+ with_id_feature=with_id_feature,
262
+ feature_channels=feature_channels,
263
+ feature_out_channels=feature_out_channels,
264
+ )
265
+
266
+ # down
267
+ output_channel = block_out_channels[0]
268
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
269
+ controlnet_block = zero_module(controlnet_block)
270
+ self.controlnet_down_blocks.append(controlnet_block)
271
+
272
+ for i, down_block_type in enumerate(down_block_types):
273
+ input_channel = output_channel
274
+ output_channel = block_out_channels[i]
275
+ is_final_block = i == len(block_out_channels) - 1
276
+
277
+ down_block = get_down_block(
278
+ down_block_type,
279
+ num_layers=layers_per_block[i],
280
+ transformer_layers_per_block=transformer_layers_per_block[i],
281
+ in_channels=input_channel,
282
+ out_channels=output_channel,
283
+ temb_channels=blocks_time_embed_dim,
284
+ add_downsample=not is_final_block,
285
+ resnet_eps=1e-5,
286
+ cross_attention_dim=cross_attention_dim[i],
287
+ num_attention_heads=num_attention_heads[i],
288
+ resnet_act_fn="silu",
289
+ )
290
+ self.down_blocks.append(down_block)
291
+
292
+ for _ in range(layers_per_block[i]):
293
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
294
+ controlnet_block = zero_module(controlnet_block)
295
+ self.controlnet_down_blocks.append(controlnet_block)
296
+
297
+ if not is_final_block:
298
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
299
+ controlnet_block = zero_module(controlnet_block)
300
+ self.controlnet_down_blocks.append(controlnet_block)
301
+
302
+ # mid
303
+ mid_block_channel = block_out_channels[-1]
304
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
305
+ controlnet_block = zero_module(controlnet_block)
306
+ self.controlnet_mid_block = controlnet_block
307
+
308
+ self.mid_block = UNetMidBlockSpatioTemporal(
309
+ block_out_channels[-1],
310
+ temb_channels=blocks_time_embed_dim,
311
+ transformer_layers_per_block=transformer_layers_per_block[-1],
312
+ cross_attention_dim=cross_attention_dim[-1],
313
+ num_attention_heads=num_attention_heads[-1],
314
+ )
315
+
316
+ # # out
317
+ # self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
318
+ # self.conv_act = nn.SiLU()
319
+
320
+ # self.conv_out = nn.Conv2d(
321
+ # block_out_channels[0],
322
+ # out_channels,
323
+ # kernel_size=3,
324
+ # padding=1,
325
+ # )
326
+
327
+ @property
328
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
329
+ r"""
330
+ Returns:
331
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
332
+ indexed by its weight name.
333
+ """
334
+ # set recursively
335
+ processors = {}
336
+
337
+ def fn_recursive_add_processors(
338
+ name: str,
339
+ module: torch.nn.Module,
340
+ processors: Dict[str, AttentionProcessor],
341
+ ):
342
+ if hasattr(module, "get_processor"):
343
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
344
+
345
+ for sub_name, child in module.named_children():
346
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
347
+
348
+ return processors
349
+
350
+ for name, module in self.named_children():
351
+ fn_recursive_add_processors(name, module, processors)
352
+
353
+ return processors
354
+
355
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
356
+ r"""
357
+ Sets the attention processor to use to compute attention.
358
+
359
+ Parameters:
360
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
361
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
362
+ for **all** `Attention` layers.
363
+
364
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
365
+ processor. This is strongly recommended when setting trainable attention processors.
366
+
367
+ """
368
+ count = len(self.attn_processors.keys())
369
+
370
+ if isinstance(processor, dict) and len(processor) != count:
371
+ raise ValueError(
372
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
373
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
374
+ )
375
+
376
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
377
+ if hasattr(module, "set_processor"):
378
+ if not isinstance(processor, dict):
379
+ module.set_processor(processor)
380
+ else:
381
+ module.set_processor(processor.pop(f"{name}.processor"))
382
+
383
+ for sub_name, child in module.named_children():
384
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
385
+
386
+ for name, module in self.named_children():
387
+ fn_recursive_attn_processor(name, module, processor)
388
+
389
+ def set_default_attn_processor(self):
390
+ """
391
+ Disables custom attention processors and sets the default attention implementation.
392
+ """
393
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
394
+ processor = AttnProcessor()
395
+ else:
396
+ raise ValueError(
397
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
398
+ )
399
+
400
+ self.set_attn_processor(processor)
401
+
402
+ def _set_gradient_checkpointing(self, module, value=False):
403
+ if hasattr(module, "gradient_checkpointing"):
404
+ module.gradient_checkpointing = value
405
+
406
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
407
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
408
+ """
409
+ Sets the attention processor to use [feed forward
410
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
411
+
412
+ Parameters:
413
+ chunk_size (`int`, *optional*):
414
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
415
+ over each tensor of dim=`dim`.
416
+ dim (`int`, *optional*, defaults to `0`):
417
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
418
+ or dim=1 (sequence length).
419
+ """
420
+ if dim not in [0, 1]:
421
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
422
+
423
+ # By default chunk size is 1
424
+ chunk_size = chunk_size or 1
425
+
426
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
427
+ if hasattr(module, "set_chunk_feed_forward"):
428
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
429
+
430
+ for child in module.children():
431
+ fn_recursive_feed_forward(child, chunk_size, dim)
432
+
433
+ for module in self.children():
434
+ fn_recursive_feed_forward(module, chunk_size, dim)
435
+
436
+ def forward(
437
+ self,
438
+ sample: torch.FloatTensor,
439
+ timestep: Union[torch.Tensor, float, int],
440
+ encoder_hidden_states: torch.Tensor,
441
+ added_time_ids: torch.Tensor,
442
+ controlnet_cond: torch.FloatTensor = None,
443
+ point_embedding: torch.FloatTensor = None,
444
+ point_tracks: torch.FloatTensor = None,
445
+ image_only_indicator: Optional[torch.Tensor] = None,
446
+ return_dict: bool = True,
447
+ guess_mode: bool = False,
448
+ conditioning_scale: float = 1.0,
449
+ ) -> Union[ControlNetOutput, Tuple]:
450
+ r"""
451
+ The [`UNetSpatioTemporalConditionModel`] forward method.
452
+
453
+ Args:
454
+ sample (`torch.FloatTensor`):
455
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
456
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
457
+ encoder_hidden_states (`torch.FloatTensor`):
458
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
459
+ added_time_ids: (`torch.FloatTensor`):
460
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
461
+ embeddings and added to the time embeddings.
462
+ return_dict (`bool`, *optional*, defaults to `True`):
463
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
464
+ tuple.
465
+ Returns:
466
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
467
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
468
+ a `tuple` is returned where the first element is the sample tensor.
469
+ """
470
+ # 1. time
471
+ timesteps = timestep
472
+ if not torch.is_tensor(timesteps):
473
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
474
+ # This would be a good case for the `match` statement (Python 3.10+)
475
+ is_mps = sample.device.type == "mps"
476
+ if isinstance(timestep, float):
477
+ dtype = torch.float32 if is_mps else torch.float64
478
+ else:
479
+ dtype = torch.int32 if is_mps else torch.int64
480
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
481
+ elif len(timesteps.shape) == 0:
482
+ timesteps = timesteps[None].to(sample.device)
483
+
484
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
485
+ batch_size, num_frames = sample.shape[:2]
486
+ timesteps = timesteps.expand(batch_size)
487
+
488
+ t_emb = self.time_proj(timesteps)
489
+
490
+ # `Timesteps` does not contain any weights and will always return f32 tensors
491
+ # but time_embedding might actually be running in fp16. so we need to cast here.
492
+ # there might be better ways to encapsulate this.
493
+ t_emb = t_emb.to(dtype=sample.dtype)
494
+
495
+ emb = self.time_embedding(t_emb)
496
+
497
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
498
+ time_embeds = time_embeds.reshape((batch_size, -1))
499
+ time_embeds = time_embeds.to(emb.dtype)
500
+ aug_emb = self.add_embedding(time_embeds)
501
+ emb = emb + aug_emb
502
+
503
+ # Flatten the batch and frames dimensions
504
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
505
+ sample = sample.flatten(0, 1)
506
+ # Repeat the embeddings num_video_frames times
507
+ # emb: [batch, channels] -> [batch * frames, channels]
508
+ emb = emb.repeat_interleave(num_frames, dim=0)
509
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
510
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
511
+
512
+ # 2. pre-process
513
+ sample = self.conv_in(sample)
514
+
515
+ # controlnet cond
516
+ if controlnet_cond != None:
517
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond, point_embedding=point_embedding, point_tracks=point_tracks)
518
+ sample = sample + controlnet_cond
519
+
520
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
521
+
522
+ down_block_res_samples = (sample,)
523
+ for downsample_block in self.down_blocks:
524
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
525
+ # print('has_cross_attention', type(downsample_block))
526
+ # models_diffusers.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal
527
+
528
+ sample, res_samples = downsample_block(
529
+ hidden_states=sample,
530
+ temb=emb,
531
+ encoder_hidden_states=encoder_hidden_states,
532
+ image_only_indicator=image_only_indicator,
533
+ )
534
+ else:
535
+ # print('no_cross_attention', type(downsample_block))
536
+ # models_diffusers.unet_3d_blocks.DownBlockSpatioTemporal
537
+
538
+ sample, res_samples = downsample_block(
539
+ hidden_states=sample,
540
+ temb=emb,
541
+ image_only_indicator=image_only_indicator,
542
+ )
543
+
544
+ down_block_res_samples += res_samples
545
+
546
+ # 4. mid
547
+ sample = self.mid_block(
548
+ hidden_states=sample,
549
+ temb=emb,
550
+ encoder_hidden_states=encoder_hidden_states,
551
+ image_only_indicator=image_only_indicator,
552
+ )
553
+
554
+ controlnet_down_block_res_samples = ()
555
+
556
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
557
+ down_block_res_sample = controlnet_block(down_block_res_sample)
558
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
559
+
560
+ down_block_res_samples = controlnet_down_block_res_samples
561
+
562
+ mid_block_res_sample = self.controlnet_mid_block(sample)
563
+
564
+ # 6. scaling
565
+
566
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
567
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
568
+
569
+ if not return_dict:
570
+ return (down_block_res_samples, mid_block_res_sample)
571
+
572
+ return ControlNetOutput(
573
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
574
+ )
575
+
576
+ @classmethod
577
+ def from_unet(
578
+ cls,
579
+ unet: UNetSpatioTemporalConditionModel,
580
+ # controlnet_conditioning_channel_order: str = "rgb",
581
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
582
+ load_weights_from_unet: bool = True,
583
+ conditioning_channels: int = 3,
584
+ with_id_feature: bool = False,
585
+ ):
586
+ r"""
587
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
588
+
589
+ Parameters:
590
+ unet (`UNet2DConditionModel`):
591
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
592
+ where applicable.
593
+ """
594
+
595
+ # transformer_layers_per_block = (
596
+ # unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
597
+ # )
598
+ # encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
599
+ # encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
600
+ # addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
601
+ # addition_time_embed_dim = (
602
+ # unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
603
+ # )
604
+ print(unet.config)
605
+ controlnet = cls(
606
+ in_channels=unet.config.in_channels,
607
+ down_block_types=unet.config.down_block_types,
608
+ block_out_channels=unet.config.block_out_channels,
609
+ addition_time_embed_dim=unet.config.addition_time_embed_dim,
610
+ transformer_layers_per_block=unet.config.transformer_layers_per_block,
611
+ cross_attention_dim=unet.config.cross_attention_dim,
612
+ num_attention_heads=unet.config.num_attention_heads,
613
+ num_frames=unet.config.num_frames,
614
+ sample_size=unet.config.sample_size, # Added based on the dict
615
+ layers_per_block=unet.config.layers_per_block,
616
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
617
+ conditioning_channels = conditioning_channels,
618
+ conditioning_embedding_out_channels = conditioning_embedding_out_channels,
619
+ with_id_feature=with_id_feature,
620
+ )
621
+ # controlnet rgb channel order ignored, set to not makea difference by default
622
+
623
+ if load_weights_from_unet:
624
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
625
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
626
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
627
+
628
+ # if controlnet.class_embedding:
629
+ # controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
630
+
631
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
632
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
633
+
634
+ return controlnet
635
+
636
+ @property
637
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
638
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
639
+ r"""
640
+ Returns:
641
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
642
+ indexed by its weight name.
643
+ """
644
+ # set recursively
645
+ processors = {}
646
+
647
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
648
+ if hasattr(module, "get_processor"):
649
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
650
+
651
+ for sub_name, child in module.named_children():
652
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
653
+
654
+ return processors
655
+
656
+ for name, module in self.named_children():
657
+ fn_recursive_add_processors(name, module, processors)
658
+
659
+ return processors
660
+
661
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
662
+ def set_attn_processor(
663
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
664
+ ):
665
+ r"""
666
+ Sets the attention processor to use to compute attention.
667
+
668
+ Parameters:
669
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
670
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
671
+ for **all** `Attention` layers.
672
+
673
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
674
+ processor. This is strongly recommended when setting trainable attention processors.
675
+
676
+ """
677
+ count = len(self.attn_processors.keys())
678
+
679
+ if isinstance(processor, dict) and len(processor) != count:
680
+ raise ValueError(
681
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
682
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
683
+ )
684
+
685
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
686
+ if hasattr(module, "set_processor"):
687
+ if not isinstance(processor, dict):
688
+ module.set_processor(processor, _remove_lora=_remove_lora)
689
+ else:
690
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
691
+
692
+ for sub_name, child in module.named_children():
693
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
694
+
695
+ for name, module in self.named_children():
696
+ fn_recursive_attn_processor(name, module, processor)
697
+
698
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
699
+ def set_default_attn_processor(self):
700
+ """
701
+ Disables custom attention processors and sets the default attention implementation.
702
+ """
703
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
704
+ processor = AttnAddedKVProcessor()
705
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
706
+ processor = AttnProcessor()
707
+ else:
708
+ raise ValueError(
709
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
710
+ )
711
+
712
+ self.set_attn_processor(processor, _remove_lora=True)
713
+
714
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
715
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
716
+ r"""
717
+ Enable sliced attention computation.
718
+
719
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
720
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
721
+
722
+ Args:
723
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
724
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
725
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
726
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
727
+ must be a multiple of `slice_size`.
728
+ """
729
+ sliceable_head_dims = []
730
+
731
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
732
+ if hasattr(module, "set_attention_slice"):
733
+ sliceable_head_dims.append(module.sliceable_head_dim)
734
+
735
+ for child in module.children():
736
+ fn_recursive_retrieve_sliceable_dims(child)
737
+
738
+ # retrieve number of attention layers
739
+ for module in self.children():
740
+ fn_recursive_retrieve_sliceable_dims(module)
741
+
742
+ num_sliceable_layers = len(sliceable_head_dims)
743
+
744
+ if slice_size == "auto":
745
+ # half the attention head size is usually a good trade-off between
746
+ # speed and memory
747
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
748
+ elif slice_size == "max":
749
+ # make smallest slice possible
750
+ slice_size = num_sliceable_layers * [1]
751
+
752
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
753
+
754
+ if len(slice_size) != len(sliceable_head_dims):
755
+ raise ValueError(
756
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
757
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
758
+ )
759
+
760
+ for i in range(len(slice_size)):
761
+ size = slice_size[i]
762
+ dim = sliceable_head_dims[i]
763
+ if size is not None and size > dim:
764
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
765
+
766
+ # Recursively walk through all the children.
767
+ # Any children which exposes the set_attention_slice method
768
+ # gets the message
769
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
770
+ if hasattr(module, "set_attention_slice"):
771
+ module.set_attention_slice(slice_size.pop())
772
+
773
+ for child in module.children():
774
+ fn_recursive_set_attention_slice(child, slice_size)
775
+
776
+ reversed_slice_size = list(reversed(slice_size))
777
+ for module in self.children():
778
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
779
+
780
+ # def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
781
+ # if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
782
+ # module.gradient_checkpointing = value
783
+
784
+
785
+ def zero_module(module):
786
+ for p in module.parameters():
787
+ nn.init.zeros_(p)
788
+ return module
models_diffusers/sift_match.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.interpolate import interp1d, PchipInterpolator
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ import cv2
6
+ import torch
7
+
8
+
9
+ def sift_match(
10
+ img1, img2,
11
+ thr=0.5,
12
+ topk=5, method="max_dist",
13
+ output_path="sift_matches.png",
14
+ ):
15
+
16
+ assert method in ["max_dist", "random", "max_score", "max_score_even"]
17
+
18
+ # img1 and img2 are PIL images
19
+ # small threshold means less points
20
+
21
+ # 1. to cv2 grayscale image
22
+ img1_rgb = np.array(img1).copy()
23
+ img2_rgb = np.array(img2).copy()
24
+ img1 = cv2.cvtColor(np.array(img1), cv2.COLOR_RGB2BGR)
25
+ img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
26
+ img2 = cv2.cvtColor(np.array(img2), cv2.COLOR_RGB2BGR)
27
+ img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
28
+
29
+ # 2. use sift to extract keypoints and descriptors
30
+ # Initiate SIFT detector
31
+ sift = cv2.SIFT_create()
32
+ # find the keypoints and descriptors with SIFT
33
+ kp1, des1 = sift.detectAndCompute(img1, None)
34
+ kp2, des2 = sift.detectAndCompute(img2, None)
35
+ # BFMatcher with default params
36
+ bf = cv2.BFMatcher()
37
+ # bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
38
+ # bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
39
+ matches = bf.knnMatch(des1, des2, k=2)
40
+
41
+ # Apply ratio test
42
+ good = []
43
+ point_list = []
44
+ distance_list = []
45
+
46
+ if method in ['max_score', 'max_score_even']:
47
+ matches = sorted(matches, key=lambda x: x[0].distance / x[1].distance)
48
+
49
+ anchor_points_list = []
50
+ for m, n in matches[:topk]:
51
+ print(m.distance / n.distance)
52
+
53
+ # check evenly distributed
54
+ if method == 'max_score_even':
55
+ to_close = False
56
+ for anchor_point in anchor_points_list:
57
+ pt1 = kp1[m.queryIdx].pt
58
+ dist = np.linalg.norm(np.array(pt1) - np.array(anchor_point))
59
+ if dist < 50:
60
+ to_close = True
61
+ break
62
+ if to_close:
63
+ continue
64
+
65
+ good.append([m])
66
+
67
+ pt1 = kp1[m.queryIdx].pt
68
+ pt2 = kp2[m.trainIdx].pt
69
+ dist = np.linalg.norm(np.array(pt1) - np.array(pt2))
70
+ distance_list.append(dist)
71
+
72
+ anchor_points_list.append(pt1)
73
+
74
+ pt1 = torch.tensor(pt1)
75
+ pt2 = torch.tensor(pt2)
76
+ pt = torch.stack([pt1, pt2]) # (2, 2)
77
+ point_list.append(pt)
78
+
79
+ if method in ['max_dist', 'random']:
80
+ for m, n in matches:
81
+ if m.distance < thr * n.distance:
82
+ good.append([m])
83
+
84
+ pt1 = kp1[m.queryIdx].pt
85
+ pt2 = kp2[m.trainIdx].pt
86
+ dist = np.linalg.norm(np.array(pt1) - np.array(pt2))
87
+ distance_list.append(dist)
88
+
89
+ pt1 = torch.tensor(pt1)
90
+ pt2 = torch.tensor(pt2)
91
+ pt = torch.stack([pt1, pt2]) # (2, 2)
92
+ point_list.append(pt)
93
+
94
+ distance_list = np.array(distance_list)
95
+ # only keep the points with the largest topk distance
96
+ idx = np.argsort(distance_list)
97
+ if method == "max_dist":
98
+ idx = idx[-topk:]
99
+ elif method == "random":
100
+ topk = min(topk, len(idx))
101
+ idx = np.random.choice(idx, topk, replace=False)
102
+ elif method == "max_score":
103
+ import pdb; pdb.set_trace()
104
+ raise NotImplementedError
105
+ # idx = np.argsort(distance_list)[:topk]
106
+ else:
107
+ raise ValueError(f"Unknown method {method}")
108
+
109
+ point_list = [point_list[i] for i in idx]
110
+ good = [good[i] for i in idx]
111
+
112
+ # # cv2.drawMatchesKnn expects list of lists as matches.
113
+ # draw_params = dict(
114
+ # matchColor=(255, 0, 0),
115
+ # singlePointColor=None,
116
+ # flags=2,
117
+ # )
118
+ # img3 = cv2.drawMatchesKnn(img1, kp1, img2, kp2, good, None, **draw_params)
119
+
120
+
121
+ # # manually draw the matches, the images are put in horizontal
122
+ # img3 = np.concatenate([img1_rgb, img2_rgb], axis=1) # (h, 2w, 3)
123
+ # for m in good:
124
+ # pt1 = kp1[m[0].queryIdx].pt
125
+ # pt2 = kp2[m[0].trainIdx].pt
126
+ # pt1 = (int(pt1[0]), int(pt1[1]))
127
+ # pt2 = (int(pt2[0]) + img1_rgb.shape[1], int(pt2[1]))
128
+ # cv2.line(img3, pt1, pt2, (255, 0, 0), 1)
129
+
130
+ # manually draw the matches, the images are put in vertical. with 10 pixels margin
131
+ margin = 10
132
+ img3 = np.zeros((img1_rgb.shape[0] + img2_rgb.shape[0] + margin, max(img1_rgb.shape[1], img2_rgb.shape[1]), 3), dtype=np.uint8)
133
+ # the margin is white
134
+ img3[:, :] = 255
135
+ img3[:img1_rgb.shape[0], :img1_rgb.shape[1]] = img1_rgb
136
+ img3[img1_rgb.shape[0] + margin:, :img2_rgb.shape[1]] = img2_rgb
137
+ # create a color list of 6 different colors
138
+ color_list = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)]
139
+ for color_idx, m in enumerate(good):
140
+ pt1 = kp1[m[0].queryIdx].pt
141
+ pt2 = kp2[m[0].trainIdx].pt
142
+ pt1 = (int(pt1[0]), int(pt1[1]))
143
+ pt2 = (int(pt2[0]), int(pt2[1]) + img1_rgb.shape[0] + margin)
144
+ # cv2.line(img3, pt1, pt2, (255, 0, 0), 1)
145
+ # avoid the zigzag artifact in line
146
+ # random_color = tuple(np.random.randint(0, 255, 3).tolist())
147
+ color = color_list[color_idx % len(color_list)]
148
+ cv2.line(img3, pt1, pt2, color, 1, lineType=cv2.LINE_AA)
149
+ # add a empty circle to both start and end points
150
+ cv2.circle(img3, pt1, 3, color, lineType=cv2.LINE_AA)
151
+ cv2.circle(img3, pt2, 3, color, lineType=cv2.LINE_AA)
152
+
153
+ Image.fromarray(img3).save(output_path)
154
+ print(f"Save the sift matches to {output_path}")
155
+
156
+ # (f, topk, 2), f=2 (before interpolation)
157
+ if len(point_list) == 0:
158
+ return None
159
+
160
+ point_list = torch.stack(point_list)
161
+ point_list = point_list.permute(1, 0, 2)
162
+
163
+ return point_list
164
+
165
+
166
+ def interpolate_trajectory(points_torch, num_frames, t=None):
167
+ # points:(f, topk, 2), f=2 (before interpolation)
168
+
169
+ num_points = points_torch.shape[1]
170
+ points_torch = points_torch.permute(1, 0, 2) # (topk, f, 2)
171
+
172
+ points_list = []
173
+ for i in range(num_points):
174
+ # points:(f, 2)
175
+ points = points_torch[i].cpu().numpy()
176
+
177
+ x = [point[0] for point in points]
178
+ y = [point[1] for point in points]
179
+
180
+ if t is None:
181
+ t = np.linspace(0, 1, len(points))
182
+
183
+ # fx = interp1d(t, x, kind='cubic')
184
+ # fy = interp1d(t, y, kind='cubic')
185
+ fx = PchipInterpolator(t, x)
186
+ fy = PchipInterpolator(t, y)
187
+
188
+ new_t = np.linspace(0, 1, num_frames)
189
+
190
+ new_x = fx(new_t)
191
+ new_y = fy(new_t)
192
+ new_points = list(zip(new_x, new_y))
193
+
194
+ points_list.append(new_points)
195
+
196
+ points = torch.tensor(points_list) # (topk, num_frames, 2)
197
+ points = points.permute(1, 0, 2) # (num_frames, topk, 2)
198
+
199
+ return points
200
+
201
+
202
+ # diffusion feature matching
203
+ def point_tracking(
204
+ F0,
205
+ F1,
206
+ handle_points,
207
+ handle_points_init,
208
+ track_dist=5,
209
+ ):
210
+ # handle_points: (num_points, 2)
211
+ # NOTE:
212
+ # 1. all row and col are reversed
213
+ # 2. handle_points in (y, x), not (x, y)
214
+
215
+ # reverse row and col
216
+ handle_points = torch.stack([handle_points[:, 1], handle_points[:, 0]], dim=-1)
217
+ handle_points_init = torch.stack([handle_points_init[:, 1], handle_points_init[:, 0]], dim=-1)
218
+
219
+ with torch.no_grad():
220
+ _, _, max_r, max_c = F0.shape
221
+
222
+ for i in range(len(handle_points)):
223
+ pi0, pi = handle_points_init[i], handle_points[i]
224
+ f0 = F0[:, :, int(pi0[0]), int(pi0[1])]
225
+
226
+ r1, r2 = max(0, int(pi[0]) - track_dist), min(max_r, int(pi[0]) + track_dist + 1)
227
+ c1, c2 = max(0, int(pi[1]) - track_dist), min(max_c, int(pi[1]) + track_dist + 1)
228
+ F1_neighbor = F1[:, :, r1:r2, c1:c2]
229
+ all_dist = (f0.unsqueeze(dim=-1).unsqueeze(dim=-1) - F1_neighbor).abs().sum(dim=1)
230
+ all_dist = all_dist.squeeze(dim=0)
231
+ row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1])
232
+ # handle_points[i][0] = pi[0] - track_dist + row
233
+ # handle_points[i][1] = pi[1] - track_dist + col
234
+ handle_points[i][0] = r1 + row
235
+ handle_points[i][1] = c1 + col
236
+
237
+ handle_points = torch.stack([handle_points[:, 1], handle_points[:, 0]], dim=-1) # (num_points, 2)
238
+
239
+ return handle_points
models_diffusers/transformer_temporal.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import BaseOutput
22
+ from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock
23
+ # from diffusers.models.attention import BasicTransformerBlock
24
+ from models_diffusers.attention import BasicTransformerBlock
25
+
26
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.models.resnet import AlphaBlender
29
+
30
+
31
+ @dataclass
32
+ class TransformerTemporalModelOutput(BaseOutput):
33
+ """
34
+ The output of [`TransformerTemporalModel`].
35
+
36
+ Args:
37
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
38
+ The hidden states output conditioned on `encoder_hidden_states` input.
39
+ """
40
+
41
+ sample: torch.FloatTensor
42
+
43
+
44
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
45
+ """
46
+ A Transformer model for video-like data.
47
+
48
+ Parameters:
49
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
50
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
51
+ in_channels (`int`, *optional*):
52
+ The number of channels in the input and output (specify if the input is **continuous**).
53
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
54
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
55
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
56
+ attention_bias (`bool`, *optional*):
57
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
58
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
59
+ This is fixed during training since it is used to learn a number of position embeddings.
60
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
61
+ Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
62
+ activation functions.
63
+ norm_elementwise_affine (`bool`, *optional*):
64
+ Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
65
+ double_self_attention (`bool`, *optional*):
66
+ Configure if each `TransformerBlock` should contain two self-attention layers.
67
+ positional_embeddings: (`str`, *optional*):
68
+ The type of positional embeddings to apply to the sequence input before passing use.
69
+ num_positional_embeddings: (`int`, *optional*):
70
+ The maximum length of the sequence over which to apply positional embeddings.
71
+ """
72
+
73
+ @register_to_config
74
+ def __init__(
75
+ self,
76
+ num_attention_heads: int = 16,
77
+ attention_head_dim: int = 88,
78
+ in_channels: Optional[int] = None,
79
+ out_channels: Optional[int] = None,
80
+ num_layers: int = 1,
81
+ dropout: float = 0.0,
82
+ norm_num_groups: int = 32,
83
+ cross_attention_dim: Optional[int] = None,
84
+ attention_bias: bool = False,
85
+ sample_size: Optional[int] = None,
86
+ activation_fn: str = "geglu",
87
+ norm_elementwise_affine: bool = True,
88
+ double_self_attention: bool = True,
89
+ positional_embeddings: Optional[str] = None,
90
+ num_positional_embeddings: Optional[int] = None,
91
+ ):
92
+ super().__init__()
93
+ self.num_attention_heads = num_attention_heads
94
+ self.attention_head_dim = attention_head_dim
95
+ inner_dim = num_attention_heads * attention_head_dim
96
+
97
+ self.in_channels = in_channels
98
+
99
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
100
+ self.proj_in = nn.Linear(in_channels, inner_dim)
101
+
102
+ # 3. Define transformers blocks
103
+ self.transformer_blocks = nn.ModuleList(
104
+ [
105
+ BasicTransformerBlock(
106
+ inner_dim,
107
+ num_attention_heads,
108
+ attention_head_dim,
109
+ dropout=dropout,
110
+ cross_attention_dim=cross_attention_dim,
111
+ activation_fn=activation_fn,
112
+ attention_bias=attention_bias,
113
+ double_self_attention=double_self_attention,
114
+ norm_elementwise_affine=norm_elementwise_affine,
115
+ positional_embeddings=positional_embeddings,
116
+ num_positional_embeddings=num_positional_embeddings,
117
+ )
118
+ for d in range(num_layers)
119
+ ]
120
+ )
121
+
122
+ self.proj_out = nn.Linear(inner_dim, in_channels)
123
+
124
+ def forward(
125
+ self,
126
+ hidden_states: torch.FloatTensor,
127
+ encoder_hidden_states: Optional[torch.LongTensor] = None,
128
+ timestep: Optional[torch.LongTensor] = None,
129
+ class_labels: torch.LongTensor = None,
130
+ num_frames: int = 1,
131
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
132
+ return_dict: bool = True,
133
+ ) -> TransformerTemporalModelOutput:
134
+ """
135
+ The [`TransformerTemporal`] forward method.
136
+
137
+ Args:
138
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
139
+ Input hidden_states.
140
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
141
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
142
+ self-attention.
143
+ timestep ( `torch.LongTensor`, *optional*):
144
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
145
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
146
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
147
+ `AdaLayerZeroNorm`.
148
+ num_frames (`int`, *optional*, defaults to 1):
149
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
150
+ cross_attention_kwargs (`dict`, *optional*):
151
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
152
+ `self.processor` in
153
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
154
+ return_dict (`bool`, *optional*, defaults to `True`):
155
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
156
+ tuple.
157
+
158
+ Returns:
159
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
160
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
161
+ returned, otherwise a `tuple` where the first element is the sample tensor.
162
+ """
163
+ # 1. Input
164
+ batch_frames, channel, height, width = hidden_states.shape
165
+ batch_size = batch_frames // num_frames
166
+
167
+ residual = hidden_states
168
+
169
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
170
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
171
+
172
+ hidden_states = self.norm(hidden_states)
173
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
174
+
175
+ hidden_states = self.proj_in(hidden_states)
176
+
177
+ # 2. Blocks
178
+ for block in self.transformer_blocks:
179
+ hidden_states = block(
180
+ hidden_states,
181
+ encoder_hidden_states=encoder_hidden_states,
182
+ timestep=timestep,
183
+ cross_attention_kwargs=cross_attention_kwargs,
184
+ class_labels=class_labels,
185
+ )
186
+
187
+ # 3. Output
188
+ hidden_states = self.proj_out(hidden_states)
189
+ hidden_states = (
190
+ hidden_states[None, None, :]
191
+ .reshape(batch_size, height, width, num_frames, channel)
192
+ .permute(0, 3, 4, 1, 2)
193
+ .contiguous()
194
+ )
195
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
196
+
197
+ output = hidden_states + residual
198
+
199
+ if not return_dict:
200
+ return (output,)
201
+
202
+ return TransformerTemporalModelOutput(sample=output)
203
+
204
+
205
+ class TransformerSpatioTemporalModel(nn.Module):
206
+ """
207
+ A Transformer model for video-like data.
208
+
209
+ Parameters:
210
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
211
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
212
+ in_channels (`int`, *optional*):
213
+ The number of channels in the input and output (specify if the input is **continuous**).
214
+ out_channels (`int`, *optional*):
215
+ The number of channels in the output (specify if the input is **continuous**).
216
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
217
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ num_attention_heads: int = 16,
223
+ attention_head_dim: int = 88,
224
+ in_channels: int = 320,
225
+ out_channels: Optional[int] = None,
226
+ num_layers: int = 1,
227
+ cross_attention_dim: Optional[int] = None,
228
+ ):
229
+ super().__init__()
230
+ self.num_attention_heads = num_attention_heads
231
+ self.attention_head_dim = attention_head_dim
232
+
233
+ inner_dim = num_attention_heads * attention_head_dim
234
+ self.inner_dim = inner_dim
235
+
236
+ # 2. Define input layers
237
+ self.in_channels = in_channels
238
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
239
+ self.proj_in = nn.Linear(in_channels, inner_dim)
240
+
241
+ # 3. Define transformers blocks
242
+ self.transformer_blocks = nn.ModuleList(
243
+ [
244
+ BasicTransformerBlock(
245
+ inner_dim,
246
+ num_attention_heads,
247
+ attention_head_dim,
248
+ cross_attention_dim=cross_attention_dim,
249
+ )
250
+ for d in range(num_layers)
251
+ ]
252
+ )
253
+
254
+ time_mix_inner_dim = inner_dim
255
+ self.temporal_transformer_blocks = nn.ModuleList(
256
+ [
257
+ TemporalBasicTransformerBlock(
258
+ inner_dim,
259
+ time_mix_inner_dim,
260
+ num_attention_heads,
261
+ attention_head_dim,
262
+ cross_attention_dim=cross_attention_dim,
263
+ )
264
+ for _ in range(num_layers)
265
+ ]
266
+ )
267
+
268
+ time_embed_dim = in_channels * 4
269
+ self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
270
+ self.time_proj = Timesteps(in_channels, True, 0)
271
+ self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
272
+
273
+ # 4. Define output layers
274
+ self.out_channels = in_channels if out_channels is None else out_channels
275
+ # TODO: should use out_channels for continuous projections
276
+ self.proj_out = nn.Linear(inner_dim, in_channels)
277
+
278
+ self.gradient_checkpointing = False
279
+
280
+ def forward(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ encoder_hidden_states: Optional[torch.Tensor] = None,
284
+ image_only_indicator: Optional[torch.Tensor] = None,
285
+ return_dict: bool = True,
286
+ ):
287
+ """
288
+ Args:
289
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
290
+ Input hidden_states.
291
+ num_frames (`int`):
292
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
293
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
294
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
295
+ self-attention.
296
+ image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
297
+ A tensor indicating whether the input contains only images. 1 indicates that the input contains only
298
+ images, 0 indicates that the input contains video frames.
299
+ return_dict (`bool`, *optional*, defaults to `True`):
300
+ Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
301
+ tuple.
302
+
303
+ Returns:
304
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
305
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
306
+ returned, otherwise a `tuple` where the first element is the sample tensor.
307
+ """
308
+ # 1. Input
309
+ batch_frames, _, height, width = hidden_states.shape
310
+ num_frames = image_only_indicator.shape[-1]
311
+ batch_size = batch_frames // num_frames
312
+
313
+ time_context = encoder_hidden_states
314
+ time_context_first_timestep = time_context[None, :].reshape(
315
+ batch_size, num_frames, -1, time_context.shape[-1]
316
+ )[:, 0]
317
+ time_context = time_context_first_timestep[None, :].broadcast_to(
318
+ # height * width, batch_size, 1, time_context.shape[-1]
319
+ height * width, batch_size, -1, time_context.shape[-1]
320
+ )
321
+ # time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
322
+ time_context = time_context.reshape(height * width * batch_size, -1, time_context.shape[-1])
323
+
324
+ residual = hidden_states
325
+
326
+ hidden_states = self.norm(hidden_states)
327
+ inner_dim = hidden_states.shape[1]
328
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
329
+ hidden_states = self.proj_in(hidden_states)
330
+
331
+ num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
332
+ num_frames_emb = num_frames_emb.repeat(batch_size, 1)
333
+ num_frames_emb = num_frames_emb.reshape(-1)
334
+ t_emb = self.time_proj(num_frames_emb)
335
+
336
+ # `Timesteps` does not contain any weights and will always return f32 tensors
337
+ # but time_embedding might actually be running in fp16. so we need to cast here.
338
+ # there might be better ways to encapsulate this.
339
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
340
+
341
+ emb = self.time_pos_embed(t_emb)
342
+ emb = emb[:, None, :]
343
+
344
+ # 2. Blocks
345
+ for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
346
+ if self.training and self.gradient_checkpointing:
347
+ hidden_states = torch.utils.checkpoint.checkpoint(
348
+ block,
349
+ hidden_states,
350
+ None,
351
+ encoder_hidden_states,
352
+ None,
353
+ use_reentrant=False,
354
+ )
355
+ else:
356
+ hidden_states = block(
357
+ hidden_states,
358
+ encoder_hidden_states=encoder_hidden_states,
359
+ )
360
+
361
+ hidden_states_mix = hidden_states
362
+ hidden_states_mix = hidden_states_mix + emb
363
+
364
+ hidden_states_mix = temporal_block(
365
+ hidden_states_mix,
366
+ num_frames=num_frames,
367
+ encoder_hidden_states=time_context,
368
+ )
369
+ hidden_states = self.time_mixer(
370
+ x_spatial=hidden_states,
371
+ x_temporal=hidden_states_mix,
372
+ image_only_indicator=image_only_indicator,
373
+ )
374
+
375
+ # 3. Output
376
+ hidden_states = self.proj_out(hidden_states)
377
+ hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
378
+
379
+ output = hidden_states + residual
380
+
381
+ if not return_dict:
382
+ return (output,)
383
+
384
+ return TransformerTemporalModelOutput(sample=output)
models_diffusers/unet_3d_blocks.py ADDED
@@ -0,0 +1,2405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from diffusers.utils import is_torch_version
21
+ from diffusers.utils.torch_utils import apply_freeu
22
+ # from diffusers.models.attention import Attention
23
+ from models_diffusers.attention_processor import Attention
24
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import (
26
+ Downsample2D,
27
+ ResnetBlock2D,
28
+ SpatioTemporalResBlock,
29
+ TemporalConvLayer,
30
+ Upsample2D,
31
+ )
32
+ from diffusers.models.transformer_2d import Transformer2DModel
33
+ from .transformer_temporal import (
34
+ TransformerSpatioTemporalModel,
35
+ TransformerTemporalModel,
36
+ )
37
+
38
+ from einops import rearrange
39
+
40
+
41
+ def get_down_block(
42
+ down_block_type: str,
43
+ num_layers: int,
44
+ in_channels: int,
45
+ out_channels: int,
46
+ temb_channels: int,
47
+ add_downsample: bool,
48
+ resnet_eps: float,
49
+ resnet_act_fn: str,
50
+ num_attention_heads: int,
51
+ resnet_groups: Optional[int] = None,
52
+ cross_attention_dim: Optional[int] = None,
53
+ downsample_padding: Optional[int] = None,
54
+ dual_cross_attention: bool = False,
55
+ use_linear_projection: bool = True,
56
+ only_cross_attention: bool = False,
57
+ upcast_attention: bool = False,
58
+ resnet_time_scale_shift: str = "default",
59
+ temporal_num_attention_heads: int = 8,
60
+ temporal_max_seq_length: int = 32,
61
+ transformer_layers_per_block: int = 1,
62
+ ) -> Union[
63
+ "DownBlock3D",
64
+ "CrossAttnDownBlock3D",
65
+ "DownBlockMotion",
66
+ "CrossAttnDownBlockMotion",
67
+ "DownBlockSpatioTemporal",
68
+ "CrossAttnDownBlockSpatioTemporal",
69
+ ]:
70
+ if down_block_type == "DownBlock3D":
71
+ return DownBlock3D(
72
+ num_layers=num_layers,
73
+ in_channels=in_channels,
74
+ out_channels=out_channels,
75
+ temb_channels=temb_channels,
76
+ add_downsample=add_downsample,
77
+ resnet_eps=resnet_eps,
78
+ resnet_act_fn=resnet_act_fn,
79
+ resnet_groups=resnet_groups,
80
+ downsample_padding=downsample_padding,
81
+ resnet_time_scale_shift=resnet_time_scale_shift,
82
+ )
83
+ elif down_block_type == "CrossAttnDownBlock3D":
84
+ if cross_attention_dim is None:
85
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
86
+ return CrossAttnDownBlock3D(
87
+ num_layers=num_layers,
88
+ in_channels=in_channels,
89
+ out_channels=out_channels,
90
+ temb_channels=temb_channels,
91
+ add_downsample=add_downsample,
92
+ resnet_eps=resnet_eps,
93
+ resnet_act_fn=resnet_act_fn,
94
+ resnet_groups=resnet_groups,
95
+ downsample_padding=downsample_padding,
96
+ cross_attention_dim=cross_attention_dim,
97
+ num_attention_heads=num_attention_heads,
98
+ dual_cross_attention=dual_cross_attention,
99
+ use_linear_projection=use_linear_projection,
100
+ only_cross_attention=only_cross_attention,
101
+ upcast_attention=upcast_attention,
102
+ resnet_time_scale_shift=resnet_time_scale_shift,
103
+ )
104
+ if down_block_type == "DownBlockMotion":
105
+ return DownBlockMotion(
106
+ num_layers=num_layers,
107
+ in_channels=in_channels,
108
+ out_channels=out_channels,
109
+ temb_channels=temb_channels,
110
+ add_downsample=add_downsample,
111
+ resnet_eps=resnet_eps,
112
+ resnet_act_fn=resnet_act_fn,
113
+ resnet_groups=resnet_groups,
114
+ downsample_padding=downsample_padding,
115
+ resnet_time_scale_shift=resnet_time_scale_shift,
116
+ temporal_num_attention_heads=temporal_num_attention_heads,
117
+ temporal_max_seq_length=temporal_max_seq_length,
118
+ )
119
+ elif down_block_type == "CrossAttnDownBlockMotion":
120
+ if cross_attention_dim is None:
121
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
122
+ return CrossAttnDownBlockMotion(
123
+ num_layers=num_layers,
124
+ in_channels=in_channels,
125
+ out_channels=out_channels,
126
+ temb_channels=temb_channels,
127
+ add_downsample=add_downsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ downsample_padding=downsample_padding,
132
+ cross_attention_dim=cross_attention_dim,
133
+ num_attention_heads=num_attention_heads,
134
+ dual_cross_attention=dual_cross_attention,
135
+ use_linear_projection=use_linear_projection,
136
+ only_cross_attention=only_cross_attention,
137
+ upcast_attention=upcast_attention,
138
+ resnet_time_scale_shift=resnet_time_scale_shift,
139
+ temporal_num_attention_heads=temporal_num_attention_heads,
140
+ temporal_max_seq_length=temporal_max_seq_length,
141
+ )
142
+ elif down_block_type == "DownBlockSpatioTemporal":
143
+ # added for SDV
144
+ return DownBlockSpatioTemporal(
145
+ num_layers=num_layers,
146
+ in_channels=in_channels,
147
+ out_channels=out_channels,
148
+ temb_channels=temb_channels,
149
+ add_downsample=add_downsample,
150
+ )
151
+ elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
152
+ # added for SDV
153
+ if cross_attention_dim is None:
154
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
155
+ return CrossAttnDownBlockSpatioTemporal(
156
+ in_channels=in_channels,
157
+ out_channels=out_channels,
158
+ temb_channels=temb_channels,
159
+ num_layers=num_layers,
160
+ transformer_layers_per_block=transformer_layers_per_block,
161
+ add_downsample=add_downsample,
162
+ cross_attention_dim=cross_attention_dim,
163
+ num_attention_heads=num_attention_heads,
164
+ )
165
+
166
+ raise ValueError(f"{down_block_type} does not exist.")
167
+
168
+
169
+ def get_up_block(
170
+ up_block_type: str,
171
+ num_layers: int,
172
+ in_channels: int,
173
+ out_channels: int,
174
+ prev_output_channel: int,
175
+ temb_channels: int,
176
+ add_upsample: bool,
177
+ resnet_eps: float,
178
+ resnet_act_fn: str,
179
+ num_attention_heads: int,
180
+ resolution_idx: Optional[int] = None,
181
+ resnet_groups: Optional[int] = None,
182
+ cross_attention_dim: Optional[int] = None,
183
+ dual_cross_attention: bool = False,
184
+ use_linear_projection: bool = True,
185
+ only_cross_attention: bool = False,
186
+ upcast_attention: bool = False,
187
+ resnet_time_scale_shift: str = "default",
188
+ temporal_num_attention_heads: int = 8,
189
+ temporal_cross_attention_dim: Optional[int] = None,
190
+ temporal_max_seq_length: int = 32,
191
+ transformer_layers_per_block: int = 1,
192
+ dropout: float = 0.0,
193
+ ) -> Union[
194
+ "UpBlock3D",
195
+ "CrossAttnUpBlock3D",
196
+ "UpBlockMotion",
197
+ "CrossAttnUpBlockMotion",
198
+ "UpBlockSpatioTemporal",
199
+ "CrossAttnUpBlockSpatioTemporal",
200
+ ]:
201
+ if up_block_type == "UpBlock3D":
202
+ return UpBlock3D(
203
+ num_layers=num_layers,
204
+ in_channels=in_channels,
205
+ out_channels=out_channels,
206
+ prev_output_channel=prev_output_channel,
207
+ temb_channels=temb_channels,
208
+ add_upsample=add_upsample,
209
+ resnet_eps=resnet_eps,
210
+ resnet_act_fn=resnet_act_fn,
211
+ resnet_groups=resnet_groups,
212
+ resnet_time_scale_shift=resnet_time_scale_shift,
213
+ resolution_idx=resolution_idx,
214
+ )
215
+ elif up_block_type == "CrossAttnUpBlock3D":
216
+ if cross_attention_dim is None:
217
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
218
+ return CrossAttnUpBlock3D(
219
+ num_layers=num_layers,
220
+ in_channels=in_channels,
221
+ out_channels=out_channels,
222
+ prev_output_channel=prev_output_channel,
223
+ temb_channels=temb_channels,
224
+ add_upsample=add_upsample,
225
+ resnet_eps=resnet_eps,
226
+ resnet_act_fn=resnet_act_fn,
227
+ resnet_groups=resnet_groups,
228
+ cross_attention_dim=cross_attention_dim,
229
+ num_attention_heads=num_attention_heads,
230
+ dual_cross_attention=dual_cross_attention,
231
+ use_linear_projection=use_linear_projection,
232
+ only_cross_attention=only_cross_attention,
233
+ upcast_attention=upcast_attention,
234
+ resnet_time_scale_shift=resnet_time_scale_shift,
235
+ resolution_idx=resolution_idx,
236
+ )
237
+ if up_block_type == "UpBlockMotion":
238
+ return UpBlockMotion(
239
+ num_layers=num_layers,
240
+ in_channels=in_channels,
241
+ out_channels=out_channels,
242
+ prev_output_channel=prev_output_channel,
243
+ temb_channels=temb_channels,
244
+ add_upsample=add_upsample,
245
+ resnet_eps=resnet_eps,
246
+ resnet_act_fn=resnet_act_fn,
247
+ resnet_groups=resnet_groups,
248
+ resnet_time_scale_shift=resnet_time_scale_shift,
249
+ resolution_idx=resolution_idx,
250
+ temporal_num_attention_heads=temporal_num_attention_heads,
251
+ temporal_max_seq_length=temporal_max_seq_length,
252
+ )
253
+ elif up_block_type == "CrossAttnUpBlockMotion":
254
+ if cross_attention_dim is None:
255
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
256
+ return CrossAttnUpBlockMotion(
257
+ num_layers=num_layers,
258
+ in_channels=in_channels,
259
+ out_channels=out_channels,
260
+ prev_output_channel=prev_output_channel,
261
+ temb_channels=temb_channels,
262
+ add_upsample=add_upsample,
263
+ resnet_eps=resnet_eps,
264
+ resnet_act_fn=resnet_act_fn,
265
+ resnet_groups=resnet_groups,
266
+ cross_attention_dim=cross_attention_dim,
267
+ num_attention_heads=num_attention_heads,
268
+ dual_cross_attention=dual_cross_attention,
269
+ use_linear_projection=use_linear_projection,
270
+ only_cross_attention=only_cross_attention,
271
+ upcast_attention=upcast_attention,
272
+ resnet_time_scale_shift=resnet_time_scale_shift,
273
+ resolution_idx=resolution_idx,
274
+ temporal_num_attention_heads=temporal_num_attention_heads,
275
+ temporal_max_seq_length=temporal_max_seq_length,
276
+ )
277
+ elif up_block_type == "UpBlockSpatioTemporal":
278
+ # added for SDV
279
+ return UpBlockSpatioTemporal(
280
+ num_layers=num_layers,
281
+ in_channels=in_channels,
282
+ out_channels=out_channels,
283
+ prev_output_channel=prev_output_channel,
284
+ temb_channels=temb_channels,
285
+ resolution_idx=resolution_idx,
286
+ add_upsample=add_upsample,
287
+ )
288
+ elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
289
+ # added for SDV
290
+ if cross_attention_dim is None:
291
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
292
+ return CrossAttnUpBlockSpatioTemporal(
293
+ in_channels=in_channels,
294
+ out_channels=out_channels,
295
+ prev_output_channel=prev_output_channel,
296
+ temb_channels=temb_channels,
297
+ num_layers=num_layers,
298
+ transformer_layers_per_block=transformer_layers_per_block,
299
+ add_upsample=add_upsample,
300
+ cross_attention_dim=cross_attention_dim,
301
+ num_attention_heads=num_attention_heads,
302
+ resolution_idx=resolution_idx,
303
+ )
304
+
305
+ raise ValueError(f"{up_block_type} does not exist.")
306
+
307
+
308
+ class UNetMidBlock3DCrossAttn(nn.Module):
309
+ def __init__(
310
+ self,
311
+ in_channels: int,
312
+ temb_channels: int,
313
+ dropout: float = 0.0,
314
+ num_layers: int = 1,
315
+ resnet_eps: float = 1e-6,
316
+ resnet_time_scale_shift: str = "default",
317
+ resnet_act_fn: str = "swish",
318
+ resnet_groups: int = 32,
319
+ resnet_pre_norm: bool = True,
320
+ num_attention_heads: int = 1,
321
+ output_scale_factor: float = 1.0,
322
+ cross_attention_dim: int = 1280,
323
+ dual_cross_attention: bool = False,
324
+ use_linear_projection: bool = True,
325
+ upcast_attention: bool = False,
326
+ ):
327
+ super().__init__()
328
+
329
+ self.has_cross_attention = True
330
+ self.num_attention_heads = num_attention_heads
331
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
332
+
333
+ # there is always at least one resnet
334
+ resnets = [
335
+ ResnetBlock2D(
336
+ in_channels=in_channels,
337
+ out_channels=in_channels,
338
+ temb_channels=temb_channels,
339
+ eps=resnet_eps,
340
+ groups=resnet_groups,
341
+ dropout=dropout,
342
+ time_embedding_norm=resnet_time_scale_shift,
343
+ non_linearity=resnet_act_fn,
344
+ output_scale_factor=output_scale_factor,
345
+ pre_norm=resnet_pre_norm,
346
+ )
347
+ ]
348
+ temp_convs = [
349
+ TemporalConvLayer(
350
+ in_channels,
351
+ in_channels,
352
+ dropout=0.1,
353
+ norm_num_groups=resnet_groups,
354
+ )
355
+ ]
356
+ attentions = []
357
+ temp_attentions = []
358
+
359
+ for _ in range(num_layers):
360
+ attentions.append(
361
+ Transformer2DModel(
362
+ in_channels // num_attention_heads,
363
+ num_attention_heads,
364
+ in_channels=in_channels,
365
+ num_layers=1,
366
+ cross_attention_dim=cross_attention_dim,
367
+ norm_num_groups=resnet_groups,
368
+ use_linear_projection=use_linear_projection,
369
+ upcast_attention=upcast_attention,
370
+ )
371
+ )
372
+ temp_attentions.append(
373
+ TransformerTemporalModel(
374
+ in_channels // num_attention_heads,
375
+ num_attention_heads,
376
+ in_channels=in_channels,
377
+ num_layers=1,
378
+ cross_attention_dim=cross_attention_dim,
379
+ norm_num_groups=resnet_groups,
380
+ )
381
+ )
382
+ resnets.append(
383
+ ResnetBlock2D(
384
+ in_channels=in_channels,
385
+ out_channels=in_channels,
386
+ temb_channels=temb_channels,
387
+ eps=resnet_eps,
388
+ groups=resnet_groups,
389
+ dropout=dropout,
390
+ time_embedding_norm=resnet_time_scale_shift,
391
+ non_linearity=resnet_act_fn,
392
+ output_scale_factor=output_scale_factor,
393
+ pre_norm=resnet_pre_norm,
394
+ )
395
+ )
396
+ temp_convs.append(
397
+ TemporalConvLayer(
398
+ in_channels,
399
+ in_channels,
400
+ dropout=0.1,
401
+ norm_num_groups=resnet_groups,
402
+ )
403
+ )
404
+
405
+ self.resnets = nn.ModuleList(resnets)
406
+ self.temp_convs = nn.ModuleList(temp_convs)
407
+ self.attentions = nn.ModuleList(attentions)
408
+ self.temp_attentions = nn.ModuleList(temp_attentions)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.FloatTensor,
413
+ temb: Optional[torch.FloatTensor] = None,
414
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
415
+ attention_mask: Optional[torch.FloatTensor] = None,
416
+ num_frames: int = 1,
417
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
418
+ ) -> torch.FloatTensor:
419
+ hidden_states = self.resnets[0](hidden_states, temb)
420
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
421
+ for attn, temp_attn, resnet, temp_conv in zip(
422
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
423
+ ):
424
+ hidden_states = attn(
425
+ hidden_states,
426
+ encoder_hidden_states=encoder_hidden_states,
427
+ cross_attention_kwargs=cross_attention_kwargs,
428
+ return_dict=False,
429
+ )[0]
430
+ hidden_states = temp_attn(
431
+ hidden_states,
432
+ num_frames=num_frames,
433
+ cross_attention_kwargs=cross_attention_kwargs,
434
+ return_dict=False,
435
+ )[0]
436
+ hidden_states = resnet(hidden_states, temb)
437
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
438
+
439
+ return hidden_states
440
+
441
+
442
+ class CrossAttnDownBlock3D(nn.Module):
443
+ def __init__(
444
+ self,
445
+ in_channels: int,
446
+ out_channels: int,
447
+ temb_channels: int,
448
+ dropout: float = 0.0,
449
+ num_layers: int = 1,
450
+ resnet_eps: float = 1e-6,
451
+ resnet_time_scale_shift: str = "default",
452
+ resnet_act_fn: str = "swish",
453
+ resnet_groups: int = 32,
454
+ resnet_pre_norm: bool = True,
455
+ num_attention_heads: int = 1,
456
+ cross_attention_dim: int = 1280,
457
+ output_scale_factor: float = 1.0,
458
+ downsample_padding: int = 1,
459
+ add_downsample: bool = True,
460
+ dual_cross_attention: bool = False,
461
+ use_linear_projection: bool = False,
462
+ only_cross_attention: bool = False,
463
+ upcast_attention: bool = False,
464
+ ):
465
+ super().__init__()
466
+ resnets = []
467
+ attentions = []
468
+ temp_attentions = []
469
+ temp_convs = []
470
+
471
+ self.has_cross_attention = True
472
+ self.num_attention_heads = num_attention_heads
473
+
474
+ for i in range(num_layers):
475
+ in_channels = in_channels if i == 0 else out_channels
476
+ resnets.append(
477
+ ResnetBlock2D(
478
+ in_channels=in_channels,
479
+ out_channels=out_channels,
480
+ temb_channels=temb_channels,
481
+ eps=resnet_eps,
482
+ groups=resnet_groups,
483
+ dropout=dropout,
484
+ time_embedding_norm=resnet_time_scale_shift,
485
+ non_linearity=resnet_act_fn,
486
+ output_scale_factor=output_scale_factor,
487
+ pre_norm=resnet_pre_norm,
488
+ )
489
+ )
490
+ temp_convs.append(
491
+ TemporalConvLayer(
492
+ out_channels,
493
+ out_channels,
494
+ dropout=0.1,
495
+ norm_num_groups=resnet_groups,
496
+ )
497
+ )
498
+ attentions.append(
499
+ Transformer2DModel(
500
+ out_channels // num_attention_heads,
501
+ num_attention_heads,
502
+ in_channels=out_channels,
503
+ num_layers=1,
504
+ cross_attention_dim=cross_attention_dim,
505
+ norm_num_groups=resnet_groups,
506
+ use_linear_projection=use_linear_projection,
507
+ only_cross_attention=only_cross_attention,
508
+ upcast_attention=upcast_attention,
509
+ )
510
+ )
511
+ temp_attentions.append(
512
+ TransformerTemporalModel(
513
+ out_channels // num_attention_heads,
514
+ num_attention_heads,
515
+ in_channels=out_channels,
516
+ num_layers=1,
517
+ cross_attention_dim=cross_attention_dim,
518
+ norm_num_groups=resnet_groups,
519
+ )
520
+ )
521
+ self.resnets = nn.ModuleList(resnets)
522
+ self.temp_convs = nn.ModuleList(temp_convs)
523
+ self.attentions = nn.ModuleList(attentions)
524
+ self.temp_attentions = nn.ModuleList(temp_attentions)
525
+
526
+ if add_downsample:
527
+ self.downsamplers = nn.ModuleList(
528
+ [
529
+ Downsample2D(
530
+ out_channels,
531
+ use_conv=True,
532
+ out_channels=out_channels,
533
+ padding=downsample_padding,
534
+ name="op",
535
+ )
536
+ ]
537
+ )
538
+ else:
539
+ self.downsamplers = None
540
+
541
+ self.gradient_checkpointing = False
542
+
543
+ def forward(
544
+ self,
545
+ hidden_states: torch.FloatTensor,
546
+ temb: Optional[torch.FloatTensor] = None,
547
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
548
+ attention_mask: Optional[torch.FloatTensor] = None,
549
+ num_frames: int = 1,
550
+ cross_attention_kwargs: Dict[str, Any] = None,
551
+ ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
552
+ # TODO(Patrick, William) - attention mask is not used
553
+ output_states = ()
554
+
555
+ for resnet, temp_conv, attn, temp_attn in zip(
556
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
557
+ ):
558
+ hidden_states = resnet(hidden_states, temb)
559
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
560
+ hidden_states = attn(
561
+ hidden_states,
562
+ encoder_hidden_states=encoder_hidden_states,
563
+ cross_attention_kwargs=cross_attention_kwargs,
564
+ return_dict=False,
565
+ )[0]
566
+ hidden_states = temp_attn(
567
+ hidden_states,
568
+ num_frames=num_frames,
569
+ cross_attention_kwargs=cross_attention_kwargs,
570
+ return_dict=False,
571
+ )[0]
572
+
573
+ output_states += (hidden_states,)
574
+
575
+ if self.downsamplers is not None:
576
+ for downsampler in self.downsamplers:
577
+ hidden_states = downsampler(hidden_states)
578
+
579
+ output_states += (hidden_states,)
580
+
581
+ return hidden_states, output_states
582
+
583
+
584
+ class DownBlock3D(nn.Module):
585
+ def __init__(
586
+ self,
587
+ in_channels: int,
588
+ out_channels: int,
589
+ temb_channels: int,
590
+ dropout: float = 0.0,
591
+ num_layers: int = 1,
592
+ resnet_eps: float = 1e-6,
593
+ resnet_time_scale_shift: str = "default",
594
+ resnet_act_fn: str = "swish",
595
+ resnet_groups: int = 32,
596
+ resnet_pre_norm: bool = True,
597
+ output_scale_factor: float = 1.0,
598
+ add_downsample: bool = True,
599
+ downsample_padding: int = 1,
600
+ ):
601
+ super().__init__()
602
+ resnets = []
603
+ temp_convs = []
604
+
605
+ for i in range(num_layers):
606
+ in_channels = in_channels if i == 0 else out_channels
607
+ resnets.append(
608
+ ResnetBlock2D(
609
+ in_channels=in_channels,
610
+ out_channels=out_channels,
611
+ temb_channels=temb_channels,
612
+ eps=resnet_eps,
613
+ groups=resnet_groups,
614
+ dropout=dropout,
615
+ time_embedding_norm=resnet_time_scale_shift,
616
+ non_linearity=resnet_act_fn,
617
+ output_scale_factor=output_scale_factor,
618
+ pre_norm=resnet_pre_norm,
619
+ )
620
+ )
621
+ temp_convs.append(
622
+ TemporalConvLayer(
623
+ out_channels,
624
+ out_channels,
625
+ dropout=0.1,
626
+ norm_num_groups=resnet_groups,
627
+ )
628
+ )
629
+
630
+ self.resnets = nn.ModuleList(resnets)
631
+ self.temp_convs = nn.ModuleList(temp_convs)
632
+
633
+ if add_downsample:
634
+ self.downsamplers = nn.ModuleList(
635
+ [
636
+ Downsample2D(
637
+ out_channels,
638
+ use_conv=True,
639
+ out_channels=out_channels,
640
+ padding=downsample_padding,
641
+ name="op",
642
+ )
643
+ ]
644
+ )
645
+ else:
646
+ self.downsamplers = None
647
+
648
+ self.gradient_checkpointing = False
649
+
650
+ def forward(
651
+ self,
652
+ hidden_states: torch.FloatTensor,
653
+ temb: Optional[torch.FloatTensor] = None,
654
+ num_frames: int = 1,
655
+ ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
656
+ output_states = ()
657
+
658
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
659
+ hidden_states = resnet(hidden_states, temb)
660
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
661
+
662
+ output_states += (hidden_states,)
663
+
664
+ if self.downsamplers is not None:
665
+ for downsampler in self.downsamplers:
666
+ hidden_states = downsampler(hidden_states)
667
+
668
+ output_states += (hidden_states,)
669
+
670
+ return hidden_states, output_states
671
+
672
+
673
+ class CrossAttnUpBlock3D(nn.Module):
674
+ def __init__(
675
+ self,
676
+ in_channels: int,
677
+ out_channels: int,
678
+ prev_output_channel: int,
679
+ temb_channels: int,
680
+ dropout: float = 0.0,
681
+ num_layers: int = 1,
682
+ resnet_eps: float = 1e-6,
683
+ resnet_time_scale_shift: str = "default",
684
+ resnet_act_fn: str = "swish",
685
+ resnet_groups: int = 32,
686
+ resnet_pre_norm: bool = True,
687
+ num_attention_heads: int = 1,
688
+ cross_attention_dim: int = 1280,
689
+ output_scale_factor: float = 1.0,
690
+ add_upsample: bool = True,
691
+ dual_cross_attention: bool = False,
692
+ use_linear_projection: bool = False,
693
+ only_cross_attention: bool = False,
694
+ upcast_attention: bool = False,
695
+ resolution_idx: Optional[int] = None,
696
+ ):
697
+ super().__init__()
698
+ resnets = []
699
+ temp_convs = []
700
+ attentions = []
701
+ temp_attentions = []
702
+
703
+ self.has_cross_attention = True
704
+ self.num_attention_heads = num_attention_heads
705
+
706
+ for i in range(num_layers):
707
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
708
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
709
+
710
+ resnets.append(
711
+ ResnetBlock2D(
712
+ in_channels=resnet_in_channels + res_skip_channels,
713
+ out_channels=out_channels,
714
+ temb_channels=temb_channels,
715
+ eps=resnet_eps,
716
+ groups=resnet_groups,
717
+ dropout=dropout,
718
+ time_embedding_norm=resnet_time_scale_shift,
719
+ non_linearity=resnet_act_fn,
720
+ output_scale_factor=output_scale_factor,
721
+ pre_norm=resnet_pre_norm,
722
+ )
723
+ )
724
+ temp_convs.append(
725
+ TemporalConvLayer(
726
+ out_channels,
727
+ out_channels,
728
+ dropout=0.1,
729
+ norm_num_groups=resnet_groups,
730
+ )
731
+ )
732
+ attentions.append(
733
+ Transformer2DModel(
734
+ out_channels // num_attention_heads,
735
+ num_attention_heads,
736
+ in_channels=out_channels,
737
+ num_layers=1,
738
+ cross_attention_dim=cross_attention_dim,
739
+ norm_num_groups=resnet_groups,
740
+ use_linear_projection=use_linear_projection,
741
+ only_cross_attention=only_cross_attention,
742
+ upcast_attention=upcast_attention,
743
+ )
744
+ )
745
+ temp_attentions.append(
746
+ TransformerTemporalModel(
747
+ out_channels // num_attention_heads,
748
+ num_attention_heads,
749
+ in_channels=out_channels,
750
+ num_layers=1,
751
+ cross_attention_dim=cross_attention_dim,
752
+ norm_num_groups=resnet_groups,
753
+ )
754
+ )
755
+ self.resnets = nn.ModuleList(resnets)
756
+ self.temp_convs = nn.ModuleList(temp_convs)
757
+ self.attentions = nn.ModuleList(attentions)
758
+ self.temp_attentions = nn.ModuleList(temp_attentions)
759
+
760
+ if add_upsample:
761
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
762
+ else:
763
+ self.upsamplers = None
764
+
765
+ self.gradient_checkpointing = False
766
+ self.resolution_idx = resolution_idx
767
+
768
+ def forward(
769
+ self,
770
+ hidden_states: torch.FloatTensor,
771
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
772
+ temb: Optional[torch.FloatTensor] = None,
773
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
774
+ upsample_size: Optional[int] = None,
775
+ attention_mask: Optional[torch.FloatTensor] = None,
776
+ num_frames: int = 1,
777
+ cross_attention_kwargs: Dict[str, Any] = None,
778
+ ) -> torch.FloatTensor:
779
+ is_freeu_enabled = (
780
+ getattr(self, "s1", None)
781
+ and getattr(self, "s2", None)
782
+ and getattr(self, "b1", None)
783
+ and getattr(self, "b2", None)
784
+ )
785
+
786
+ # TODO(Patrick, William) - attention mask is not used
787
+ for resnet, temp_conv, attn, temp_attn in zip(
788
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
789
+ ):
790
+ # pop res hidden states
791
+ res_hidden_states = res_hidden_states_tuple[-1]
792
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
793
+
794
+ # FreeU: Only operate on the first two stages
795
+ if is_freeu_enabled:
796
+ hidden_states, res_hidden_states = apply_freeu(
797
+ self.resolution_idx,
798
+ hidden_states,
799
+ res_hidden_states,
800
+ s1=self.s1,
801
+ s2=self.s2,
802
+ b1=self.b1,
803
+ b2=self.b2,
804
+ )
805
+
806
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
807
+
808
+ hidden_states = resnet(hidden_states, temb)
809
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
810
+ hidden_states = attn(
811
+ hidden_states,
812
+ encoder_hidden_states=encoder_hidden_states,
813
+ cross_attention_kwargs=cross_attention_kwargs,
814
+ return_dict=False,
815
+ )[0]
816
+ hidden_states = temp_attn(
817
+ hidden_states,
818
+ num_frames=num_frames,
819
+ cross_attention_kwargs=cross_attention_kwargs,
820
+ return_dict=False,
821
+ )[0]
822
+
823
+ if self.upsamplers is not None:
824
+ for upsampler in self.upsamplers:
825
+ hidden_states = upsampler(hidden_states, upsample_size)
826
+
827
+ return hidden_states
828
+
829
+
830
+ class UpBlock3D(nn.Module):
831
+ def __init__(
832
+ self,
833
+ in_channels: int,
834
+ prev_output_channel: int,
835
+ out_channels: int,
836
+ temb_channels: int,
837
+ dropout: float = 0.0,
838
+ num_layers: int = 1,
839
+ resnet_eps: float = 1e-6,
840
+ resnet_time_scale_shift: str = "default",
841
+ resnet_act_fn: str = "swish",
842
+ resnet_groups: int = 32,
843
+ resnet_pre_norm: bool = True,
844
+ output_scale_factor: float = 1.0,
845
+ add_upsample: bool = True,
846
+ resolution_idx: Optional[int] = None,
847
+ ):
848
+ super().__init__()
849
+ resnets = []
850
+ temp_convs = []
851
+
852
+ for i in range(num_layers):
853
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
854
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
855
+
856
+ resnets.append(
857
+ ResnetBlock2D(
858
+ in_channels=resnet_in_channels + res_skip_channels,
859
+ out_channels=out_channels,
860
+ temb_channels=temb_channels,
861
+ eps=resnet_eps,
862
+ groups=resnet_groups,
863
+ dropout=dropout,
864
+ time_embedding_norm=resnet_time_scale_shift,
865
+ non_linearity=resnet_act_fn,
866
+ output_scale_factor=output_scale_factor,
867
+ pre_norm=resnet_pre_norm,
868
+ )
869
+ )
870
+ temp_convs.append(
871
+ TemporalConvLayer(
872
+ out_channels,
873
+ out_channels,
874
+ dropout=0.1,
875
+ norm_num_groups=resnet_groups,
876
+ )
877
+ )
878
+
879
+ self.resnets = nn.ModuleList(resnets)
880
+ self.temp_convs = nn.ModuleList(temp_convs)
881
+
882
+ if add_upsample:
883
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
884
+ else:
885
+ self.upsamplers = None
886
+
887
+ self.gradient_checkpointing = False
888
+ self.resolution_idx = resolution_idx
889
+
890
+ def forward(
891
+ self,
892
+ hidden_states: torch.FloatTensor,
893
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
894
+ temb: Optional[torch.FloatTensor] = None,
895
+ upsample_size: Optional[int] = None,
896
+ num_frames: int = 1,
897
+ ) -> torch.FloatTensor:
898
+ is_freeu_enabled = (
899
+ getattr(self, "s1", None)
900
+ and getattr(self, "s2", None)
901
+ and getattr(self, "b1", None)
902
+ and getattr(self, "b2", None)
903
+ )
904
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
905
+ # pop res hidden states
906
+ res_hidden_states = res_hidden_states_tuple[-1]
907
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
908
+
909
+ # FreeU: Only operate on the first two stages
910
+ if is_freeu_enabled:
911
+ hidden_states, res_hidden_states = apply_freeu(
912
+ self.resolution_idx,
913
+ hidden_states,
914
+ res_hidden_states,
915
+ s1=self.s1,
916
+ s2=self.s2,
917
+ b1=self.b1,
918
+ b2=self.b2,
919
+ )
920
+
921
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
922
+
923
+ hidden_states = resnet(hidden_states, temb)
924
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
925
+
926
+ if self.upsamplers is not None:
927
+ for upsampler in self.upsamplers:
928
+ hidden_states = upsampler(hidden_states, upsample_size)
929
+
930
+ return hidden_states
931
+
932
+
933
+ class DownBlockMotion(nn.Module):
934
+ def __init__(
935
+ self,
936
+ in_channels: int,
937
+ out_channels: int,
938
+ temb_channels: int,
939
+ dropout: float = 0.0,
940
+ num_layers: int = 1,
941
+ resnet_eps: float = 1e-6,
942
+ resnet_time_scale_shift: str = "default",
943
+ resnet_act_fn: str = "swish",
944
+ resnet_groups: int = 32,
945
+ resnet_pre_norm: bool = True,
946
+ output_scale_factor: float = 1.0,
947
+ add_downsample: bool = True,
948
+ downsample_padding: int = 1,
949
+ temporal_num_attention_heads: int = 1,
950
+ temporal_cross_attention_dim: Optional[int] = None,
951
+ temporal_max_seq_length: int = 32,
952
+ ):
953
+ super().__init__()
954
+ resnets = []
955
+ motion_modules = []
956
+
957
+ for i in range(num_layers):
958
+ in_channels = in_channels if i == 0 else out_channels
959
+ resnets.append(
960
+ ResnetBlock2D(
961
+ in_channels=in_channels,
962
+ out_channels=out_channels,
963
+ temb_channels=temb_channels,
964
+ eps=resnet_eps,
965
+ groups=resnet_groups,
966
+ dropout=dropout,
967
+ time_embedding_norm=resnet_time_scale_shift,
968
+ non_linearity=resnet_act_fn,
969
+ output_scale_factor=output_scale_factor,
970
+ pre_norm=resnet_pre_norm,
971
+ )
972
+ )
973
+ motion_modules.append(
974
+ TransformerTemporalModel(
975
+ num_attention_heads=temporal_num_attention_heads,
976
+ in_channels=out_channels,
977
+ norm_num_groups=resnet_groups,
978
+ cross_attention_dim=temporal_cross_attention_dim,
979
+ attention_bias=False,
980
+ activation_fn="geglu",
981
+ positional_embeddings="sinusoidal",
982
+ num_positional_embeddings=temporal_max_seq_length,
983
+ attention_head_dim=out_channels // temporal_num_attention_heads,
984
+ )
985
+ )
986
+
987
+ self.resnets = nn.ModuleList(resnets)
988
+ self.motion_modules = nn.ModuleList(motion_modules)
989
+
990
+ if add_downsample:
991
+ self.downsamplers = nn.ModuleList(
992
+ [
993
+ Downsample2D(
994
+ out_channels,
995
+ use_conv=True,
996
+ out_channels=out_channels,
997
+ padding=downsample_padding,
998
+ name="op",
999
+ )
1000
+ ]
1001
+ )
1002
+ else:
1003
+ self.downsamplers = None
1004
+
1005
+ self.gradient_checkpointing = False
1006
+
1007
+ def forward(
1008
+ self,
1009
+ hidden_states: torch.FloatTensor,
1010
+ temb: Optional[torch.FloatTensor] = None,
1011
+ scale: float = 1.0,
1012
+ num_frames: int = 1,
1013
+ ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1014
+ output_states = ()
1015
+
1016
+ blocks = zip(self.resnets, self.motion_modules)
1017
+ for resnet, motion_module in blocks:
1018
+ if self.training and self.gradient_checkpointing:
1019
+
1020
+ def create_custom_forward(module):
1021
+ def custom_forward(*inputs):
1022
+ return module(*inputs)
1023
+
1024
+ return custom_forward
1025
+
1026
+ if is_torch_version(">=", "1.11.0"):
1027
+ hidden_states = torch.utils.checkpoint.checkpoint(
1028
+ create_custom_forward(resnet),
1029
+ hidden_states,
1030
+ temb,
1031
+ use_reentrant=False,
1032
+ )
1033
+ else:
1034
+ hidden_states = torch.utils.checkpoint.checkpoint(
1035
+ create_custom_forward(resnet), hidden_states, temb, scale
1036
+ )
1037
+ hidden_states = torch.utils.checkpoint.checkpoint(
1038
+ create_custom_forward(motion_module),
1039
+ hidden_states.requires_grad_(),
1040
+ temb,
1041
+ num_frames,
1042
+ )
1043
+
1044
+ else:
1045
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1046
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1047
+
1048
+ output_states = output_states + (hidden_states,)
1049
+
1050
+ if self.downsamplers is not None:
1051
+ for downsampler in self.downsamplers:
1052
+ hidden_states = downsampler(hidden_states, scale=scale)
1053
+
1054
+ output_states = output_states + (hidden_states,)
1055
+
1056
+ return hidden_states, output_states
1057
+
1058
+
1059
+ class CrossAttnDownBlockMotion(nn.Module):
1060
+ def __init__(
1061
+ self,
1062
+ in_channels: int,
1063
+ out_channels: int,
1064
+ temb_channels: int,
1065
+ dropout: float = 0.0,
1066
+ num_layers: int = 1,
1067
+ transformer_layers_per_block: int = 1,
1068
+ resnet_eps: float = 1e-6,
1069
+ resnet_time_scale_shift: str = "default",
1070
+ resnet_act_fn: str = "swish",
1071
+ resnet_groups: int = 32,
1072
+ resnet_pre_norm: bool = True,
1073
+ num_attention_heads: int = 1,
1074
+ cross_attention_dim: int = 1280,
1075
+ output_scale_factor: float = 1.0,
1076
+ downsample_padding: int = 1,
1077
+ add_downsample: bool = True,
1078
+ dual_cross_attention: bool = False,
1079
+ use_linear_projection: bool = False,
1080
+ only_cross_attention: bool = False,
1081
+ upcast_attention: bool = False,
1082
+ attention_type: str = "default",
1083
+ temporal_cross_attention_dim: Optional[int] = None,
1084
+ temporal_num_attention_heads: int = 8,
1085
+ temporal_max_seq_length: int = 32,
1086
+ ):
1087
+ super().__init__()
1088
+ resnets = []
1089
+ attentions = []
1090
+ motion_modules = []
1091
+
1092
+ self.has_cross_attention = True
1093
+ self.num_attention_heads = num_attention_heads
1094
+
1095
+ for i in range(num_layers):
1096
+ in_channels = in_channels if i == 0 else out_channels
1097
+ resnets.append(
1098
+ ResnetBlock2D(
1099
+ in_channels=in_channels,
1100
+ out_channels=out_channels,
1101
+ temb_channels=temb_channels,
1102
+ eps=resnet_eps,
1103
+ groups=resnet_groups,
1104
+ dropout=dropout,
1105
+ time_embedding_norm=resnet_time_scale_shift,
1106
+ non_linearity=resnet_act_fn,
1107
+ output_scale_factor=output_scale_factor,
1108
+ pre_norm=resnet_pre_norm,
1109
+ )
1110
+ )
1111
+
1112
+ if not dual_cross_attention:
1113
+ attentions.append(
1114
+ Transformer2DModel(
1115
+ num_attention_heads,
1116
+ out_channels // num_attention_heads,
1117
+ in_channels=out_channels,
1118
+ num_layers=transformer_layers_per_block,
1119
+ cross_attention_dim=cross_attention_dim,
1120
+ norm_num_groups=resnet_groups,
1121
+ use_linear_projection=use_linear_projection,
1122
+ only_cross_attention=only_cross_attention,
1123
+ upcast_attention=upcast_attention,
1124
+ attention_type=attention_type,
1125
+ )
1126
+ )
1127
+ else:
1128
+ attentions.append(
1129
+ DualTransformer2DModel(
1130
+ num_attention_heads,
1131
+ out_channels // num_attention_heads,
1132
+ in_channels=out_channels,
1133
+ num_layers=1,
1134
+ cross_attention_dim=cross_attention_dim,
1135
+ norm_num_groups=resnet_groups,
1136
+ )
1137
+ )
1138
+
1139
+ motion_modules.append(
1140
+ TransformerTemporalModel(
1141
+ num_attention_heads=temporal_num_attention_heads,
1142
+ in_channels=out_channels,
1143
+ norm_num_groups=resnet_groups,
1144
+ cross_attention_dim=temporal_cross_attention_dim,
1145
+ attention_bias=False,
1146
+ activation_fn="geglu",
1147
+ positional_embeddings="sinusoidal",
1148
+ num_positional_embeddings=temporal_max_seq_length,
1149
+ attention_head_dim=out_channels // temporal_num_attention_heads,
1150
+ )
1151
+ )
1152
+
1153
+ self.attentions = nn.ModuleList(attentions)
1154
+ self.resnets = nn.ModuleList(resnets)
1155
+ self.motion_modules = nn.ModuleList(motion_modules)
1156
+
1157
+ if add_downsample:
1158
+ self.downsamplers = nn.ModuleList(
1159
+ [
1160
+ Downsample2D(
1161
+ out_channels,
1162
+ use_conv=True,
1163
+ out_channels=out_channels,
1164
+ padding=downsample_padding,
1165
+ name="op",
1166
+ )
1167
+ ]
1168
+ )
1169
+ else:
1170
+ self.downsamplers = None
1171
+
1172
+ self.gradient_checkpointing = False
1173
+
1174
+ def forward(
1175
+ self,
1176
+ hidden_states: torch.FloatTensor,
1177
+ temb: Optional[torch.FloatTensor] = None,
1178
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1179
+ attention_mask: Optional[torch.FloatTensor] = None,
1180
+ num_frames: int = 1,
1181
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1182
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1183
+ additional_residuals: Optional[torch.FloatTensor] = None,
1184
+ ):
1185
+ output_states = ()
1186
+
1187
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1188
+
1189
+ blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
1190
+ for i, (resnet, attn, motion_module) in enumerate(blocks):
1191
+ if self.training and self.gradient_checkpointing:
1192
+
1193
+ def create_custom_forward(module, return_dict=None):
1194
+ def custom_forward(*inputs):
1195
+ if return_dict is not None:
1196
+ return module(*inputs, return_dict=return_dict)
1197
+ else:
1198
+ return module(*inputs)
1199
+
1200
+ return custom_forward
1201
+
1202
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1203
+ hidden_states = torch.utils.checkpoint.checkpoint(
1204
+ create_custom_forward(resnet),
1205
+ hidden_states,
1206
+ temb,
1207
+ **ckpt_kwargs,
1208
+ )
1209
+ hidden_states = attn(
1210
+ hidden_states,
1211
+ encoder_hidden_states=encoder_hidden_states,
1212
+ cross_attention_kwargs=cross_attention_kwargs,
1213
+ attention_mask=attention_mask,
1214
+ encoder_attention_mask=encoder_attention_mask,
1215
+ return_dict=False,
1216
+ )[0]
1217
+ else:
1218
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1219
+ hidden_states = attn(
1220
+ hidden_states,
1221
+ encoder_hidden_states=encoder_hidden_states,
1222
+ cross_attention_kwargs=cross_attention_kwargs,
1223
+ attention_mask=attention_mask,
1224
+ encoder_attention_mask=encoder_attention_mask,
1225
+ return_dict=False,
1226
+ )[0]
1227
+ hidden_states = motion_module(
1228
+ hidden_states,
1229
+ num_frames=num_frames,
1230
+ )[0]
1231
+
1232
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
1233
+ if i == len(blocks) - 1 and additional_residuals is not None:
1234
+ hidden_states = hidden_states + additional_residuals
1235
+
1236
+ output_states = output_states + (hidden_states,)
1237
+
1238
+ if self.downsamplers is not None:
1239
+ for downsampler in self.downsamplers:
1240
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
1241
+
1242
+ output_states = output_states + (hidden_states,)
1243
+
1244
+ return hidden_states, output_states
1245
+
1246
+
1247
+ class CrossAttnUpBlockMotion(nn.Module):
1248
+ def __init__(
1249
+ self,
1250
+ in_channels: int,
1251
+ out_channels: int,
1252
+ prev_output_channel: int,
1253
+ temb_channels: int,
1254
+ resolution_idx: Optional[int] = None,
1255
+ dropout: float = 0.0,
1256
+ num_layers: int = 1,
1257
+ transformer_layers_per_block: int = 1,
1258
+ resnet_eps: float = 1e-6,
1259
+ resnet_time_scale_shift: str = "default",
1260
+ resnet_act_fn: str = "swish",
1261
+ resnet_groups: int = 32,
1262
+ resnet_pre_norm: bool = True,
1263
+ num_attention_heads: int = 1,
1264
+ cross_attention_dim: int = 1280,
1265
+ output_scale_factor: float = 1.0,
1266
+ add_upsample: bool = True,
1267
+ dual_cross_attention: bool = False,
1268
+ use_linear_projection: bool = False,
1269
+ only_cross_attention: bool = False,
1270
+ upcast_attention: bool = False,
1271
+ attention_type: str = "default",
1272
+ temporal_cross_attention_dim: Optional[int] = None,
1273
+ temporal_num_attention_heads: int = 8,
1274
+ temporal_max_seq_length: int = 32,
1275
+ ):
1276
+ super().__init__()
1277
+ resnets = []
1278
+ attentions = []
1279
+ motion_modules = []
1280
+
1281
+ self.has_cross_attention = True
1282
+ self.num_attention_heads = num_attention_heads
1283
+
1284
+ for i in range(num_layers):
1285
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1286
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1287
+
1288
+ resnets.append(
1289
+ ResnetBlock2D(
1290
+ in_channels=resnet_in_channels + res_skip_channels,
1291
+ out_channels=out_channels,
1292
+ temb_channels=temb_channels,
1293
+ eps=resnet_eps,
1294
+ groups=resnet_groups,
1295
+ dropout=dropout,
1296
+ time_embedding_norm=resnet_time_scale_shift,
1297
+ non_linearity=resnet_act_fn,
1298
+ output_scale_factor=output_scale_factor,
1299
+ pre_norm=resnet_pre_norm,
1300
+ )
1301
+ )
1302
+
1303
+ if not dual_cross_attention:
1304
+ attentions.append(
1305
+ Transformer2DModel(
1306
+ num_attention_heads,
1307
+ out_channels // num_attention_heads,
1308
+ in_channels=out_channels,
1309
+ num_layers=transformer_layers_per_block,
1310
+ cross_attention_dim=cross_attention_dim,
1311
+ norm_num_groups=resnet_groups,
1312
+ use_linear_projection=use_linear_projection,
1313
+ only_cross_attention=only_cross_attention,
1314
+ upcast_attention=upcast_attention,
1315
+ attention_type=attention_type,
1316
+ )
1317
+ )
1318
+ else:
1319
+ attentions.append(
1320
+ DualTransformer2DModel(
1321
+ num_attention_heads,
1322
+ out_channels // num_attention_heads,
1323
+ in_channels=out_channels,
1324
+ num_layers=1,
1325
+ cross_attention_dim=cross_attention_dim,
1326
+ norm_num_groups=resnet_groups,
1327
+ )
1328
+ )
1329
+ motion_modules.append(
1330
+ TransformerTemporalModel(
1331
+ num_attention_heads=temporal_num_attention_heads,
1332
+ in_channels=out_channels,
1333
+ norm_num_groups=resnet_groups,
1334
+ cross_attention_dim=temporal_cross_attention_dim,
1335
+ attention_bias=False,
1336
+ activation_fn="geglu",
1337
+ positional_embeddings="sinusoidal",
1338
+ num_positional_embeddings=temporal_max_seq_length,
1339
+ attention_head_dim=out_channels // temporal_num_attention_heads,
1340
+ )
1341
+ )
1342
+
1343
+ self.attentions = nn.ModuleList(attentions)
1344
+ self.resnets = nn.ModuleList(resnets)
1345
+ self.motion_modules = nn.ModuleList(motion_modules)
1346
+
1347
+ if add_upsample:
1348
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1349
+ else:
1350
+ self.upsamplers = None
1351
+
1352
+ self.gradient_checkpointing = False
1353
+ self.resolution_idx = resolution_idx
1354
+
1355
+ def forward(
1356
+ self,
1357
+ hidden_states: torch.FloatTensor,
1358
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1359
+ temb: Optional[torch.FloatTensor] = None,
1360
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1361
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1362
+ upsample_size: Optional[int] = None,
1363
+ attention_mask: Optional[torch.FloatTensor] = None,
1364
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1365
+ num_frames: int = 1,
1366
+ ) -> torch.FloatTensor:
1367
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1368
+ is_freeu_enabled = (
1369
+ getattr(self, "s1", None)
1370
+ and getattr(self, "s2", None)
1371
+ and getattr(self, "b1", None)
1372
+ and getattr(self, "b2", None)
1373
+ )
1374
+
1375
+ blocks = zip(self.resnets, self.attentions, self.motion_modules)
1376
+ for resnet, attn, motion_module in blocks:
1377
+ # pop res hidden states
1378
+ res_hidden_states = res_hidden_states_tuple[-1]
1379
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1380
+
1381
+ # FreeU: Only operate on the first two stages
1382
+ if is_freeu_enabled:
1383
+ hidden_states, res_hidden_states = apply_freeu(
1384
+ self.resolution_idx,
1385
+ hidden_states,
1386
+ res_hidden_states,
1387
+ s1=self.s1,
1388
+ s2=self.s2,
1389
+ b1=self.b1,
1390
+ b2=self.b2,
1391
+ )
1392
+
1393
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1394
+
1395
+ if self.training and self.gradient_checkpointing:
1396
+
1397
+ def create_custom_forward(module, return_dict=None):
1398
+ def custom_forward(*inputs):
1399
+ if return_dict is not None:
1400
+ return module(*inputs, return_dict=return_dict)
1401
+ else:
1402
+ return module(*inputs)
1403
+
1404
+ return custom_forward
1405
+
1406
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1407
+ hidden_states = torch.utils.checkpoint.checkpoint(
1408
+ create_custom_forward(resnet),
1409
+ hidden_states,
1410
+ temb,
1411
+ **ckpt_kwargs,
1412
+ )
1413
+ hidden_states = attn(
1414
+ hidden_states,
1415
+ encoder_hidden_states=encoder_hidden_states,
1416
+ cross_attention_kwargs=cross_attention_kwargs,
1417
+ attention_mask=attention_mask,
1418
+ encoder_attention_mask=encoder_attention_mask,
1419
+ return_dict=False,
1420
+ )[0]
1421
+ else:
1422
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1423
+ hidden_states = attn(
1424
+ hidden_states,
1425
+ encoder_hidden_states=encoder_hidden_states,
1426
+ cross_attention_kwargs=cross_attention_kwargs,
1427
+ attention_mask=attention_mask,
1428
+ encoder_attention_mask=encoder_attention_mask,
1429
+ return_dict=False,
1430
+ )[0]
1431
+ hidden_states = motion_module(
1432
+ hidden_states,
1433
+ num_frames=num_frames,
1434
+ )[0]
1435
+
1436
+ if self.upsamplers is not None:
1437
+ for upsampler in self.upsamplers:
1438
+ hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
1439
+
1440
+ return hidden_states
1441
+
1442
+
1443
+ class UpBlockMotion(nn.Module):
1444
+ def __init__(
1445
+ self,
1446
+ in_channels: int,
1447
+ prev_output_channel: int,
1448
+ out_channels: int,
1449
+ temb_channels: int,
1450
+ resolution_idx: Optional[int] = None,
1451
+ dropout: float = 0.0,
1452
+ num_layers: int = 1,
1453
+ resnet_eps: float = 1e-6,
1454
+ resnet_time_scale_shift: str = "default",
1455
+ resnet_act_fn: str = "swish",
1456
+ resnet_groups: int = 32,
1457
+ resnet_pre_norm: bool = True,
1458
+ output_scale_factor: float = 1.0,
1459
+ add_upsample: bool = True,
1460
+ temporal_norm_num_groups: int = 32,
1461
+ temporal_cross_attention_dim: Optional[int] = None,
1462
+ temporal_num_attention_heads: int = 8,
1463
+ temporal_max_seq_length: int = 32,
1464
+ ):
1465
+ super().__init__()
1466
+ resnets = []
1467
+ motion_modules = []
1468
+
1469
+ for i in range(num_layers):
1470
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1471
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1472
+
1473
+ resnets.append(
1474
+ ResnetBlock2D(
1475
+ in_channels=resnet_in_channels + res_skip_channels,
1476
+ out_channels=out_channels,
1477
+ temb_channels=temb_channels,
1478
+ eps=resnet_eps,
1479
+ groups=resnet_groups,
1480
+ dropout=dropout,
1481
+ time_embedding_norm=resnet_time_scale_shift,
1482
+ non_linearity=resnet_act_fn,
1483
+ output_scale_factor=output_scale_factor,
1484
+ pre_norm=resnet_pre_norm,
1485
+ )
1486
+ )
1487
+
1488
+ motion_modules.append(
1489
+ TransformerTemporalModel(
1490
+ num_attention_heads=temporal_num_attention_heads,
1491
+ in_channels=out_channels,
1492
+ norm_num_groups=temporal_norm_num_groups,
1493
+ cross_attention_dim=temporal_cross_attention_dim,
1494
+ attention_bias=False,
1495
+ activation_fn="geglu",
1496
+ positional_embeddings="sinusoidal",
1497
+ num_positional_embeddings=temporal_max_seq_length,
1498
+ attention_head_dim=out_channels // temporal_num_attention_heads,
1499
+ )
1500
+ )
1501
+
1502
+ self.resnets = nn.ModuleList(resnets)
1503
+ self.motion_modules = nn.ModuleList(motion_modules)
1504
+
1505
+ if add_upsample:
1506
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1507
+ else:
1508
+ self.upsamplers = None
1509
+
1510
+ self.gradient_checkpointing = False
1511
+ self.resolution_idx = resolution_idx
1512
+
1513
+ def forward(
1514
+ self,
1515
+ hidden_states: torch.FloatTensor,
1516
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1517
+ temb: Optional[torch.FloatTensor] = None,
1518
+ upsample_size=None,
1519
+ scale: float = 1.0,
1520
+ num_frames: int = 1,
1521
+ ) -> torch.FloatTensor:
1522
+ is_freeu_enabled = (
1523
+ getattr(self, "s1", None)
1524
+ and getattr(self, "s2", None)
1525
+ and getattr(self, "b1", None)
1526
+ and getattr(self, "b2", None)
1527
+ )
1528
+
1529
+ blocks = zip(self.resnets, self.motion_modules)
1530
+
1531
+ for resnet, motion_module in blocks:
1532
+ # pop res hidden states
1533
+ res_hidden_states = res_hidden_states_tuple[-1]
1534
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1535
+
1536
+ # FreeU: Only operate on the first two stages
1537
+ if is_freeu_enabled:
1538
+ hidden_states, res_hidden_states = apply_freeu(
1539
+ self.resolution_idx,
1540
+ hidden_states,
1541
+ res_hidden_states,
1542
+ s1=self.s1,
1543
+ s2=self.s2,
1544
+ b1=self.b1,
1545
+ b2=self.b2,
1546
+ )
1547
+
1548
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1549
+
1550
+ if self.training and self.gradient_checkpointing:
1551
+
1552
+ def create_custom_forward(module):
1553
+ def custom_forward(*inputs):
1554
+ return module(*inputs)
1555
+
1556
+ return custom_forward
1557
+
1558
+ if is_torch_version(">=", "1.11.0"):
1559
+ hidden_states = torch.utils.checkpoint.checkpoint(
1560
+ create_custom_forward(resnet),
1561
+ hidden_states,
1562
+ temb,
1563
+ use_reentrant=False,
1564
+ )
1565
+ else:
1566
+ hidden_states = torch.utils.checkpoint.checkpoint(
1567
+ create_custom_forward(resnet), hidden_states, temb
1568
+ )
1569
+ hidden_states = torch.utils.checkpoint.checkpoint(
1570
+ create_custom_forward(resnet),
1571
+ hidden_states,
1572
+ temb,
1573
+ )
1574
+
1575
+ else:
1576
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1577
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1578
+
1579
+ if self.upsamplers is not None:
1580
+ for upsampler in self.upsamplers:
1581
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1582
+
1583
+ return hidden_states
1584
+
1585
+
1586
+ class UNetMidBlockCrossAttnMotion(nn.Module):
1587
+ def __init__(
1588
+ self,
1589
+ in_channels: int,
1590
+ temb_channels: int,
1591
+ dropout: float = 0.0,
1592
+ num_layers: int = 1,
1593
+ transformer_layers_per_block: int = 1,
1594
+ resnet_eps: float = 1e-6,
1595
+ resnet_time_scale_shift: str = "default",
1596
+ resnet_act_fn: str = "swish",
1597
+ resnet_groups: int = 32,
1598
+ resnet_pre_norm: bool = True,
1599
+ num_attention_heads: int = 1,
1600
+ output_scale_factor: float = 1.0,
1601
+ cross_attention_dim: int = 1280,
1602
+ dual_cross_attention: float = False,
1603
+ use_linear_projection: float = False,
1604
+ upcast_attention: float = False,
1605
+ attention_type: str = "default",
1606
+ temporal_num_attention_heads: int = 1,
1607
+ temporal_cross_attention_dim: Optional[int] = None,
1608
+ temporal_max_seq_length: int = 32,
1609
+ ):
1610
+ super().__init__()
1611
+
1612
+ self.has_cross_attention = True
1613
+ self.num_attention_heads = num_attention_heads
1614
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
1615
+
1616
+ # there is always at least one resnet
1617
+ resnets = [
1618
+ ResnetBlock2D(
1619
+ in_channels=in_channels,
1620
+ out_channels=in_channels,
1621
+ temb_channels=temb_channels,
1622
+ eps=resnet_eps,
1623
+ groups=resnet_groups,
1624
+ dropout=dropout,
1625
+ time_embedding_norm=resnet_time_scale_shift,
1626
+ non_linearity=resnet_act_fn,
1627
+ output_scale_factor=output_scale_factor,
1628
+ pre_norm=resnet_pre_norm,
1629
+ )
1630
+ ]
1631
+ attentions = []
1632
+ motion_modules = []
1633
+
1634
+ for _ in range(num_layers):
1635
+ if not dual_cross_attention:
1636
+ attentions.append(
1637
+ Transformer2DModel(
1638
+ num_attention_heads,
1639
+ in_channels // num_attention_heads,
1640
+ in_channels=in_channels,
1641
+ num_layers=transformer_layers_per_block,
1642
+ cross_attention_dim=cross_attention_dim,
1643
+ norm_num_groups=resnet_groups,
1644
+ use_linear_projection=use_linear_projection,
1645
+ upcast_attention=upcast_attention,
1646
+ attention_type=attention_type,
1647
+ )
1648
+ )
1649
+ else:
1650
+ attentions.append(
1651
+ DualTransformer2DModel(
1652
+ num_attention_heads,
1653
+ in_channels // num_attention_heads,
1654
+ in_channels=in_channels,
1655
+ num_layers=1,
1656
+ cross_attention_dim=cross_attention_dim,
1657
+ norm_num_groups=resnet_groups,
1658
+ )
1659
+ )
1660
+ resnets.append(
1661
+ ResnetBlock2D(
1662
+ in_channels=in_channels,
1663
+ out_channels=in_channels,
1664
+ temb_channels=temb_channels,
1665
+ eps=resnet_eps,
1666
+ groups=resnet_groups,
1667
+ dropout=dropout,
1668
+ time_embedding_norm=resnet_time_scale_shift,
1669
+ non_linearity=resnet_act_fn,
1670
+ output_scale_factor=output_scale_factor,
1671
+ pre_norm=resnet_pre_norm,
1672
+ )
1673
+ )
1674
+ motion_modules.append(
1675
+ TransformerTemporalModel(
1676
+ num_attention_heads=temporal_num_attention_heads,
1677
+ attention_head_dim=in_channels // temporal_num_attention_heads,
1678
+ in_channels=in_channels,
1679
+ norm_num_groups=resnet_groups,
1680
+ cross_attention_dim=temporal_cross_attention_dim,
1681
+ attention_bias=False,
1682
+ positional_embeddings="sinusoidal",
1683
+ num_positional_embeddings=temporal_max_seq_length,
1684
+ activation_fn="geglu",
1685
+ )
1686
+ )
1687
+
1688
+ self.attentions = nn.ModuleList(attentions)
1689
+ self.resnets = nn.ModuleList(resnets)
1690
+ self.motion_modules = nn.ModuleList(motion_modules)
1691
+
1692
+ self.gradient_checkpointing = False
1693
+
1694
+ def forward(
1695
+ self,
1696
+ hidden_states: torch.FloatTensor,
1697
+ temb: Optional[torch.FloatTensor] = None,
1698
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1699
+ attention_mask: Optional[torch.FloatTensor] = None,
1700
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1701
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1702
+ num_frames: int = 1,
1703
+ ) -> torch.FloatTensor:
1704
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1705
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
1706
+
1707
+ blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
1708
+ for attn, resnet, motion_module in blocks:
1709
+ if self.training and self.gradient_checkpointing:
1710
+
1711
+ def create_custom_forward(module, return_dict=None):
1712
+ def custom_forward(*inputs):
1713
+ if return_dict is not None:
1714
+ return module(*inputs, return_dict=return_dict)
1715
+ else:
1716
+ return module(*inputs)
1717
+
1718
+ return custom_forward
1719
+
1720
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1721
+ hidden_states = attn(
1722
+ hidden_states,
1723
+ encoder_hidden_states=encoder_hidden_states,
1724
+ cross_attention_kwargs=cross_attention_kwargs,
1725
+ attention_mask=attention_mask,
1726
+ encoder_attention_mask=encoder_attention_mask,
1727
+ return_dict=False,
1728
+ )[0]
1729
+ hidden_states = torch.utils.checkpoint.checkpoint(
1730
+ create_custom_forward(motion_module),
1731
+ hidden_states,
1732
+ temb,
1733
+ **ckpt_kwargs,
1734
+ )
1735
+ hidden_states = torch.utils.checkpoint.checkpoint(
1736
+ create_custom_forward(resnet),
1737
+ hidden_states,
1738
+ temb,
1739
+ **ckpt_kwargs,
1740
+ )
1741
+ else:
1742
+ hidden_states = attn(
1743
+ hidden_states,
1744
+ encoder_hidden_states=encoder_hidden_states,
1745
+ cross_attention_kwargs=cross_attention_kwargs,
1746
+ attention_mask=attention_mask,
1747
+ encoder_attention_mask=encoder_attention_mask,
1748
+ return_dict=False,
1749
+ )[0]
1750
+ hidden_states = motion_module(
1751
+ hidden_states,
1752
+ num_frames=num_frames,
1753
+ )[0]
1754
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1755
+
1756
+ return hidden_states
1757
+
1758
+
1759
+ class MidBlockTemporalDecoder(nn.Module):
1760
+ def __init__(
1761
+ self,
1762
+ in_channels: int,
1763
+ out_channels: int,
1764
+ attention_head_dim: int = 512,
1765
+ num_layers: int = 1,
1766
+ upcast_attention: bool = False,
1767
+ ):
1768
+ super().__init__()
1769
+
1770
+ resnets = []
1771
+ attentions = []
1772
+ for i in range(num_layers):
1773
+ input_channels = in_channels if i == 0 else out_channels
1774
+ resnets.append(
1775
+ SpatioTemporalResBlock(
1776
+ in_channels=input_channels,
1777
+ out_channels=out_channels,
1778
+ temb_channels=None,
1779
+ eps=1e-6,
1780
+ temporal_eps=1e-5,
1781
+ merge_factor=0.0,
1782
+ merge_strategy="learned",
1783
+ switch_spatial_to_temporal_mix=True,
1784
+ )
1785
+ )
1786
+
1787
+ attentions.append(
1788
+ Attention(
1789
+ query_dim=in_channels,
1790
+ heads=in_channels // attention_head_dim,
1791
+ dim_head=attention_head_dim,
1792
+ eps=1e-6,
1793
+ upcast_attention=upcast_attention,
1794
+ norm_num_groups=32,
1795
+ bias=True,
1796
+ residual_connection=True,
1797
+ )
1798
+ )
1799
+
1800
+ self.attentions = nn.ModuleList(attentions)
1801
+ self.resnets = nn.ModuleList(resnets)
1802
+
1803
+ def forward(
1804
+ self,
1805
+ hidden_states: torch.FloatTensor,
1806
+ image_only_indicator: torch.FloatTensor,
1807
+ ):
1808
+ hidden_states = self.resnets[0](
1809
+ hidden_states,
1810
+ image_only_indicator=image_only_indicator,
1811
+ )
1812
+ for resnet, attn in zip(self.resnets[1:], self.attentions):
1813
+ hidden_states = attn(hidden_states)
1814
+ hidden_states = resnet(
1815
+ hidden_states,
1816
+ image_only_indicator=image_only_indicator,
1817
+ )
1818
+
1819
+ return hidden_states
1820
+
1821
+
1822
+ class UpBlockTemporalDecoder(nn.Module):
1823
+ def __init__(
1824
+ self,
1825
+ in_channels: int,
1826
+ out_channels: int,
1827
+ num_layers: int = 1,
1828
+ add_upsample: bool = True,
1829
+ ):
1830
+ super().__init__()
1831
+ resnets = []
1832
+ for i in range(num_layers):
1833
+ input_channels = in_channels if i == 0 else out_channels
1834
+
1835
+ resnets.append(
1836
+ SpatioTemporalResBlock(
1837
+ in_channels=input_channels,
1838
+ out_channels=out_channels,
1839
+ temb_channels=None,
1840
+ eps=1e-6,
1841
+ temporal_eps=1e-5,
1842
+ merge_factor=0.0,
1843
+ merge_strategy="learned",
1844
+ switch_spatial_to_temporal_mix=True,
1845
+ )
1846
+ )
1847
+ self.resnets = nn.ModuleList(resnets)
1848
+
1849
+ if add_upsample:
1850
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1851
+ else:
1852
+ self.upsamplers = None
1853
+
1854
+ def forward(
1855
+ self,
1856
+ hidden_states: torch.FloatTensor,
1857
+ image_only_indicator: torch.FloatTensor,
1858
+ ) -> torch.FloatTensor:
1859
+ for resnet in self.resnets:
1860
+ hidden_states = resnet(
1861
+ hidden_states,
1862
+ image_only_indicator=image_only_indicator,
1863
+ )
1864
+
1865
+ if self.upsamplers is not None:
1866
+ for upsampler in self.upsamplers:
1867
+ hidden_states = upsampler(hidden_states)
1868
+
1869
+ return hidden_states
1870
+
1871
+
1872
+ class UNetMidBlockSpatioTemporal(nn.Module):
1873
+ def __init__(
1874
+ self,
1875
+ in_channels: int,
1876
+ temb_channels: int,
1877
+ num_layers: int = 1,
1878
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
1879
+ num_attention_heads: int = 1,
1880
+ cross_attention_dim: int = 1280,
1881
+ ):
1882
+ super().__init__()
1883
+
1884
+ self.has_cross_attention = True
1885
+ self.num_attention_heads = num_attention_heads
1886
+
1887
+ # support for variable transformer layers per block
1888
+ if isinstance(transformer_layers_per_block, int):
1889
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
1890
+
1891
+ # there is always at least one resnet
1892
+ resnets = [
1893
+ SpatioTemporalResBlock(
1894
+ in_channels=in_channels,
1895
+ out_channels=in_channels,
1896
+ temb_channels=temb_channels,
1897
+ eps=1e-5,
1898
+ )
1899
+ ]
1900
+ attentions = []
1901
+
1902
+ for i in range(num_layers):
1903
+ attentions.append(
1904
+ TransformerSpatioTemporalModel(
1905
+ num_attention_heads,
1906
+ in_channels // num_attention_heads,
1907
+ in_channels=in_channels,
1908
+ num_layers=transformer_layers_per_block[i],
1909
+ cross_attention_dim=cross_attention_dim,
1910
+ )
1911
+ )
1912
+
1913
+ resnets.append(
1914
+ SpatioTemporalResBlock(
1915
+ in_channels=in_channels,
1916
+ out_channels=in_channels,
1917
+ temb_channels=temb_channels,
1918
+ eps=1e-5,
1919
+ )
1920
+ )
1921
+
1922
+ self.attentions = nn.ModuleList(attentions)
1923
+ self.resnets = nn.ModuleList(resnets)
1924
+
1925
+ self.gradient_checkpointing = False
1926
+
1927
+ def forward(
1928
+ self,
1929
+ hidden_states: torch.FloatTensor,
1930
+ temb: Optional[torch.FloatTensor] = None,
1931
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1932
+ image_only_indicator: Optional[torch.Tensor] = None,
1933
+ ) -> torch.FloatTensor:
1934
+ hidden_states = self.resnets[0](
1935
+ hidden_states,
1936
+ temb,
1937
+ image_only_indicator=image_only_indicator,
1938
+ )
1939
+
1940
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
1941
+ if self.training and self.gradient_checkpointing: # TODO
1942
+
1943
+ def create_custom_forward(module, return_dict=None):
1944
+ def custom_forward(*inputs):
1945
+ if return_dict is not None:
1946
+ return module(*inputs, return_dict=return_dict)
1947
+ else:
1948
+ return module(*inputs)
1949
+
1950
+ return custom_forward
1951
+
1952
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1953
+ hidden_states = attn(
1954
+ hidden_states,
1955
+ encoder_hidden_states=encoder_hidden_states,
1956
+ image_only_indicator=image_only_indicator,
1957
+ return_dict=False,
1958
+ )[0]
1959
+ hidden_states = torch.utils.checkpoint.checkpoint(
1960
+ create_custom_forward(resnet),
1961
+ hidden_states,
1962
+ temb,
1963
+ image_only_indicator,
1964
+ **ckpt_kwargs,
1965
+ )
1966
+ else:
1967
+ hidden_states = attn(
1968
+ hidden_states,
1969
+ encoder_hidden_states=encoder_hidden_states,
1970
+ image_only_indicator=image_only_indicator,
1971
+ return_dict=False,
1972
+ )[0]
1973
+ hidden_states = resnet(
1974
+ hidden_states,
1975
+ temb,
1976
+ image_only_indicator=image_only_indicator,
1977
+ )
1978
+
1979
+ return hidden_states
1980
+
1981
+
1982
+ class DownBlockSpatioTemporal(nn.Module):
1983
+ def __init__(
1984
+ self,
1985
+ in_channels: int,
1986
+ out_channels: int,
1987
+ temb_channels: int,
1988
+ num_layers: int = 1,
1989
+ add_downsample: bool = True,
1990
+ ):
1991
+ super().__init__()
1992
+ resnets = []
1993
+
1994
+ for i in range(num_layers):
1995
+ in_channels = in_channels if i == 0 else out_channels
1996
+ resnets.append(
1997
+ SpatioTemporalResBlock(
1998
+ in_channels=in_channels,
1999
+ out_channels=out_channels,
2000
+ temb_channels=temb_channels,
2001
+ eps=1e-5,
2002
+ )
2003
+ )
2004
+
2005
+ self.resnets = nn.ModuleList(resnets)
2006
+
2007
+ if add_downsample:
2008
+ self.downsamplers = nn.ModuleList(
2009
+ [
2010
+ Downsample2D(
2011
+ out_channels,
2012
+ use_conv=True,
2013
+ out_channels=out_channels,
2014
+ name="op",
2015
+ )
2016
+ ]
2017
+ )
2018
+ else:
2019
+ self.downsamplers = None
2020
+
2021
+ self.gradient_checkpointing = False
2022
+
2023
+ def forward(
2024
+ self,
2025
+ hidden_states: torch.FloatTensor,
2026
+ temb: Optional[torch.FloatTensor] = None,
2027
+ image_only_indicator: Optional[torch.Tensor] = None,
2028
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2029
+ output_states = ()
2030
+ for resnet in self.resnets:
2031
+ if self.training and self.gradient_checkpointing:
2032
+
2033
+ def create_custom_forward(module):
2034
+ def custom_forward(*inputs):
2035
+ return module(*inputs)
2036
+
2037
+ return custom_forward
2038
+
2039
+ if is_torch_version(">=", "1.11.0"):
2040
+ hidden_states = torch.utils.checkpoint.checkpoint(
2041
+ create_custom_forward(resnet),
2042
+ hidden_states,
2043
+ temb,
2044
+ image_only_indicator,
2045
+ use_reentrant=False,
2046
+ )
2047
+ else:
2048
+ hidden_states = torch.utils.checkpoint.checkpoint(
2049
+ create_custom_forward(resnet),
2050
+ hidden_states,
2051
+ temb,
2052
+ image_only_indicator,
2053
+ )
2054
+ else:
2055
+ hidden_states = resnet(
2056
+ hidden_states,
2057
+ temb,
2058
+ image_only_indicator=image_only_indicator,
2059
+ )
2060
+
2061
+ output_states = output_states + (hidden_states,)
2062
+
2063
+ if self.downsamplers is not None:
2064
+ for downsampler in self.downsamplers:
2065
+ hidden_states = downsampler(hidden_states)
2066
+
2067
+ output_states = output_states + (hidden_states,)
2068
+
2069
+ return hidden_states, output_states
2070
+
2071
+
2072
+ class CrossAttnDownBlockSpatioTemporal(nn.Module):
2073
+ def __init__(
2074
+ self,
2075
+ in_channels: int,
2076
+ out_channels: int,
2077
+ temb_channels: int,
2078
+ num_layers: int = 1,
2079
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
2080
+ num_attention_heads: int = 1,
2081
+ cross_attention_dim: int = 1280,
2082
+ add_downsample: bool = True,
2083
+ ):
2084
+ super().__init__()
2085
+ resnets = []
2086
+ attentions = []
2087
+
2088
+ self.has_cross_attention = True
2089
+ self.num_attention_heads = num_attention_heads
2090
+ if isinstance(transformer_layers_per_block, int):
2091
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
2092
+
2093
+ for i in range(num_layers):
2094
+ in_channels = in_channels if i == 0 else out_channels
2095
+ resnets.append(
2096
+ SpatioTemporalResBlock(
2097
+ in_channels=in_channels,
2098
+ out_channels=out_channels,
2099
+ temb_channels=temb_channels,
2100
+ eps=1e-6,
2101
+ )
2102
+ )
2103
+ attentions.append(
2104
+ TransformerSpatioTemporalModel(
2105
+ num_attention_heads,
2106
+ out_channels // num_attention_heads,
2107
+ in_channels=out_channels,
2108
+ num_layers=transformer_layers_per_block[i],
2109
+ cross_attention_dim=cross_attention_dim,
2110
+ )
2111
+ )
2112
+
2113
+ self.attentions = nn.ModuleList(attentions)
2114
+ self.resnets = nn.ModuleList(resnets)
2115
+
2116
+ if add_downsample:
2117
+ self.downsamplers = nn.ModuleList(
2118
+ [
2119
+ Downsample2D(
2120
+ out_channels,
2121
+ use_conv=True,
2122
+ out_channels=out_channels,
2123
+ padding=1,
2124
+ name="op",
2125
+ )
2126
+ ]
2127
+ )
2128
+ else:
2129
+ self.downsamplers = None
2130
+
2131
+ self.gradient_checkpointing = False
2132
+
2133
+ def forward(
2134
+ self,
2135
+ hidden_states: torch.FloatTensor,
2136
+ temb: Optional[torch.FloatTensor] = None,
2137
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
2138
+ image_only_indicator: Optional[torch.Tensor] = None,
2139
+ additional_residuals: Optional[torch.FloatTensor] = None,
2140
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2141
+ output_states = ()
2142
+
2143
+ blocks = list(zip(self.resnets, self.attentions))
2144
+ for block_idx, (resnet, attn) in enumerate(blocks):
2145
+ if self.training and self.gradient_checkpointing: # TODO
2146
+
2147
+ def create_custom_forward(module, return_dict=None):
2148
+ def custom_forward(*inputs):
2149
+ if return_dict is not None:
2150
+ return module(*inputs, return_dict=return_dict)
2151
+ else:
2152
+ return module(*inputs)
2153
+
2154
+ return custom_forward
2155
+
2156
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2157
+ hidden_states = torch.utils.checkpoint.checkpoint(
2158
+ create_custom_forward(resnet),
2159
+ hidden_states,
2160
+ temb,
2161
+ image_only_indicator,
2162
+ **ckpt_kwargs,
2163
+ )
2164
+
2165
+ hidden_states = attn(
2166
+ hidden_states,
2167
+ encoder_hidden_states=encoder_hidden_states,
2168
+ image_only_indicator=image_only_indicator,
2169
+ return_dict=False,
2170
+ )[0]
2171
+ else:
2172
+ hidden_states = resnet(
2173
+ hidden_states,
2174
+ temb,
2175
+ image_only_indicator=image_only_indicator,
2176
+ )
2177
+ hidden_states = attn(
2178
+ hidden_states,
2179
+ encoder_hidden_states=encoder_hidden_states,
2180
+ image_only_indicator=image_only_indicator,
2181
+ return_dict=False,
2182
+ )[0]
2183
+
2184
+ output_states = output_states + (hidden_states,)
2185
+
2186
+ # NOTE
2187
+ if block_idx == len(blocks) - 1 and additional_residuals is not None:
2188
+ if hidden_states.dim() == 5:
2189
+ additional_residuals = rearrange(additional_residuals, '(b f) c h w -> b c f h w', b=hidden_states.shape[0])
2190
+ hidden_states = hidden_states + additional_residuals
2191
+
2192
+ if self.downsamplers is not None:
2193
+ for downsampler in self.downsamplers:
2194
+ hidden_states = downsampler(hidden_states)
2195
+
2196
+ output_states = output_states + (hidden_states,)
2197
+
2198
+ return hidden_states, output_states
2199
+
2200
+
2201
+ class UpBlockSpatioTemporal(nn.Module):
2202
+ def __init__(
2203
+ self,
2204
+ in_channels: int,
2205
+ prev_output_channel: int,
2206
+ out_channels: int,
2207
+ temb_channels: int,
2208
+ resolution_idx: Optional[int] = None,
2209
+ num_layers: int = 1,
2210
+ resnet_eps: float = 1e-6,
2211
+ add_upsample: bool = True,
2212
+ ):
2213
+ super().__init__()
2214
+ resnets = []
2215
+
2216
+ for i in range(num_layers):
2217
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
2218
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
2219
+
2220
+ resnets.append(
2221
+ SpatioTemporalResBlock(
2222
+ in_channels=resnet_in_channels + res_skip_channels,
2223
+ out_channels=out_channels,
2224
+ temb_channels=temb_channels,
2225
+ eps=resnet_eps,
2226
+ )
2227
+ )
2228
+
2229
+ self.resnets = nn.ModuleList(resnets)
2230
+
2231
+ if add_upsample:
2232
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
2233
+ else:
2234
+ self.upsamplers = None
2235
+
2236
+ self.gradient_checkpointing = False
2237
+ self.resolution_idx = resolution_idx
2238
+
2239
+ def forward(
2240
+ self,
2241
+ hidden_states: torch.FloatTensor,
2242
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2243
+ temb: Optional[torch.FloatTensor] = None,
2244
+ image_only_indicator: Optional[torch.Tensor] = None,
2245
+ ) -> torch.FloatTensor:
2246
+ for resnet in self.resnets:
2247
+ # pop res hidden states
2248
+ res_hidden_states = res_hidden_states_tuple[-1]
2249
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
2250
+
2251
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2252
+
2253
+ if self.training and self.gradient_checkpointing:
2254
+
2255
+ def create_custom_forward(module):
2256
+ def custom_forward(*inputs):
2257
+ return module(*inputs)
2258
+
2259
+ return custom_forward
2260
+
2261
+ if is_torch_version(">=", "1.11.0"):
2262
+ hidden_states = torch.utils.checkpoint.checkpoint(
2263
+ create_custom_forward(resnet),
2264
+ hidden_states,
2265
+ temb,
2266
+ image_only_indicator,
2267
+ use_reentrant=False,
2268
+ )
2269
+ else:
2270
+ hidden_states = torch.utils.checkpoint.checkpoint(
2271
+ create_custom_forward(resnet),
2272
+ hidden_states,
2273
+ temb,
2274
+ image_only_indicator,
2275
+ )
2276
+ else:
2277
+ hidden_states = resnet(
2278
+ hidden_states,
2279
+ temb,
2280
+ image_only_indicator=image_only_indicator,
2281
+ )
2282
+
2283
+ if self.upsamplers is not None:
2284
+ for upsampler in self.upsamplers:
2285
+ hidden_states = upsampler(hidden_states)
2286
+
2287
+ return hidden_states
2288
+
2289
+
2290
+ class CrossAttnUpBlockSpatioTemporal(nn.Module):
2291
+ def __init__(
2292
+ self,
2293
+ in_channels: int,
2294
+ out_channels: int,
2295
+ prev_output_channel: int,
2296
+ temb_channels: int,
2297
+ resolution_idx: Optional[int] = None,
2298
+ num_layers: int = 1,
2299
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
2300
+ resnet_eps: float = 1e-6,
2301
+ num_attention_heads: int = 1,
2302
+ cross_attention_dim: int = 1280,
2303
+ add_upsample: bool = True,
2304
+ ):
2305
+ super().__init__()
2306
+ resnets = []
2307
+ attentions = []
2308
+
2309
+ self.has_cross_attention = True
2310
+ self.num_attention_heads = num_attention_heads
2311
+
2312
+ if isinstance(transformer_layers_per_block, int):
2313
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
2314
+
2315
+ for i in range(num_layers):
2316
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
2317
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
2318
+
2319
+ resnets.append(
2320
+ SpatioTemporalResBlock(
2321
+ in_channels=resnet_in_channels + res_skip_channels,
2322
+ out_channels=out_channels,
2323
+ temb_channels=temb_channels,
2324
+ eps=resnet_eps,
2325
+ )
2326
+ )
2327
+ attentions.append(
2328
+ TransformerSpatioTemporalModel(
2329
+ num_attention_heads,
2330
+ out_channels // num_attention_heads,
2331
+ in_channels=out_channels,
2332
+ num_layers=transformer_layers_per_block[i],
2333
+ cross_attention_dim=cross_attention_dim,
2334
+ )
2335
+ )
2336
+
2337
+ self.attentions = nn.ModuleList(attentions)
2338
+ self.resnets = nn.ModuleList(resnets)
2339
+
2340
+ if add_upsample:
2341
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
2342
+ else:
2343
+ self.upsamplers = None
2344
+
2345
+ self.gradient_checkpointing = False
2346
+ self.resolution_idx = resolution_idx
2347
+
2348
+ def forward(
2349
+ self,
2350
+ hidden_states: torch.FloatTensor,
2351
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2352
+ temb: Optional[torch.FloatTensor] = None,
2353
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
2354
+ image_only_indicator: Optional[torch.Tensor] = None,
2355
+ ) -> torch.FloatTensor:
2356
+ for resnet, attn in zip(self.resnets, self.attentions):
2357
+ # pop res hidden states
2358
+ res_hidden_states = res_hidden_states_tuple[-1]
2359
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
2360
+
2361
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2362
+
2363
+ if self.training and self.gradient_checkpointing: # TODO
2364
+
2365
+ def create_custom_forward(module, return_dict=None):
2366
+ def custom_forward(*inputs):
2367
+ if return_dict is not None:
2368
+ return module(*inputs, return_dict=return_dict)
2369
+ else:
2370
+ return module(*inputs)
2371
+
2372
+ return custom_forward
2373
+
2374
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2375
+ hidden_states = torch.utils.checkpoint.checkpoint(
2376
+ create_custom_forward(resnet),
2377
+ hidden_states,
2378
+ temb,
2379
+ image_only_indicator,
2380
+ **ckpt_kwargs,
2381
+ )
2382
+ hidden_states = attn(
2383
+ hidden_states,
2384
+ encoder_hidden_states=encoder_hidden_states,
2385
+ image_only_indicator=image_only_indicator,
2386
+ return_dict=False,
2387
+ )[0]
2388
+ else:
2389
+ hidden_states = resnet(
2390
+ hidden_states,
2391
+ temb,
2392
+ image_only_indicator=image_only_indicator,
2393
+ )
2394
+ hidden_states = attn(
2395
+ hidden_states,
2396
+ encoder_hidden_states=encoder_hidden_states,
2397
+ image_only_indicator=image_only_indicator,
2398
+ return_dict=False,
2399
+ )[0]
2400
+
2401
+ if self.upsamplers is not None:
2402
+ for upsampler in self.upsamplers:
2403
+ hidden_states = upsampler(hidden_states)
2404
+
2405
+ return hidden_states
models_diffusers/unet_spatio_temporal_condition.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Tuple, Union
3
+ from einops import rearrange
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.utils import BaseOutput, logging
11
+ # from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
12
+ from models_diffusers.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
13
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+ # from diffusers.models.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
16
+ from models_diffusers.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
17
+
18
+
19
+ import inspect
20
+ import itertools
21
+ import os
22
+ import re
23
+ from collections import OrderedDict
24
+ from functools import partial
25
+ from typing import Any, Callable, List, Optional, Tuple, Union
26
+
27
+ from diffusers import __version__
28
+ from diffusers.utils import (
29
+ CONFIG_NAME,
30
+ DIFFUSERS_CACHE,
31
+ FLAX_WEIGHTS_NAME,
32
+ HF_HUB_OFFLINE,
33
+ MIN_PEFT_VERSION,
34
+ SAFETENSORS_WEIGHTS_NAME,
35
+ WEIGHTS_NAME,
36
+ _add_variant,
37
+ _get_model_file,
38
+ check_peft_version,
39
+ deprecate,
40
+ is_accelerate_available,
41
+ is_torch_version,
42
+ logging,
43
+ )
44
+ from diffusers.utils.hub_utils import PushToHubMixin
45
+ from diffusers.models.modeling_utils import load_model_dict_into_meta, load_state_dict
46
+
47
+ if is_torch_version(">=", "1.9.0"):
48
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
49
+ else:
50
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
51
+
52
+ if is_accelerate_available():
53
+ import accelerate
54
+ from accelerate.utils import set_module_tensor_to_device
55
+ from accelerate.utils.versions import is_torch_version
56
+
57
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
58
+
59
+
60
+ @dataclass
61
+ class UNetSpatioTemporalConditionOutput(BaseOutput):
62
+ """
63
+ The output of [`UNetSpatioTemporalConditionModel`].
64
+
65
+ Args:
66
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
67
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
68
+ """
69
+
70
+ sample: torch.FloatTensor = None
71
+ intermediate_features: Optional[Tuple[torch.FloatTensor]] = None
72
+
73
+
74
+ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
75
+ r"""
76
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
77
+ shaped output.
78
+
79
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
80
+ for all models (such as downloading or saving).
81
+
82
+ Parameters:
83
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
84
+ Height and width of input/output sample.
85
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
86
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
87
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
88
+ The tuple of downsample blocks to use.
89
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
90
+ The tuple of upsample blocks to use.
91
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
92
+ The tuple of output channels for each block.
93
+ addition_time_embed_dim: (`int`, defaults to 256):
94
+ Dimension to to encode the additional time ids.
95
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
96
+ The dimension of the projection of encoded `added_time_ids`.
97
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
98
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
99
+ The dimension of the cross attention features.
100
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
101
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
102
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
103
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
104
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
105
+ The number of attention heads.
106
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
107
+ """
108
+
109
+ _supports_gradient_checkpointing = True
110
+
111
+ @register_to_config
112
+ def __init__(
113
+ self,
114
+ sample_size: Optional[int] = None,
115
+ in_channels: int = 8,
116
+ out_channels: int = 4,
117
+ down_block_types: Tuple[str] = (
118
+ "CrossAttnDownBlockSpatioTemporal",
119
+ "CrossAttnDownBlockSpatioTemporal",
120
+ "CrossAttnDownBlockSpatioTemporal",
121
+ "DownBlockSpatioTemporal",
122
+ ),
123
+ up_block_types: Tuple[str] = (
124
+ "UpBlockSpatioTemporal",
125
+ "CrossAttnUpBlockSpatioTemporal",
126
+ "CrossAttnUpBlockSpatioTemporal",
127
+ "CrossAttnUpBlockSpatioTemporal",
128
+ ),
129
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
130
+ addition_time_embed_dim: int = 256,
131
+ projection_class_embeddings_input_dim: int = 768,
132
+ layers_per_block: Union[int, Tuple[int]] = 2,
133
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
134
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
135
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
136
+ num_frames: int = 25,
137
+ ):
138
+ super().__init__()
139
+
140
+ self.sample_size = sample_size
141
+
142
+ # Check inputs
143
+ if len(down_block_types) != len(up_block_types):
144
+ raise ValueError(
145
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
146
+ )
147
+
148
+ if len(block_out_channels) != len(down_block_types):
149
+ raise ValueError(
150
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
151
+ )
152
+
153
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
154
+ raise ValueError(
155
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
156
+ )
157
+
158
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
159
+ raise ValueError(
160
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
161
+ )
162
+
163
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
164
+ raise ValueError(
165
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
166
+ )
167
+
168
+ self.mask_token = nn.Parameter(torch.randn(1, 1, 4, 1, 1))
169
+
170
+ # input
171
+ self.conv_in = nn.Conv2d(
172
+ in_channels,
173
+ block_out_channels[0],
174
+ kernel_size=3,
175
+ padding=1,
176
+ )
177
+
178
+ # time
179
+ time_embed_dim = block_out_channels[0] * 4
180
+
181
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
182
+ timestep_input_dim = block_out_channels[0]
183
+
184
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
185
+
186
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
187
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
188
+
189
+ self.down_blocks = nn.ModuleList([])
190
+ self.up_blocks = nn.ModuleList([])
191
+
192
+ if isinstance(num_attention_heads, int):
193
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
194
+
195
+ if isinstance(cross_attention_dim, int):
196
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
197
+
198
+ if isinstance(layers_per_block, int):
199
+ layers_per_block = [layers_per_block] * len(down_block_types)
200
+
201
+ if isinstance(transformer_layers_per_block, int):
202
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
203
+
204
+ blocks_time_embed_dim = time_embed_dim
205
+
206
+ # down
207
+ output_channel = block_out_channels[0]
208
+ for i, down_block_type in enumerate(down_block_types):
209
+ input_channel = output_channel
210
+ output_channel = block_out_channels[i]
211
+ is_final_block = i == len(block_out_channels) - 1
212
+
213
+ down_block = get_down_block(
214
+ down_block_type,
215
+ num_layers=layers_per_block[i],
216
+ transformer_layers_per_block=transformer_layers_per_block[i],
217
+ in_channels=input_channel,
218
+ out_channels=output_channel,
219
+ temb_channels=blocks_time_embed_dim,
220
+ add_downsample=not is_final_block,
221
+ resnet_eps=1e-5,
222
+ cross_attention_dim=cross_attention_dim[i],
223
+ num_attention_heads=num_attention_heads[i],
224
+ resnet_act_fn="silu",
225
+ )
226
+ self.down_blocks.append(down_block)
227
+
228
+ # mid
229
+ self.mid_block = UNetMidBlockSpatioTemporal(
230
+ block_out_channels[-1],
231
+ temb_channels=blocks_time_embed_dim,
232
+ transformer_layers_per_block=transformer_layers_per_block[-1],
233
+ cross_attention_dim=cross_attention_dim[-1],
234
+ num_attention_heads=num_attention_heads[-1],
235
+ )
236
+
237
+ # count how many layers upsample the images
238
+ self.num_upsamplers = 0
239
+
240
+ # up
241
+ reversed_block_out_channels = list(reversed(block_out_channels))
242
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
243
+ reversed_layers_per_block = list(reversed(layers_per_block))
244
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
245
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
246
+
247
+ output_channel = reversed_block_out_channels[0]
248
+ for i, up_block_type in enumerate(up_block_types):
249
+ is_final_block = i == len(block_out_channels) - 1
250
+
251
+ prev_output_channel = output_channel
252
+ output_channel = reversed_block_out_channels[i]
253
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
254
+
255
+ # add upsample block for all BUT final layer
256
+ if not is_final_block:
257
+ add_upsample = True
258
+ self.num_upsamplers += 1
259
+ else:
260
+ add_upsample = False
261
+
262
+ up_block = get_up_block(
263
+ up_block_type,
264
+ num_layers=reversed_layers_per_block[i] + 1,
265
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
266
+ in_channels=input_channel,
267
+ out_channels=output_channel,
268
+ prev_output_channel=prev_output_channel,
269
+ temb_channels=blocks_time_embed_dim,
270
+ add_upsample=add_upsample,
271
+ resnet_eps=1e-5,
272
+ resolution_idx=i,
273
+ cross_attention_dim=reversed_cross_attention_dim[i],
274
+ num_attention_heads=reversed_num_attention_heads[i],
275
+ resnet_act_fn="silu",
276
+ )
277
+ self.up_blocks.append(up_block)
278
+ prev_output_channel = output_channel
279
+
280
+ # out
281
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
282
+ self.conv_act = nn.SiLU()
283
+
284
+ self.conv_out = nn.Conv2d(
285
+ block_out_channels[0],
286
+ out_channels,
287
+ kernel_size=3,
288
+ padding=1,
289
+ )
290
+
291
+ @property
292
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
293
+ r"""
294
+ Returns:
295
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
296
+ indexed by its weight name.
297
+ """
298
+ # set recursively
299
+ processors = {}
300
+
301
+ def fn_recursive_add_processors(
302
+ name: str,
303
+ module: torch.nn.Module,
304
+ processors: Dict[str, AttentionProcessor],
305
+ ):
306
+ if hasattr(module, "get_processor"):
307
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
308
+
309
+ for sub_name, child in module.named_children():
310
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
311
+
312
+ return processors
313
+
314
+ for name, module in self.named_children():
315
+ fn_recursive_add_processors(name, module, processors)
316
+
317
+ return processors
318
+
319
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
320
+ r"""
321
+ Sets the attention processor to use to compute attention.
322
+
323
+ Parameters:
324
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
325
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
326
+ for **all** `Attention` layers.
327
+
328
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
329
+ processor. This is strongly recommended when setting trainable attention processors.
330
+
331
+ """
332
+ count = len(self.attn_processors.keys())
333
+
334
+ if isinstance(processor, dict) and len(processor) != count:
335
+ raise ValueError(
336
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
337
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
338
+ )
339
+
340
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
341
+ if hasattr(module, "set_processor"):
342
+ if not isinstance(processor, dict):
343
+ module.set_processor(processor)
344
+ else:
345
+ module.set_processor(processor.pop(f"{name}.processor"))
346
+
347
+ for sub_name, child in module.named_children():
348
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
349
+
350
+ for name, module in self.named_children():
351
+ fn_recursive_attn_processor(name, module, processor)
352
+
353
+ def set_default_attn_processor(self):
354
+ """
355
+ Disables custom attention processors and sets the default attention implementation.
356
+ """
357
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
358
+ processor = AttnProcessor()
359
+ else:
360
+ raise ValueError(
361
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
362
+ )
363
+
364
+ self.set_attn_processor(processor)
365
+
366
+ def _set_gradient_checkpointing(self, module, value=False):
367
+ if hasattr(module, "gradient_checkpointing"):
368
+ module.gradient_checkpointing = value
369
+
370
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
371
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
372
+ """
373
+ Sets the attention processor to use [feed forward
374
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
375
+
376
+ Parameters:
377
+ chunk_size (`int`, *optional*):
378
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
379
+ over each tensor of dim=`dim`.
380
+ dim (`int`, *optional*, defaults to `0`):
381
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
382
+ or dim=1 (sequence length).
383
+ """
384
+ if dim not in [0, 1]:
385
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
386
+
387
+ # By default chunk size is 1
388
+ chunk_size = chunk_size or 1
389
+
390
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
391
+ if hasattr(module, "set_chunk_feed_forward"):
392
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
393
+
394
+ for child in module.children():
395
+ fn_recursive_feed_forward(child, chunk_size, dim)
396
+
397
+ for module in self.children():
398
+ fn_recursive_feed_forward(module, chunk_size, dim)
399
+
400
+ def forward(
401
+ self,
402
+ sample: torch.FloatTensor,
403
+ timestep: Union[torch.Tensor, float, int],
404
+ encoder_hidden_states: torch.Tensor,
405
+ added_time_ids: torch.Tensor,
406
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, # for t2i-adaptor or controlnet
407
+ mid_block_additional_residual: Optional[torch.Tensor] = None, # for controlnet
408
+ return_dict: bool = True,
409
+ # return_intermediate_features: bool = False,
410
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
411
+ r"""
412
+ The [`UNetSpatioTemporalConditionModel`] forward method.
413
+
414
+ Args:
415
+ sample (`torch.FloatTensor`):
416
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
417
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
418
+ encoder_hidden_states (`torch.FloatTensor`):
419
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
420
+ added_time_ids: (`torch.FloatTensor`):
421
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
422
+ embeddings and added to the time embeddings.
423
+ return_dict (`bool`, *optional*, defaults to `True`):
424
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
425
+ tuple.
426
+ Returns:
427
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
428
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
429
+ a `tuple` is returned where the first element is the sample tensor.
430
+ """
431
+ # 1. time
432
+ timesteps = timestep
433
+ if not torch.is_tensor(timesteps):
434
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
435
+ # This would be a good case for the `match` statement (Python 3.10+)
436
+ is_mps = sample.device.type == "mps"
437
+ if isinstance(timestep, float):
438
+ dtype = torch.float32 if is_mps else torch.float64
439
+ else:
440
+ dtype = torch.int32 if is_mps else torch.int64
441
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
442
+ elif len(timesteps.shape) == 0:
443
+ timesteps = timesteps[None].to(sample.device)
444
+
445
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
446
+ batch_size, num_frames = sample.shape[:2]
447
+ timesteps = timesteps.expand(batch_size)
448
+
449
+ t_emb = self.time_proj(timesteps)
450
+
451
+ # `Timesteps` does not contain any weights and will always return f32 tensors
452
+ # but time_embedding might actually be running in fp16. so we need to cast here.
453
+ # there might be better ways to encapsulate this.
454
+ t_emb = t_emb.to(dtype=sample.dtype)
455
+
456
+ emb = self.time_embedding(t_emb)
457
+
458
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
459
+ time_embeds = time_embeds.reshape((batch_size, -1))
460
+ time_embeds = time_embeds.to(emb.dtype)
461
+ aug_emb = self.add_embedding(time_embeds)
462
+ emb = emb + aug_emb
463
+
464
+ # Flatten the batch and frames dimensions
465
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
466
+ sample = sample.flatten(0, 1)
467
+ # Repeat the embeddings num_video_frames times
468
+ # emb: [batch, channels] -> [batch * frames, channels]
469
+ emb = emb.repeat_interleave(num_frames, dim=0)
470
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
471
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
472
+
473
+ # 2. pre-process
474
+ sample = self.conv_in(sample)
475
+
476
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
477
+
478
+ is_adapter = is_controlnet = False
479
+ if (down_block_additional_residuals is not None):
480
+ if (mid_block_additional_residual is not None):
481
+ is_controlnet = True
482
+ else:
483
+ is_adapter = True
484
+
485
+ down_block_res_samples = (sample,)
486
+ for block_idx, downsample_block in enumerate(self.down_blocks):
487
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
488
+ # print('has_cross_attention', type(downsample_block))
489
+ # models_diffusers.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal
490
+
491
+ additional_residuals = {}
492
+ if is_adapter and len(down_block_additional_residuals) > 0:
493
+ additional_residuals['additional_residuals'] = down_block_additional_residuals.pop(0)
494
+
495
+ sample, res_samples = downsample_block(
496
+ hidden_states=sample,
497
+ temb=emb,
498
+ encoder_hidden_states=encoder_hidden_states,
499
+ image_only_indicator=image_only_indicator,
500
+ **additional_residuals,
501
+ )
502
+ else:
503
+ # print('no_cross_attention', type(downsample_block))
504
+ # models_diffusers.unet_3d_blocks.DownBlockSpatioTemporal
505
+
506
+ sample, res_samples = downsample_block(
507
+ hidden_states=sample,
508
+ temb=emb,
509
+ image_only_indicator=image_only_indicator,
510
+ )
511
+
512
+ if is_adapter and len(down_block_additional_residuals) > 0:
513
+ additional_residuals = down_block_additional_residuals.pop(0)
514
+ if sample.dim() == 5:
515
+ additional_residuals = rearrange(additional_residuals, '(b f) c h w -> b c f h w', b=sample.shape[0])
516
+ sample = sample + additional_residuals
517
+
518
+ down_block_res_samples += res_samples
519
+
520
+ if is_controlnet:
521
+ new_down_block_res_samples = ()
522
+
523
+ for down_block_res_sample, down_block_additional_residual in zip(down_block_res_samples, down_block_additional_residuals):
524
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
525
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
526
+
527
+ down_block_res_samples = new_down_block_res_samples
528
+
529
+ # 4. mid
530
+ sample = self.mid_block(
531
+ hidden_states=sample,
532
+ temb=emb,
533
+ encoder_hidden_states=encoder_hidden_states,
534
+ image_only_indicator=image_only_indicator,
535
+ )
536
+
537
+ if is_controlnet:
538
+ sample = sample + mid_block_additional_residual
539
+
540
+ # if return_intermediate_features:
541
+ intermediate_features = []
542
+
543
+ # 5. up
544
+ for block_idx, upsample_block in enumerate(self.up_blocks):
545
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
546
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
547
+
548
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
549
+ sample = upsample_block(
550
+ hidden_states=sample,
551
+ temb=emb,
552
+ res_hidden_states_tuple=res_samples,
553
+ encoder_hidden_states=encoder_hidden_states,
554
+ image_only_indicator=image_only_indicator,
555
+ )
556
+ else:
557
+ sample = upsample_block(
558
+ hidden_states=sample,
559
+ temb=emb,
560
+ res_hidden_states_tuple=res_samples,
561
+ image_only_indicator=image_only_indicator,
562
+ )
563
+
564
+ # if return_intermediate_features:
565
+ intermediate_features.append(sample)
566
+
567
+ # 6. post-process
568
+ sample = self.conv_norm_out(sample)
569
+ sample = self.conv_act(sample)
570
+ sample = self.conv_out(sample)
571
+
572
+ # 7. Reshape back to original shape
573
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
574
+
575
+ if not return_dict:
576
+ return (sample, intermediate_features)
577
+
578
+ return UNetSpatioTemporalConditionOutput(
579
+ sample=sample,
580
+ intermediate_features=intermediate_features,
581
+ )
582
+
583
+ @classmethod
584
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], custom_resume=False, **kwargs):
585
+ r"""
586
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
587
+
588
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
589
+ train the model, set it back in training mode with `model.train()`.
590
+
591
+ Parameters:
592
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
593
+ Can be either:
594
+
595
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
596
+ the Hub.
597
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
598
+ with [`~ModelMixin.save_pretrained`].
599
+
600
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
601
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
602
+ is not used.
603
+ torch_dtype (`str` or `torch.dtype`, *optional*):
604
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
605
+ dtype is automatically derived from the model's weights.
606
+ force_download (`bool`, *optional*, defaults to `False`):
607
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
608
+ cached versions if they exist.
609
+ resume_download (`bool`, *optional*, defaults to `False`):
610
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
611
+ incompletely downloaded files are deleted.
612
+ proxies (`Dict[str, str]`, *optional*):
613
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
614
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
615
+ output_loading_info (`bool`, *optional*, defaults to `False`):
616
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
617
+ local_files_only(`bool`, *optional*, defaults to `False`):
618
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
619
+ won't be downloaded from the Hub.
620
+ use_auth_token (`str` or *bool*, *optional*):
621
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
622
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
623
+ revision (`str`, *optional*, defaults to `"main"`):
624
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
625
+ allowed by Git.
626
+ from_flax (`bool`, *optional*, defaults to `False`):
627
+ Load the model weights from a Flax checkpoint save file.
628
+ subfolder (`str`, *optional*, defaults to `""`):
629
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
630
+ mirror (`str`, *optional*):
631
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
632
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
633
+ information.
634
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
635
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
636
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
637
+ same device.
638
+
639
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
640
+ more information about each option see [designing a device
641
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
642
+ max_memory (`Dict`, *optional*):
643
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
644
+ each GPU and the available CPU RAM if unset.
645
+ offload_folder (`str` or `os.PathLike`, *optional*):
646
+ The path to offload weights if `device_map` contains the value `"disk"`.
647
+ offload_state_dict (`bool`, *optional*):
648
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
649
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
650
+ when there is some disk offload.
651
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
652
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
653
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
654
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
655
+ argument to `True` will raise an error.
656
+ variant (`str`, *optional*):
657
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
658
+ loading `from_flax`.
659
+ use_safetensors (`bool`, *optional*, defaults to `None`):
660
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
661
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
662
+ weights. If set to `False`, `safetensors` weights are not loaded.
663
+
664
+ <Tip>
665
+
666
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
667
+ `huggingface-cli login`. You can also activate the special
668
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
669
+ firewalled environment.
670
+
671
+ </Tip>
672
+
673
+ Example:
674
+
675
+ ```py
676
+ from diffusers import UNet2DConditionModel
677
+
678
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
679
+ ```
680
+
681
+ If you get the error message below, you need to finetune the weights for your downstream task:
682
+
683
+ ```bash
684
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
685
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
686
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
687
+ ```
688
+ """
689
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
690
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
691
+ force_download = kwargs.pop("force_download", False)
692
+ from_flax = kwargs.pop("from_flax", False)
693
+ resume_download = kwargs.pop("resume_download", False)
694
+ proxies = kwargs.pop("proxies", None)
695
+ output_loading_info = kwargs.pop("output_loading_info", False)
696
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
697
+ use_auth_token = kwargs.pop("use_auth_token", None)
698
+ revision = kwargs.pop("revision", None)
699
+ torch_dtype = kwargs.pop("torch_dtype", None)
700
+ subfolder = kwargs.pop("subfolder", None)
701
+ device_map = kwargs.pop("device_map", None)
702
+ max_memory = kwargs.pop("max_memory", None)
703
+ offload_folder = kwargs.pop("offload_folder", None)
704
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
705
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
706
+ variant = kwargs.pop("variant", None)
707
+ use_safetensors = kwargs.pop("use_safetensors", None)
708
+
709
+ allow_pickle = False
710
+ if use_safetensors is None:
711
+ use_safetensors = True
712
+ allow_pickle = True
713
+
714
+ if low_cpu_mem_usage and not is_accelerate_available():
715
+ low_cpu_mem_usage = False
716
+ logger.warning(
717
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
718
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
719
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
720
+ " install accelerate\n```\n."
721
+ )
722
+
723
+ if device_map is not None and not is_accelerate_available():
724
+ raise NotImplementedError(
725
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
726
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
727
+ )
728
+
729
+ # Check if we can handle device_map and dispatching the weights
730
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
731
+ raise NotImplementedError(
732
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
733
+ " `device_map=None`."
734
+ )
735
+
736
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
737
+ raise NotImplementedError(
738
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
739
+ " `low_cpu_mem_usage=False`."
740
+ )
741
+
742
+ if low_cpu_mem_usage is False and device_map is not None:
743
+ raise ValueError(
744
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
745
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
746
+ )
747
+
748
+ # Load config if we don't provide a configuration
749
+ config_path = pretrained_model_name_or_path
750
+
751
+ user_agent = {
752
+ "diffusers": __version__,
753
+ "file_type": "model",
754
+ "framework": "pytorch",
755
+ }
756
+
757
+ # load config
758
+ config, unused_kwargs, commit_hash = cls.load_config(
759
+ config_path,
760
+ cache_dir=cache_dir,
761
+ return_unused_kwargs=True,
762
+ return_commit_hash=True,
763
+ force_download=force_download,
764
+ resume_download=resume_download,
765
+ proxies=proxies,
766
+ local_files_only=local_files_only,
767
+ use_auth_token=use_auth_token,
768
+ revision=revision,
769
+ subfolder=subfolder,
770
+ device_map=device_map,
771
+ max_memory=max_memory,
772
+ offload_folder=offload_folder,
773
+ offload_state_dict=offload_state_dict,
774
+ user_agent=user_agent,
775
+ **kwargs,
776
+ )
777
+
778
+ if not custom_resume:
779
+ # NOTE: update in_channels, for additional mask concatentation
780
+ config['in_channels'] = config['in_channels'] + 1
781
+
782
+ # load model
783
+ model_file = None
784
+ if from_flax:
785
+ model_file = _get_model_file(
786
+ pretrained_model_name_or_path,
787
+ weights_name=FLAX_WEIGHTS_NAME,
788
+ cache_dir=cache_dir,
789
+ force_download=force_download,
790
+ resume_download=resume_download,
791
+ proxies=proxies,
792
+ local_files_only=local_files_only,
793
+ use_auth_token=use_auth_token,
794
+ revision=revision,
795
+ subfolder=subfolder,
796
+ user_agent=user_agent,
797
+ commit_hash=commit_hash,
798
+ )
799
+ model = cls.from_config(config, **unused_kwargs)
800
+
801
+ # Convert the weights
802
+ from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
803
+
804
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
805
+ else:
806
+ if use_safetensors:
807
+ try:
808
+ model_file = _get_model_file(
809
+ pretrained_model_name_or_path,
810
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
811
+ cache_dir=cache_dir,
812
+ force_download=force_download,
813
+ resume_download=resume_download,
814
+ proxies=proxies,
815
+ local_files_only=local_files_only,
816
+ use_auth_token=use_auth_token,
817
+ revision=revision,
818
+ subfolder=subfolder,
819
+ user_agent=user_agent,
820
+ commit_hash=commit_hash,
821
+ )
822
+ except IOError as e:
823
+ if not allow_pickle:
824
+ raise e
825
+ pass
826
+ if model_file is None:
827
+ model_file = _get_model_file(
828
+ pretrained_model_name_or_path,
829
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
830
+ cache_dir=cache_dir,
831
+ force_download=force_download,
832
+ resume_download=resume_download,
833
+ proxies=proxies,
834
+ local_files_only=local_files_only,
835
+ use_auth_token=use_auth_token,
836
+ revision=revision,
837
+ subfolder=subfolder,
838
+ user_agent=user_agent,
839
+ commit_hash=commit_hash,
840
+ )
841
+
842
+ if low_cpu_mem_usage:
843
+ # Instantiate model with empty weights
844
+ with accelerate.init_empty_weights():
845
+ model = cls.from_config(config, **unused_kwargs)
846
+
847
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
848
+ if device_map is None:
849
+ param_device = "cpu"
850
+ state_dict = load_state_dict(model_file, variant=variant)
851
+
852
+ if not custom_resume:
853
+ # NOTE update conv_in_weight
854
+ conv_in_weight = state_dict['conv_in.weight']
855
+ assert conv_in_weight.shape == (320, 8, 3, 3)
856
+ conv_in_weight_new = torch.randn(320, 9, 3, 3).to(conv_in_weight.device).to(conv_in_weight.dtype)
857
+ conv_in_weight_new[:, :8, :, :] = conv_in_weight
858
+ state_dict['conv_in.weight'] = conv_in_weight_new
859
+
860
+ # NOTE add mask_token
861
+ mask_token = torch.randn(1, 1, 4, 1, 1).to(conv_in_weight.device).to(conv_in_weight.dtype)
862
+ state_dict["mask_token"] = mask_token
863
+
864
+ model._convert_deprecated_attention_blocks(state_dict)
865
+ # move the params from meta device to cpu
866
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
867
+ if len(missing_keys) > 0:
868
+ raise ValueError(
869
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
870
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
871
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
872
+ " those weights or else make sure your checkpoint file is correct."
873
+ )
874
+
875
+ unexpected_keys = load_model_dict_into_meta(
876
+ model,
877
+ state_dict,
878
+ device=param_device,
879
+ dtype=torch_dtype,
880
+ model_name_or_path=pretrained_model_name_or_path,
881
+ )
882
+
883
+ if cls._keys_to_ignore_on_load_unexpected is not None:
884
+ for pat in cls._keys_to_ignore_on_load_unexpected:
885
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
886
+
887
+ if len(unexpected_keys) > 0:
888
+ logger.warn(
889
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
890
+ )
891
+
892
+ else: # else let accelerate handle loading and dispatching.
893
+ # Load weights and dispatch according to the device_map
894
+ # by default the device_map is None and the weights are loaded on the CPU
895
+ try:
896
+ accelerate.load_checkpoint_and_dispatch(
897
+ model,
898
+ model_file,
899
+ device_map,
900
+ max_memory=max_memory,
901
+ offload_folder=offload_folder,
902
+ offload_state_dict=offload_state_dict,
903
+ dtype=torch_dtype,
904
+ )
905
+ except AttributeError as e:
906
+ # When using accelerate loading, we do not have the ability to load the state
907
+ # dict and rename the weight names manually. Additionally, accelerate skips
908
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
909
+ # (which look like they should be private variables?), so we can't use the standard hooks
910
+ # to rename parameters on load. We need to mimic the original weight names so the correct
911
+ # attributes are available. After we have loaded the weights, we convert the deprecated
912
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
913
+ # the weights so we don't have to do this again.
914
+
915
+ if "'Attention' object has no attribute" in str(e):
916
+ logger.warn(
917
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
918
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
919
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
920
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
921
+ " please also re-upload it or open a PR on the original repository."
922
+ )
923
+ model._temp_convert_self_to_deprecated_attention_blocks()
924
+ accelerate.load_checkpoint_and_dispatch(
925
+ model,
926
+ model_file,
927
+ device_map,
928
+ max_memory=max_memory,
929
+ offload_folder=offload_folder,
930
+ offload_state_dict=offload_state_dict,
931
+ dtype=torch_dtype,
932
+ )
933
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
934
+ else:
935
+ raise e
936
+
937
+ loading_info = {
938
+ "missing_keys": [],
939
+ "unexpected_keys": [],
940
+ "mismatched_keys": [],
941
+ "error_msgs": [],
942
+ }
943
+ else:
944
+ model = cls.from_config(config, **unused_kwargs)
945
+
946
+ state_dict = load_state_dict(model_file, variant=variant)
947
+ model._convert_deprecated_attention_blocks(state_dict)
948
+
949
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
950
+ model,
951
+ state_dict,
952
+ model_file,
953
+ pretrained_model_name_or_path,
954
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
955
+ )
956
+
957
+ loading_info = {
958
+ "missing_keys": missing_keys,
959
+ "unexpected_keys": unexpected_keys,
960
+ "mismatched_keys": mismatched_keys,
961
+ "error_msgs": error_msgs,
962
+ }
963
+
964
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
965
+ raise ValueError(
966
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
967
+ )
968
+ elif torch_dtype is not None:
969
+ model = model.to(torch_dtype)
970
+
971
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
972
+
973
+ # Set model in evaluation mode to deactivate DropOut modules by default
974
+ model.eval()
975
+ if output_loading_info:
976
+ return model, loading_info
977
+
978
+ return model
models_diffusers/utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import warnings
8
+ import numpy as np
9
+ import cv2
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def gen_gaussian_heatmap(imgSize=200):
16
+ circle_img = np.zeros((imgSize, imgSize), np.float32)
17
+ circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1)
18
+
19
+ isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)
20
+
21
+ # Guass Map
22
+ for i in range(imgSize):
23
+ for j in range(imgSize):
24
+ isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp(
25
+ -1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2)))
26
+
27
+ isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
28
+ isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
29
+ isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8)
30
+
31
+ # isotropicGrayscaleImage = cv2.resize(isotropicGrayscaleImage, (40, 40))
32
+ return isotropicGrayscaleImage
33
+
34
+
35
+ def draw_heatmap(img, center_coordinate, heatmap_template, side, width, height):
36
+ x1 = max(center_coordinate[0] - side, 1)
37
+ x2 = min(center_coordinate[0] + side, width - 1)
38
+ y1 = max(center_coordinate[1] - side, 1)
39
+ y2 = min(center_coordinate[1] + side, height - 1)
40
+ x1, x2, y1, y2 = int(x1), int(x2), int(y1), int(y2)
41
+
42
+ if (x2 - x1) < 1 or (y2 - y1) < 1:
43
+ print(center_coordinate, "x1, x2, y1, y2", x1, x2, y1, y2)
44
+ return img
45
+
46
+ need_map = cv2.resize(heatmap_template, (x2-x1, y2-y1))
47
+
48
+ img[y1:y2,x1:x2] = need_map
49
+
50
+ return img
51
+
52
+
53
+ def generate_gassian_heatmap(pred_tracks, pred_visibility=None, image_size=None, side=20):
54
+ width, height = image_size
55
+ num_frames, num_points = pred_tracks.shape[:2]
56
+
57
+ point_index_list = [point_idx for point_idx in range(num_points)]
58
+ heatmap_template = gen_gaussian_heatmap()
59
+
60
+
61
+ image_list = []
62
+ for frame_idx in range(num_frames):
63
+
64
+ img = np.zeros((height, width), np.float32)
65
+ for point_idx in point_index_list:
66
+ px, py = pred_tracks[frame_idx, point_idx]
67
+
68
+ if px < 0 or py < 0 or px >= width or py >= height:
69
+ if (frame_idx == 0) or (frame_idx == num_frames - 1):
70
+ print(frame_idx, point_idx, px, py)
71
+ continue
72
+
73
+ if pred_visibility is not None:
74
+ if (not pred_visibility[frame_idx, point_idx]):
75
+ continue
76
+
77
+ img = draw_heatmap(img, (px, py), heatmap_template, side, width, height)
78
+
79
+ img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_GRAY2RGB)
80
+ img = torch.from_numpy(img).permute(2, 0, 1).contiguous()
81
+ image_list.append(img)
82
+
83
+ video_gaussion_map = torch.stack(image_list, dim=0)
84
+
85
+ return video_gaussion_map
pipelines/pipeline_stable_video_diffusion_interp_control.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from dataclasses import dataclass
17
+ from typing import Callable, Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
24
+
25
+ from diffusers.image_processor import VaeImageProcessor
26
+ # from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
27
+ from diffusers.models import AutoencoderKLTemporalDecoder
28
+ from models_diffusers.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
29
+ from diffusers.schedulers import EulerDiscreteScheduler
30
+ from diffusers.utils import BaseOutput, logging
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
33
+
34
+ from models_diffusers.controlnet_svd import ControlNetSVDModel
35
+ # from cotracker.predictor import CoTrackerPredictor, sample_trajectories, generate_gassian_heatmap
36
+ from models_diffusers.utils import generate_gassian_heatmap
37
+
38
+ from einops import rearrange
39
+ from models_diffusers.sift_match import point_tracking, interpolate_trajectory
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ def _append_dims(x, target_dims):
46
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
47
+ dims_to_append = target_dims - x.ndim
48
+ if dims_to_append < 0:
49
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
50
+ return x[(...,) + (None,) * dims_to_append]
51
+
52
+
53
+ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
54
+ # Based on:
55
+ # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
56
+
57
+ batch_size, channels, num_frames, height, width = video.shape
58
+ outputs = []
59
+ for batch_idx in range(batch_size):
60
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
61
+ batch_output = processor.postprocess(batch_vid, output_type)
62
+
63
+ outputs.append(batch_output)
64
+
65
+ return outputs
66
+
67
+
68
+ @dataclass
69
+ class StableVideoDiffusionInterpControlPipelineOutput(BaseOutput):
70
+ r"""
71
+ Output class for zero-shot text-to-video pipeline.
72
+
73
+ Args:
74
+ frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
75
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
76
+ num_channels)`.
77
+ """
78
+
79
+ frames: Union[List[PIL.Image.Image], np.ndarray]
80
+
81
+
82
+ class StableVideoDiffusionInterpControlPipeline(DiffusionPipeline):
83
+ r"""
84
+ Pipeline to generate video from an input image using Stable Video Diffusion.
85
+
86
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
87
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
88
+
89
+ Args:
90
+ vae ([`AutoencoderKL`]):
91
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
92
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
93
+ Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
94
+ unet ([`UNetSpatioTemporalConditionModel`]):
95
+ A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
96
+ scheduler ([`EulerDiscreteScheduler`]):
97
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
98
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
99
+ A `CLIPImageProcessor` to extract features from generated images.
100
+ """
101
+
102
+ model_cpu_offload_seq = "image_encoder->unet->vae"
103
+ _callback_tensor_inputs = ["latents"]
104
+
105
+ def __init__(
106
+ self,
107
+ vae: AutoencoderKLTemporalDecoder,
108
+ image_encoder: CLIPVisionModelWithProjection,
109
+ unet: UNetSpatioTemporalConditionModel,
110
+ scheduler: EulerDiscreteScheduler,
111
+ feature_extractor: CLIPImageProcessor,
112
+ controlnet: Optional[ControlNetSVDModel] = None,
113
+ pose_encoder: Optional[torch.nn.Module] = None,
114
+ ):
115
+ super().__init__()
116
+
117
+ self.register_modules(
118
+ vae=vae,
119
+ image_encoder=image_encoder,
120
+ unet=unet,
121
+ scheduler=scheduler,
122
+ feature_extractor=feature_extractor,
123
+ controlnet=controlnet,
124
+ pose_encoder=pose_encoder,
125
+ )
126
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
127
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
128
+
129
+ def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
130
+ dtype = next(self.image_encoder.parameters()).dtype
131
+
132
+ if not isinstance(image, torch.Tensor):
133
+ image = self.image_processor.pil_to_numpy(image)
134
+ image = self.image_processor.numpy_to_pt(image)
135
+
136
+ # We normalize the image before resizing to match with the original implementation.
137
+ # Then we unnormalize it after resizing.
138
+ image = image * 2.0 - 1.0
139
+ image = _resize_with_antialiasing(image, (224, 224))
140
+ image = (image + 1.0) / 2.0
141
+
142
+ # Normalize the image with for CLIP input
143
+ image = self.feature_extractor(
144
+ images=image,
145
+ do_normalize=True,
146
+ do_center_crop=False,
147
+ do_resize=False,
148
+ do_rescale=False,
149
+ return_tensors="pt",
150
+ ).pixel_values
151
+
152
+ image = image.to(device=device, dtype=dtype)
153
+ image_embeddings = self.image_encoder(image).image_embeds
154
+ image_embeddings = image_embeddings.unsqueeze(1)
155
+
156
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
157
+ bs_embed, seq_len, _ = image_embeddings.shape
158
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
159
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
160
+
161
+ if do_classifier_free_guidance:
162
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
163
+
164
+ # For classifier free guidance, we need to do two forward passes.
165
+ # Here we concatenate the unconditional and text embeddings into a single batch
166
+ # to avoid doing two forward passes
167
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
168
+
169
+ return image_embeddings
170
+
171
+ def _encode_vae_image(
172
+ self,
173
+ image: torch.Tensor,
174
+ device,
175
+ num_videos_per_prompt,
176
+ do_classifier_free_guidance,
177
+ ):
178
+ image = image.to(device=device)
179
+ image_latents = self.vae.encode(image).latent_dist.mode()
180
+
181
+ if do_classifier_free_guidance:
182
+ negative_image_latents = torch.zeros_like(image_latents)
183
+
184
+ # For classifier free guidance, we need to do two forward passes.
185
+ # Here we concatenate the unconditional and text embeddings into a single batch
186
+ # to avoid doing two forward passes
187
+ image_latents = torch.cat([negative_image_latents, image_latents])
188
+
189
+ # duplicate image_latents for each generation per prompt, using mps friendly method
190
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
191
+
192
+ return image_latents
193
+
194
+ def _get_add_time_ids(
195
+ self,
196
+ fps,
197
+ motion_bucket_id,
198
+ noise_aug_strength,
199
+ dtype,
200
+ batch_size,
201
+ num_videos_per_prompt,
202
+ do_classifier_free_guidance,
203
+ ):
204
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
205
+
206
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
207
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
208
+
209
+ if expected_add_embed_dim != passed_add_embed_dim:
210
+ raise ValueError(
211
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
212
+ )
213
+
214
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
215
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
216
+
217
+ if do_classifier_free_guidance:
218
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
219
+
220
+ return add_time_ids
221
+
222
+ def decode_latents(self, latents, num_frames, decode_chunk_size=14):
223
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
224
+ latents = latents.flatten(0, 1)
225
+
226
+ latents = 1 / self.vae.config.scaling_factor * latents
227
+
228
+ accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys())
229
+
230
+ # decode decode_chunk_size frames at a time to avoid OOM
231
+ frames = []
232
+ for i in range(0, latents.shape[0], decode_chunk_size):
233
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
234
+ decode_kwargs = {}
235
+ if accepts_num_frames:
236
+ # we only pass num_frames_in if it's expected
237
+ decode_kwargs["num_frames"] = num_frames_in
238
+
239
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
240
+ frames.append(frame)
241
+ frames = torch.cat(frames, dim=0)
242
+
243
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
244
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
245
+
246
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
247
+ frames = frames.float()
248
+ return frames
249
+
250
+ def check_inputs(self, image, height, width):
251
+ if (
252
+ not isinstance(image, torch.Tensor)
253
+ and not isinstance(image, PIL.Image.Image)
254
+ and not isinstance(image, list)
255
+ ):
256
+ raise ValueError(
257
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
258
+ f" {type(image)}"
259
+ )
260
+
261
+ if height % 8 != 0 or width % 8 != 0:
262
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
263
+
264
+ def prepare_latents(
265
+ self,
266
+ batch_size,
267
+ num_frames,
268
+ num_channels_latents,
269
+ height,
270
+ width,
271
+ dtype,
272
+ device,
273
+ generator,
274
+ latents=None,
275
+ ):
276
+ shape = (
277
+ batch_size,
278
+ num_frames,
279
+ num_channels_latents // 2,
280
+ height // self.vae_scale_factor,
281
+ width // self.vae_scale_factor,
282
+ )
283
+ if isinstance(generator, list) and len(generator) != batch_size:
284
+ raise ValueError(
285
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
286
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
287
+ )
288
+
289
+ if latents is None:
290
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
291
+ else:
292
+ latents = latents.to(device)
293
+
294
+ # scale the initial noise by the standard deviation required by the scheduler
295
+ latents = latents * self.scheduler.init_noise_sigma
296
+ return latents
297
+
298
+ @property
299
+ def guidance_scale(self):
300
+ return self._guidance_scale
301
+
302
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
303
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
304
+ # corresponds to doing no classifier free guidance.
305
+ @property
306
+ def do_classifier_free_guidance(self):
307
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
308
+
309
+ @property
310
+ def num_timesteps(self):
311
+ return self._num_timesteps
312
+
313
+ @torch.no_grad()
314
+ def __call__(
315
+ self,
316
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
317
+ image_end: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
318
+ # for points
319
+ with_control: bool = True,
320
+ point_tracks: Optional[torch.FloatTensor] = None,
321
+ point_embedding: Optional[torch.FloatTensor] = None,
322
+ with_id_feature: bool = False, # NOTE: whether to use the id feature for controlnet
323
+ controlnet_cond_scale: float = 1.0,
324
+ controlnet_step_range: List[float] = [0, 1],
325
+ # others
326
+ height: int = 576,
327
+ width: int = 1024,
328
+ num_frames: Optional[int] = None,
329
+ num_inference_steps: int = 25,
330
+ min_guidance_scale: float = 1.0,
331
+ max_guidance_scale: float = 3.0,
332
+ middle_max_guidance: bool = False,
333
+ fps: int = 6,
334
+ motion_bucket_id: int = 127,
335
+ noise_aug_strength: int = 0.02,
336
+ decode_chunk_size: Optional[int] = None,
337
+ num_videos_per_prompt: Optional[int] = 1,
338
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
339
+ latents: Optional[torch.FloatTensor] = None,
340
+ output_type: Optional[str] = "pil",
341
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
342
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
343
+ return_dict: bool = True,
344
+ # update track
345
+ sift_track_update: bool = False,
346
+ sift_track_update_with_time: bool = True,
347
+ sift_track_feat_idx: List[int] = [2, ],
348
+ sift_track_dist: int = 5,
349
+ sift_track_double_check_thr: float = 2,
350
+ anchor_points_flag: Optional[torch.FloatTensor] = None,
351
+ ):
352
+ r"""
353
+ The call function to the pipeline for generation.
354
+
355
+ Args:
356
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
357
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
358
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
359
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
360
+ The height in pixels of the generated image.
361
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
362
+ The width in pixels of the generated image.
363
+ num_frames (`int`, *optional*):
364
+ The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
365
+ num_inference_steps (`int`, *optional*, defaults to 25):
366
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
367
+ expense of slower inference. This parameter is modulated by `strength`.
368
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
369
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
370
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
371
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
372
+ fps (`int`, *optional*, defaults to 7):
373
+ Frames per second. The rate at which the generated images shall be exported to a video after generation.
374
+ Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
375
+ motion_bucket_id (`int`, *optional*, defaults to 127):
376
+ The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
377
+ noise_aug_strength (`int`, *optional*, defaults to 0.02):
378
+ The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
379
+ decode_chunk_size (`int`, *optional*):
380
+ The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
381
+ between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
382
+ for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
383
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
384
+ The number of images to generate per prompt.
385
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
386
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
387
+ generation deterministic.
388
+ latents (`torch.FloatTensor`, *optional*):
389
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
390
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
391
+ tensor is generated by sampling using the supplied random `generator`.
392
+ output_type (`str`, *optional*, defaults to `"pil"`):
393
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
394
+ callback_on_step_end (`Callable`, *optional*):
395
+ A function that calls at the end of each denoising steps during the inference. The function is called
396
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
397
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
398
+ `callback_on_step_end_tensor_inputs`.
399
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
400
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
401
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
402
+ `._callback_tensor_inputs` attribute of your pipeline class.
403
+ return_dict (`bool`, *optional*, defaults to `True`):
404
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
405
+ plain tuple.
406
+
407
+ Returns:
408
+ [`~pipelines.stable_diffusion.StableVideoDiffusionInterpControlPipelineOutput`] or `tuple`:
409
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionInterpControlPipelineOutput`] is returned,
410
+ otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
411
+
412
+ Examples:
413
+
414
+ ```py
415
+ from diffusers import StableVideoDiffusionPipeline
416
+ from diffusers.utils import load_image, export_to_video
417
+
418
+ pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
419
+ pipe.to("cuda")
420
+
421
+ image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
422
+ image = image.resize((1024, 576))
423
+
424
+ frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
425
+ export_to_video(frames, "generated.mp4", fps=7)
426
+ ```
427
+ """
428
+ # 0. Default height and width to unet
429
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
430
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
431
+
432
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
433
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
434
+
435
+ # 1. Check inputs. Raise error if not correct
436
+ self.check_inputs(image, height, width)
437
+ self.check_inputs(image_end, height, width)
438
+
439
+ # 2. Define call parameters
440
+ if isinstance(image, PIL.Image.Image):
441
+ batch_size = 1
442
+ elif isinstance(image, list):
443
+ batch_size = len(image)
444
+ else:
445
+ batch_size = image.shape[0]
446
+ device = self._execution_device
447
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
448
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
449
+ # corresponds to doing no classifier free guidance.
450
+ do_classifier_free_guidance = max_guidance_scale > 1.0
451
+
452
+ # 3. Encode input image
453
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
454
+ image_end_embeddings = self._encode_image(image_end, device, num_videos_per_prompt, do_classifier_free_guidance)
455
+
456
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
457
+ # is why it is reduced here.
458
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
459
+ fps = fps - 1
460
+
461
+ # 4. Encode input image using VAE
462
+ image = self.image_processor.preprocess(image, height=height, width=width)
463
+ noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
464
+ image = image + noise_aug_strength * noise
465
+ # also for image_end
466
+ image_end = self.image_processor.preprocess(image_end, height=height, width=width)
467
+ noise = randn_tensor(image_end.shape, generator=generator, device=image_end.device, dtype=image_end.dtype)
468
+ image_end = image_end + noise_aug_strength * noise
469
+
470
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
471
+ if needs_upcasting:
472
+ self.vae.to(dtype=torch.float32)
473
+
474
+ if with_control:
475
+ # create controlnet input
476
+ video_gaussion_map = generate_gassian_heatmap(point_tracks, image_size=(width, height))
477
+ controlnet_image = video_gaussion_map.unsqueeze(0) # (1, f, c, h, w)
478
+ controlnet_image = controlnet_image.to(device, dtype=image_embeddings.dtype)
479
+ controlnet_image = torch.cat([controlnet_image] * 2, dim=0)
480
+
481
+ point_embedding = point_embedding.to(device).to(image_embeddings.dtype) if point_embedding is not None else None
482
+ point_tracks = point_tracks.to(device).to(image_embeddings.dtype) # (f, p, 2)
483
+
484
+ assert point_tracks.shape[0] == num_frames, f"point_tracks.shape[0] != num_frames, {point_tracks.shape[0]} != {num_frames}"
485
+ # if point_tracks.shape[0] != num_frames:
486
+ # # interpolate the point_tracks to the number of frames
487
+ # point_tracks = rearrange(point_tracks[None], 'b f p c -> b p f c')
488
+ # point_tracks = torch.nn.functional.interpolate(point_tracks, size=(num_frames, point_tracks.shape[-1]), mode='bilinear', align_corners=False)[0]
489
+ # point_tracks = rearrange(point_tracks, 'p f c -> f p c')
490
+
491
+ image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
492
+ image_latents = image_latents.to(image_embeddings.dtype)
493
+ # also for image_end
494
+ image_end_latents = self._encode_vae_image(image_end, device, num_videos_per_prompt, do_classifier_free_guidance)
495
+ image_end_latents = image_end_latents.to(image_end_embeddings.dtype)
496
+
497
+ # cast back to fp16 if needed
498
+ if needs_upcasting:
499
+ self.vae.to(dtype=torch.float16)
500
+
501
+ # Repeat the image latents for each frame so we can concatenate them with the noise
502
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
503
+ # image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
504
+
505
+ # 5. Get Added Time IDs
506
+ added_time_ids = self._get_add_time_ids(
507
+ fps,
508
+ motion_bucket_id,
509
+ noise_aug_strength,
510
+ image_embeddings.dtype,
511
+ batch_size,
512
+ num_videos_per_prompt,
513
+ do_classifier_free_guidance,
514
+ )
515
+ added_time_ids = added_time_ids.to(device)
516
+
517
+ # 4. Prepare timesteps
518
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
519
+ timesteps = self.scheduler.timesteps
520
+
521
+ # 5. Prepare latent variables
522
+ num_channels_latents = self.unet.config.in_channels
523
+ latents = self.prepare_latents(
524
+ batch_size * num_videos_per_prompt,
525
+ num_frames,
526
+ num_channels_latents,
527
+ height,
528
+ width,
529
+ image_embeddings.dtype,
530
+ device,
531
+ generator,
532
+ latents,
533
+ )
534
+
535
+ # Concatenate the `conditional_latents` with the `noisy_latents`.
536
+ # conditional_latents = conditional_latents.unsqueeze(1).repeat(1, noisy_latents.shape[1], 1, 1, 1)
537
+ image_latents = image_latents.unsqueeze(1) # (1, 1, 4, h, w)
538
+ bsz, num_frames, _, latent_h, latent_w = latents.shape
539
+ bsz_cfg = bsz * 2
540
+ mask_token = self.unet.mask_token
541
+ conditional_latents_mask = mask_token.repeat(bsz_cfg, num_frames-2, 1, latent_h, latent_w)
542
+ image_end_latents = image_end_latents.unsqueeze(1)
543
+ image_latents = torch.cat([image_latents, conditional_latents_mask, image_end_latents], dim=1)
544
+
545
+ # Concatenate additional mask channel
546
+ mask_channel = torch.ones_like(image_latents[:, :, 0:1, :, :])
547
+ mask_channel[:, 0:1, :, :, :] = 0
548
+ mask_channel[:, -1:, :, :, :] = 0
549
+ image_latents = torch.cat([image_latents, mask_channel], dim=2)
550
+
551
+ # concate the conditions
552
+ image_embeddings = torch.cat([image_embeddings, image_end_embeddings], dim=1)
553
+
554
+ # 7. Prepare guidance scale
555
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) # (1, 14)
556
+ if middle_max_guidance:
557
+ # big in middle, small at the beginning and end
558
+ guidance_scale = torch.cat([guidance_scale, guidance_scale.flip(1)], dim=1)
559
+ # interpolate the guidance scale, from [1, 2*frames] to [1, frames]
560
+ guidance_scale = torch.nn.functional.interpolate(guidance_scale.unsqueeze(0), size=num_frames, mode='linear', align_corners=False)[0]
561
+
562
+
563
+ guidance_scale = guidance_scale.to(device, latents.dtype)
564
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
565
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
566
+
567
+ self._guidance_scale = guidance_scale
568
+
569
+ # 9. Denoising loop
570
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
571
+ self._num_timesteps = len(timesteps)
572
+
573
+ if with_control and sift_track_update:
574
+ num_tracks = point_tracks.shape[1]
575
+ anchor_point_dict = {}
576
+ for frame_idx in range(num_frames):
577
+ anchor_point_dict[frame_idx] = {}
578
+ for point_idx in range(num_tracks):
579
+ # add the start and end point
580
+ if frame_idx in [0, num_frames - 1]:
581
+ anchor_point_dict[frame_idx][point_idx] = point_tracks[frame_idx][point_idx]
582
+ else:
583
+ anchor_point_dict[frame_idx][point_idx] = None
584
+
585
+ with_control_global = with_control
586
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
587
+ for i, t in enumerate(timesteps):
588
+
589
+ # NOTE: set the range for control
590
+ if with_control_global:
591
+ if controlnet_step_range[0] <= i / num_inference_steps < controlnet_step_range[1]:
592
+ with_control = True
593
+ else:
594
+ with_control = False
595
+ # print(f"step={i / num_inference_steps}, with_control={with_control}")
596
+
597
+ if with_control and sift_track_update and i > 0:
598
+ # update the point tracks
599
+ track_list = []
600
+ for point_idx in range(num_tracks):
601
+ # get the anchor points
602
+ current_track = []
603
+ current_time_to_interp = []
604
+ for frame_idx in range(num_frames):
605
+ if anchor_points_flag[frame_idx][point_idx] == 1:
606
+ current_track.append(anchor_point_dict[frame_idx][point_idx].cpu())
607
+ if sift_track_update_with_time:
608
+ current_time_to_interp.append(frame_idx / (num_frames - 1))
609
+
610
+ current_track = torch.stack(current_track, dim=0).unsqueeze(1) # (f, 1, 2)
611
+ # interpolate the anchor points to obtain trajectory
612
+ current_time_to_interp = np.array(current_time_to_interp) if sift_track_update_with_time else None
613
+ current_track = interpolate_trajectory(current_track, num_frames=num_frames, t=current_time_to_interp)
614
+ track_list.append(current_track)
615
+ point_tracks = torch.concat(track_list, dim=1).to(device).to(image_embeddings.dtype) # (f, p, 2)
616
+
617
+ # expand the latents if we are doing classifier free guidance
618
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
619
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
620
+
621
+ # Concatenate image_latents over channels dimention
622
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
623
+
624
+ down_block_res_samples = mid_block_res_sample = None
625
+ if with_control:
626
+ if i == 0:
627
+ print(f"controlnet_cond_scale: {controlnet_cond_scale}")
628
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
629
+ latent_model_input,
630
+ t,
631
+ encoder_hidden_states=image_embeddings,
632
+ controlnet_cond=controlnet_image,
633
+ added_time_ids=added_time_ids,
634
+ conditioning_scale=controlnet_cond_scale,
635
+ point_embedding=point_embedding if with_id_feature else None, # NOTE
636
+ point_tracks=point_tracks,
637
+ guess_mode=False,
638
+ return_dict=False,
639
+ )
640
+ else:
641
+ if i == 0:
642
+ print("Controlnet is not used")
643
+
644
+ kwargs = {}
645
+
646
+ outputs = self.unet(
647
+ latent_model_input,
648
+ t,
649
+ encoder_hidden_states=image_embeddings,
650
+ down_block_additional_residuals=down_block_res_samples,
651
+ mid_block_additional_residual=mid_block_res_sample,
652
+ added_time_ids=added_time_ids,
653
+ return_dict=False,
654
+ **kwargs,
655
+ )
656
+
657
+ noise_pred, intermediate_features = outputs
658
+
659
+ if with_control and sift_track_update:
660
+ # shape: [b*f, c, h, w], b=2 for cfg
661
+ matching_features = []
662
+ for feat_idx in sift_track_feat_idx:
663
+ feat = intermediate_features[feat_idx]
664
+ feat = F.interpolate(feat, (height, width), mode='bilinear')
665
+ matching_features.append(feat)
666
+
667
+ matching_features = torch.cat(matching_features, dim=1) # [b*f, c, h, w]
668
+
669
+ # shape: [b*f, c, h, w]
670
+ # self.guidance_scale: [1, f, 1, 1, 1]
671
+ # matching_features:
672
+ assert do_classifier_free_guidance
673
+ matching_features = rearrange(matching_features, '(b f) c h w -> b f c h w', b=2)
674
+
675
+ # # strategy 1: discard the unconditional branch feature maps
676
+ # matching_features = matching_features[1].unsqueeze(dim=0) # (b, f, c, h, w), b=1
677
+ # # strategy 2: concat pos and neg branch feature maps for motion-sup and point tracking
678
+ # matching_features = torch.cat([matching_features[0], matching_features[1]], dim=1).unsqueeze(dim=0) # (b, f, 2c, h, w), b=1
679
+ # # strategy 3: concat pos and neg branch feature maps with guidance_scale consideration
680
+ # coef = self.guidance_scale / (2 * self.guidance_scale - 1.0)
681
+ # coef = coef.squeeze(dim=0)
682
+ # matching_features = torch.cat(
683
+ # [(1 - coef) * matching_features[0], coef * matching_features[1]], dim=1,
684
+ # ).unsqueeze(dim=0) # (b, f, 2c, h, w), b=1
685
+ # strategy 4: same as cfg
686
+ matching_features = matching_features[0] + self.guidance_scale.squeeze(0) * (matching_features[1] - matching_features[0])
687
+ matching_features = matching_features.unsqueeze(dim=0) # (b, f, c, h, w), b=1
688
+
689
+ # perform point matching in intermediate frames
690
+ feature_start = matching_features[:, 0]
691
+ feature_end = matching_features[:, -1]
692
+ hanlde_points_start = point_tracks[0] # (f, p, 2) -> (p, 2)
693
+ hanlde_points_end = point_tracks[-1] # (f, p, 2) -> (p, 2)
694
+ for frame_idx in range(1, num_frames - 1):
695
+ feature_frame = matching_features[:, frame_idx]
696
+ handle_points = point_tracks[frame_idx] # (f, p, 2) -> (p, 2)
697
+ # forward matching
698
+ handle_points_forward = point_tracking(feature_start, feature_frame, handle_points, hanlde_points_start, sift_track_dist)
699
+ # backward matching
700
+ handle_points_backward = point_tracking(feature_end, feature_frame, handle_points, hanlde_points_end, sift_track_dist)
701
+
702
+ # bi-directional check
703
+ for point_idx, (point_forward, point_backward) in enumerate(zip(handle_points_forward, handle_points_backward)):
704
+ if torch.norm(point_forward - point_backward) < sift_track_double_check_thr:
705
+ # update the point
706
+ # point_tracks[frame_idx][point_idx] = (point_forward + point_backward) / 2
707
+ anchor_point_dict[frame_idx][point_idx] = (point_forward + point_backward) / 2
708
+ anchor_points_flag[frame_idx][point_idx] = 1
709
+
710
+ # perform guidance
711
+ if do_classifier_free_guidance:
712
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
713
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
714
+
715
+ # compute the previous noisy sample x_t -> x_t-1
716
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
717
+
718
+ if callback_on_step_end is not None:
719
+ callback_kwargs = {}
720
+ for k in callback_on_step_end_tensor_inputs:
721
+ callback_kwargs[k] = locals()[k]
722
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
723
+
724
+ latents = callback_outputs.pop("latents", latents)
725
+
726
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
727
+ progress_bar.update()
728
+
729
+ if not output_type == "latent":
730
+ # cast back to fp16 if needed
731
+ if needs_upcasting:
732
+ self.vae.to(dtype=torch.float16)
733
+ # self.vae.to(dtype=torch.float32)
734
+ # latents = latents.to(torch.float32)
735
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
736
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
737
+ else:
738
+ frames = latents
739
+
740
+ self.maybe_free_model_hooks()
741
+
742
+ if not return_dict:
743
+ return frames
744
+
745
+ return StableVideoDiffusionInterpControlPipelineOutput(frames=frames)
746
+
747
+
748
+ # resizing utils
749
+ # TODO: clean up later
750
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
751
+ h, w = input.shape[-2:]
752
+ factors = (h / size[0], w / size[1])
753
+
754
+ # First, we have to determine sigma
755
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
756
+ sigmas = (
757
+ max((factors[0] - 1.0) / 2.0, 0.001),
758
+ max((factors[1] - 1.0) / 2.0, 0.001),
759
+ )
760
+
761
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
762
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
763
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
764
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
765
+
766
+ # Make sure it is odd
767
+ if (ks[0] % 2) == 0:
768
+ ks = ks[0] + 1, ks[1]
769
+
770
+ if (ks[1] % 2) == 0:
771
+ ks = ks[0], ks[1] + 1
772
+
773
+ input = _gaussian_blur2d(input, ks, sigmas)
774
+
775
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
776
+ return output
777
+
778
+
779
+ def _compute_padding(kernel_size):
780
+ """Compute padding tuple."""
781
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
782
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
783
+ if len(kernel_size) < 2:
784
+ raise AssertionError(kernel_size)
785
+ computed = [k - 1 for k in kernel_size]
786
+
787
+ # for even kernels we need to do asymmetric padding :(
788
+ out_padding = 2 * len(kernel_size) * [0]
789
+
790
+ for i in range(len(kernel_size)):
791
+ computed_tmp = computed[-(i + 1)]
792
+
793
+ pad_front = computed_tmp // 2
794
+ pad_rear = computed_tmp - pad_front
795
+
796
+ out_padding[2 * i + 0] = pad_front
797
+ out_padding[2 * i + 1] = pad_rear
798
+
799
+ return out_padding
800
+
801
+
802
+ def _filter2d(input, kernel):
803
+ # prepare kernel
804
+ b, c, h, w = input.shape
805
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
806
+
807
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
808
+
809
+ height, width = tmp_kernel.shape[-2:]
810
+
811
+ padding_shape: list[int] = _compute_padding([height, width])
812
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
813
+
814
+ # kernel and input tensor reshape to align element-wise or batch-wise params
815
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
816
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
817
+
818
+ # convolve the tensor with the kernel.
819
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
820
+
821
+ out = output.view(b, c, h, w)
822
+ return out
823
+
824
+
825
+ def _gaussian(window_size: int, sigma):
826
+ if isinstance(sigma, float):
827
+ sigma = torch.tensor([[sigma]])
828
+
829
+ batch_size = sigma.shape[0]
830
+
831
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
832
+
833
+ if window_size % 2 == 0:
834
+ x = x + 0.5
835
+
836
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
837
+
838
+ return gauss / gauss.sum(-1, keepdim=True)
839
+
840
+
841
+ def _gaussian_blur2d(input, kernel_size, sigma):
842
+ if isinstance(sigma, tuple):
843
+ sigma = torch.tensor([sigma], dtype=input.dtype)
844
+ else:
845
+ sigma = sigma.to(dtype=input.dtype)
846
+
847
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
848
+ bs = sigma.shape[0]
849
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
850
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
851
+ out_x = _filter2d(input, kernel_x[..., None, :])
852
+ out = _filter2d(out_x, kernel_y[..., None])
853
+
854
+ return out
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1+cu116
2
+ torchvision==0.14.1+cu116
3
+ diffusers==0.24.0
4
+ transformers==4.27.0
5
+ xformers==0.0.16
6
+ imageio==2.27.0
7
+ decord==0.6.0
8
+ einops
9
+ triton==2.1.0
10
+ opencv-python
11
+ av
12
+ accelerate==0.27.2