shisheng7 commited on
Commit
bd6c4af
·
1 Parent(s): f7e8357

update home

Browse files
configs/inference/inference.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ train_bs: 4
3
+ val_bs: 1
4
+ train_width: 512
5
+ train_height: 512
6
+ fps: 25
7
+ sample_rate: 16000
8
+ n_motion_frames: 2
9
+ n_sample_frames: 16
10
+ audio_margin: 2
11
+ train_meta_paths:
12
+ - "./data/inference.json"
13
+
14
+ wav2vec_config:
15
+ audio_type: "vocals" # audio vocals
16
+ model_scale: "base" # base large
17
+ features: "all" # last avg all
18
+ model_path: ./pretrained_models/chinese-wav2vec2-base
19
+ audio_separator:
20
+ model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx
21
+ face_expand_ratio: 1.2
22
+
23
+ solver:
24
+ gradient_accumulation_steps: 1
25
+ mixed_precision: "no"
26
+ enable_xformers_memory_efficient_attention: True
27
+ gradient_checkpointing: True
28
+ max_train_steps: 30000
29
+ max_grad_norm: 1.0
30
+ # lr
31
+ learning_rate: 1e-5
32
+ scale_lr: False
33
+ lr_warmup_steps: 1
34
+ lr_scheduler: "constant"
35
+
36
+ # optimizer
37
+ use_8bit_adam: True
38
+ adam_beta1: 0.9
39
+ adam_beta2: 0.999
40
+ adam_weight_decay: 1.0e-2
41
+ adam_epsilon: 1.0e-8
42
+
43
+ val:
44
+ validation_steps: 1000
45
+
46
+ noise_scheduler_kwargs:
47
+ num_train_timesteps: 1000
48
+ beta_start: 0.00085
49
+ beta_end: 0.012
50
+ beta_schedule: "linear"
51
+ steps_offset: 1
52
+ clip_sample: false
53
+
54
+ unet_additional_kwargs:
55
+ use_inflated_groupnorm: true
56
+ unet_use_cross_frame_attention: false
57
+ unet_use_temporal_attention: false
58
+ use_motion_module: true
59
+ use_audio_module: true
60
+ motion_module_resolutions:
61
+ - 1
62
+ - 2
63
+ - 4
64
+ - 8
65
+ motion_module_mid_block: true
66
+ motion_module_decoder_only: false
67
+ motion_module_type: Vanilla
68
+ motion_module_kwargs:
69
+ num_attention_heads: 8
70
+ num_transformer_block: 1
71
+ attention_block_types:
72
+ - Temporal_Self
73
+ - Temporal_Self
74
+ temporal_position_encoding: true
75
+ temporal_position_encoding_max_len: 32
76
+ temporal_attention_dim_div: 1
77
+ audio_attention_dim: 768
78
+ stack_enable_blocks_name:
79
+ - "up"
80
+ - "down"
81
+ - "mid"
82
+ stack_enable_blocks_depth: [0,1,2,3]
83
+
84
+ trainable_para:
85
+ - audio_modules
86
+ - motion_modules
87
+
88
+ base_model_path: "./pretrained_models/stable-diffusion-v1-5"
89
+ vae_model_path: "./pretrained_models/sd-vae-ft-mse"
90
+ face_analysis_model_path: "./pretrained_models/face_analysis"
91
+ mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt"
92
+
93
+ weight_dtype: "fp16" # [fp16, fp32]
94
+ uncond_img_ratio: 0.05
95
+ uncond_audio_ratio: 0.05
96
+ uncond_ia_ratio: 0.05
97
+ start_ratio: 0.05
98
+ noise_offset: 0.05
99
+ snr_gamma: 5.0
100
+ enable_zero_snr: True
101
+ stage1_ckpt_dir: "./exp_output/stage1/"
102
+
103
+ single_inference_times: 10
104
+ inference_steps: 40
105
+ cfg_scale: 3.5
106
+
107
+ seed: 42
108
+ resume_from_checkpoint: "latest"
109
+ checkpointing_steps: 500
110
+
111
+ exp_name: "joyhallo"
112
+ output_dir: "./opts"
113
+
114
+ audio_ckpt_dir: "./pretrained_models/joyhallo/net.pth"
115
+
116
+ ref_img_path: None
117
+
118
+ audio_path: None
configs/unet/unet.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ unet_use_cross_frame_attention: false
4
+ unet_use_temporal_attention: false
5
+ use_motion_module: true
6
+ use_audio_module: true
7
+ motion_module_resolutions:
8
+ - 1
9
+ - 2
10
+ - 4
11
+ - 8
12
+ motion_module_mid_block: true
13
+ motion_module_decoder_only: false
14
+ motion_module_type: Vanilla
15
+ motion_module_kwargs:
16
+ num_attention_heads: 8
17
+ num_transformer_block: 1
18
+ attention_block_types:
19
+ - Temporal_Self
20
+ - Temporal_Self
21
+ temporal_position_encoding: true
22
+ temporal_position_encoding_max_len: 32
23
+ temporal_attention_dim_div: 1
24
+ audio_attention_dim: 768
25
+ stack_enable_blocks_name:
26
+ - "up"
27
+ - "down"
28
+ - "mid"
29
+ stack_enable_blocks_depth: [0,1,2,3]
30
+
31
+ enable_zero_snr: true
32
+
33
+ noise_scheduler_kwargs:
34
+ beta_start: 0.00085
35
+ beta_end: 0.012
36
+ beta_schedule: "linear"
37
+ clip_sample: false
38
+ steps_offset: 1
39
+ ### Zero-SNR params
40
+ prediction_type: "v_prediction"
41
+ rescale_betas_zero_snr: True
42
+ timestep_spacing: "trailing"
43
+
44
+ sampler: DDIM
data/inference.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "video_path": "",
4
+ "mask_path": "",
5
+ "sep_mask_border": "",
6
+ "sep_mask_face": "",
7
+ "sep_mask_lip": "",
8
+ "face_emb_path": "",
9
+ "audio_path": "",
10
+ "vocals_emb_base_all": ""
11
+ }
12
+ ]
joyhallo/__init__.py ADDED
File without changes
joyhallo/animate/__init__.py ADDED
File without changes
joyhallo/animate/face_animate.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module is responsible for animating faces in videos using a combination of deep learning techniques.
3
+ It provides a pipeline for generating face animations by processing video frames and extracting face features.
4
+ The module utilizes various schedulers and utilities for efficient face animation and supports different types
5
+ of latents for more control over the animation process.
6
+
7
+ Functions and Classes:
8
+ - FaceAnimatePipeline: A class that extends the DiffusionPipeline class from the diffusers library to handle face animation tasks.
9
+ - __init__: Initializes the pipeline with the necessary components (VAE, UNets, face locator, etc.).
10
+ - prepare_latents: Generates or loads latents for the animation process, scaling them according to the scheduler's requirements.
11
+ - prepare_extra_step_kwargs: Prepares extra keyword arguments for the scheduler step, ensuring compatibility with different schedulers.
12
+ - decode_latents: Decodes the latents into video frames, ready for animation.
13
+
14
+ Usage:
15
+ - Import the necessary packages and classes.
16
+ - Create a FaceAnimatePipeline instance with the required components.
17
+ - Prepare the latents for the animation process.
18
+ - Use the pipeline to generate the animated video.
19
+
20
+ Note:
21
+ - This module is designed to work with the diffusers library, which provides the underlying framework for face animation using deep learning.
22
+ - The module is intended for research and development purposes, and further optimization and customization may be required for specific use cases.
23
+ """
24
+
25
+ import inspect
26
+ from dataclasses import dataclass
27
+ from typing import Callable, List, Optional, Union
28
+
29
+ import numpy as np
30
+ import torch
31
+ from diffusers import (DDIMScheduler, DiffusionPipeline,
32
+ DPMSolverMultistepScheduler,
33
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
34
+ LMSDiscreteScheduler, PNDMScheduler)
35
+ from diffusers.image_processor import VaeImageProcessor
36
+ from diffusers.utils import BaseOutput
37
+ from diffusers.utils.torch_utils import randn_tensor
38
+ from einops import rearrange, repeat
39
+ from tqdm import tqdm
40
+
41
+ from joyhallo.models.mutual_self_attention import ReferenceAttentionControl
42
+
43
+
44
+ @dataclass
45
+ class FaceAnimatePipelineOutput(BaseOutput):
46
+ """
47
+ FaceAnimatePipelineOutput is a custom class that inherits from BaseOutput and represents the output of the FaceAnimatePipeline.
48
+
49
+ Attributes:
50
+ videos (Union[torch.Tensor, np.ndarray]): A tensor or numpy array containing the generated video frames.
51
+
52
+ Methods:
53
+ __init__(self, videos: Union[torch.Tensor, np.ndarray]): Initializes the FaceAnimatePipelineOutput object with the generated video frames.
54
+ """
55
+ videos: Union[torch.Tensor, np.ndarray]
56
+
57
+ class FaceAnimatePipeline(DiffusionPipeline):
58
+ """
59
+ FaceAnimatePipeline is a custom DiffusionPipeline for animating faces.
60
+
61
+ It inherits from the DiffusionPipeline class and is used to animate faces by
62
+ utilizing a variational autoencoder (VAE), a reference UNet, a denoising UNet,
63
+ a face locator, and an image processor. The pipeline is responsible for generating
64
+ and animating face latents, and decoding the latents to produce the final video output.
65
+
66
+ Attributes:
67
+ vae (VaeImageProcessor): Variational autoencoder for processing images.
68
+ reference_unet (nn.Module): Reference UNet for mutual self-attention.
69
+ denoising_unet (nn.Module): Denoising UNet for image denoising.
70
+ face_locator (nn.Module): Face locator for detecting and cropping faces.
71
+ image_proj (nn.Module): Image projector for processing images.
72
+ scheduler (Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler,
73
+ EulerDiscreteScheduler, EulerAncestralDiscreteScheduler,
74
+ DPMSolverMultistepScheduler]): Diffusion scheduler for
75
+ controlling the noise level.
76
+
77
+ Methods:
78
+ __init__(self, vae, reference_unet, denoising_unet, face_locator,
79
+ image_proj, scheduler): Initializes the FaceAnimatePipeline
80
+ with the given components and scheduler.
81
+ prepare_latents(self, batch_size, num_channels_latents, width, height,
82
+ video_length, dtype, device, generator=None, latents=None):
83
+ Prepares the initial latents for video generation.
84
+ prepare_extra_step_kwargs(self, generator, eta): Prepares extra keyword
85
+ arguments for the scheduler step.
86
+ decode_latents(self, latents): Decodes the latents to produce the final
87
+ video output.
88
+ """
89
+ def __init__(
90
+ self,
91
+ vae,
92
+ reference_unet,
93
+ denoising_unet,
94
+ face_locator,
95
+ image_proj,
96
+ scheduler: Union[
97
+ DDIMScheduler,
98
+ PNDMScheduler,
99
+ LMSDiscreteScheduler,
100
+ EulerDiscreteScheduler,
101
+ EulerAncestralDiscreteScheduler,
102
+ DPMSolverMultistepScheduler,
103
+ ],
104
+ ) -> None:
105
+ super().__init__()
106
+
107
+ self.register_modules(
108
+ vae=vae,
109
+ reference_unet=reference_unet,
110
+ denoising_unet=denoising_unet,
111
+ face_locator=face_locator,
112
+ scheduler=scheduler,
113
+ image_proj=image_proj,
114
+ )
115
+
116
+ self.vae_scale_factor: int = 2 ** (len(self.vae.config.block_out_channels) - 1)
117
+
118
+ self.ref_image_processor = VaeImageProcessor(
119
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True,
120
+ )
121
+
122
+ @property
123
+ def _execution_device(self):
124
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
125
+ return self.device
126
+ for module in self.unet.modules():
127
+ if (
128
+ hasattr(module, "_hf_hook")
129
+ and hasattr(module._hf_hook, "execution_device")
130
+ and module._hf_hook.execution_device is not None
131
+ ):
132
+ return torch.device(module._hf_hook.execution_device)
133
+ return self.device
134
+
135
+ def prepare_latents(
136
+ self,
137
+ batch_size: int, # Number of videos to generate in parallel
138
+ num_channels_latents: int, # Number of channels in the latents
139
+ width: int, # Width of the video frame
140
+ height: int, # Height of the video frame
141
+ video_length: int, # Length of the video in frames
142
+ dtype: torch.dtype, # Data type of the latents
143
+ device: torch.device, # Device to store the latents on
144
+ generator: Optional[torch.Generator] = None, # Random number generator for reproducibility
145
+ latents: Optional[torch.Tensor] = None # Pre-generated latents (optional)
146
+ ):
147
+ """
148
+ Prepares the initial latents for video generation.
149
+
150
+ Args:
151
+ batch_size (int): Number of videos to generate in parallel.
152
+ num_channels_latents (int): Number of channels in the latents.
153
+ width (int): Width of the video frame.
154
+ height (int): Height of the video frame.
155
+ video_length (int): Length of the video in frames.
156
+ dtype (torch.dtype): Data type of the latents.
157
+ device (torch.device): Device to store the latents on.
158
+ generator (Optional[torch.Generator]): Random number generator for reproducibility.
159
+ latents (Optional[torch.Tensor]): Pre-generated latents (optional).
160
+
161
+ Returns:
162
+ latents (torch.Tensor): Tensor of shape (batch_size, num_channels_latents, width, height)
163
+ containing the initial latents for video generation.
164
+ """
165
+ shape = (
166
+ batch_size,
167
+ num_channels_latents,
168
+ video_length,
169
+ height // self.vae_scale_factor,
170
+ width // self.vae_scale_factor,
171
+ )
172
+ if isinstance(generator, list) and len(generator) != batch_size:
173
+ raise ValueError(
174
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
175
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
176
+ )
177
+
178
+ if latents is None:
179
+ latents = randn_tensor(
180
+ shape, generator=generator, device=device, dtype=dtype
181
+ )
182
+ else:
183
+ latents = latents.to(device)
184
+
185
+ # scale the initial noise by the standard deviation required by the scheduler
186
+ latents = latents * self.scheduler.init_noise_sigma
187
+ return latents
188
+
189
+ def prepare_extra_step_kwargs(self, generator, eta):
190
+ """
191
+ Prepares extra keyword arguments for the scheduler step.
192
+
193
+ Args:
194
+ generator (Optional[torch.Generator]): Random number generator for reproducibility.
195
+ eta (float): The eta (η) parameter used with the DDIMScheduler.
196
+ It corresponds to η in the DDIM paper (https://arxiv.org/abs/2010.02502) and should be between [0, 1].
197
+
198
+ Returns:
199
+ dict: A dictionary containing the extra keyword arguments for the scheduler step.
200
+ """
201
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
202
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
203
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
204
+ # and should be between [0, 1]
205
+
206
+ accepts_eta = "eta" in set(
207
+ inspect.signature(self.scheduler.step).parameters.keys()
208
+ )
209
+ extra_step_kwargs = {}
210
+ if accepts_eta:
211
+ extra_step_kwargs["eta"] = eta
212
+
213
+ # check if the scheduler accepts generator
214
+ accepts_generator = "generator" in set(
215
+ inspect.signature(self.scheduler.step).parameters.keys()
216
+ )
217
+ if accepts_generator:
218
+ extra_step_kwargs["generator"] = generator
219
+ return extra_step_kwargs
220
+
221
+ def decode_latents(self, latents):
222
+ """
223
+ Decode the latents to produce a video.
224
+
225
+ Parameters:
226
+ latents (torch.Tensor): The latents to be decoded.
227
+
228
+ Returns:
229
+ video (torch.Tensor): The decoded video.
230
+ video_length (int): The length of the video in frames.
231
+ """
232
+ video_length = latents.shape[2]
233
+ latents = 1 / 0.18215 * latents
234
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
235
+ # video = self.vae.decode(latents).sample
236
+ video = []
237
+ for frame_idx in tqdm(range(latents.shape[0])):
238
+ video.append(self.vae.decode(
239
+ latents[frame_idx: frame_idx + 1]).sample)
240
+ video = torch.cat(video)
241
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
242
+ video = (video / 2 + 0.5).clamp(0, 1)
243
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
244
+ video = video.cpu().float().numpy()
245
+ return video
246
+
247
+
248
+ @torch.no_grad()
249
+ def __call__(
250
+ self,
251
+ ref_image,
252
+ face_emb,
253
+ audio_tensor,
254
+ face_mask,
255
+ pixel_values_full_mask,
256
+ pixel_values_face_mask,
257
+ pixel_values_lip_mask,
258
+ width,
259
+ height,
260
+ video_length,
261
+ num_inference_steps,
262
+ guidance_scale,
263
+ num_images_per_prompt=1,
264
+ eta: float = 0.0,
265
+ motion_scale: Optional[List[torch.Tensor]] = None,
266
+ generator: Optional[Union[torch.Generator,
267
+ List[torch.Generator]]] = None,
268
+ output_type: Optional[str] = "tensor",
269
+ return_dict: bool = True,
270
+ callback: Optional[Callable[[
271
+ int, int, torch.FloatTensor], None]] = None,
272
+ callback_steps: Optional[int] = 1,
273
+ **kwargs,
274
+ ):
275
+ # Default height and width to unet
276
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
277
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
278
+
279
+ device = self._execution_device
280
+
281
+ do_classifier_free_guidance = guidance_scale > 1.0
282
+
283
+ # Prepare timesteps
284
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
285
+ timesteps = self.scheduler.timesteps
286
+
287
+ batch_size = 1
288
+
289
+ # prepare clip image embeddings
290
+ clip_image_embeds = face_emb
291
+ clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype)
292
+
293
+ encoder_hidden_states = self.image_proj(clip_image_embeds)
294
+ uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds))
295
+
296
+ if do_classifier_free_guidance:
297
+ encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0)
298
+
299
+ reference_control_writer = ReferenceAttentionControl(
300
+ self.reference_unet,
301
+ do_classifier_free_guidance=do_classifier_free_guidance,
302
+ mode="write",
303
+ batch_size=batch_size,
304
+ fusion_blocks="full",
305
+ )
306
+ reference_control_reader = ReferenceAttentionControl(
307
+ self.denoising_unet,
308
+ do_classifier_free_guidance=do_classifier_free_guidance,
309
+ mode="read",
310
+ batch_size=batch_size,
311
+ fusion_blocks="full",
312
+ )
313
+
314
+ num_channels_latents = self.denoising_unet.in_channels
315
+
316
+ latents = self.prepare_latents(
317
+ batch_size * num_images_per_prompt,
318
+ num_channels_latents,
319
+ width,
320
+ height,
321
+ video_length,
322
+ clip_image_embeds.dtype,
323
+ device,
324
+ generator,
325
+ )
326
+
327
+ # Prepare extra step kwargs.
328
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
329
+
330
+ # Prepare ref image latents
331
+ ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w")
332
+ ref_image_tensor = self.ref_image_processor.preprocess(ref_image_tensor, height=height, width=width) # (bs, c, width, height)
333
+ ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device)
334
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
335
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
336
+
337
+
338
+ face_mask = face_mask.unsqueeze(1).to(dtype=self.face_locator.dtype, device=self.face_locator.device) # (bs, f, c, H, W)
339
+ face_mask = repeat(face_mask, "b f c h w -> b (repeat f) c h w", repeat=video_length)
340
+ face_mask = face_mask.transpose(1, 2) # (bs, c, f, H, W)
341
+ face_mask = self.face_locator(face_mask)
342
+ face_mask = torch.cat([torch.zeros_like(face_mask), face_mask], dim=0) if do_classifier_free_guidance else face_mask
343
+
344
+ pixel_values_full_mask = (
345
+ [torch.cat([mask] * 2) for mask in pixel_values_full_mask]
346
+ if do_classifier_free_guidance
347
+ else pixel_values_full_mask
348
+ )
349
+ pixel_values_face_mask = (
350
+ [torch.cat([mask] * 2) for mask in pixel_values_face_mask]
351
+ if do_classifier_free_guidance
352
+ else pixel_values_face_mask
353
+ )
354
+ pixel_values_lip_mask = (
355
+ [torch.cat([mask] * 2) for mask in pixel_values_lip_mask]
356
+ if do_classifier_free_guidance
357
+ else pixel_values_lip_mask
358
+ )
359
+ pixel_values_face_mask_ = []
360
+ for mask in pixel_values_face_mask:
361
+ pixel_values_face_mask_.append(
362
+ mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
363
+ pixel_values_face_mask = pixel_values_face_mask_
364
+ pixel_values_lip_mask_ = []
365
+ for mask in pixel_values_lip_mask:
366
+ pixel_values_lip_mask_.append(
367
+ mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
368
+ pixel_values_lip_mask = pixel_values_lip_mask_
369
+ pixel_values_full_mask_ = []
370
+ for mask in pixel_values_full_mask:
371
+ pixel_values_full_mask_.append(
372
+ mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
373
+ pixel_values_full_mask = pixel_values_full_mask_
374
+
375
+
376
+ uncond_audio_tensor = torch.zeros_like(audio_tensor)
377
+ audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0)
378
+ audio_tensor = audio_tensor.to(dtype=self.denoising_unet.dtype, device=self.denoising_unet.device)
379
+
380
+ # denoising loop
381
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
382
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
383
+ for i, t in enumerate(timesteps):
384
+ # Forward reference image
385
+ if i == 0:
386
+ self.reference_unet(
387
+ ref_image_latents.repeat(
388
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
389
+ ),
390
+ torch.zeros_like(t),
391
+ encoder_hidden_states=encoder_hidden_states,
392
+ return_dict=False,
393
+ )
394
+ reference_control_reader.update(reference_control_writer)
395
+
396
+ # expand the latents if we are doing classifier free guidance
397
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
398
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
399
+
400
+ noise_pred = self.denoising_unet(
401
+ latent_model_input,
402
+ t,
403
+ encoder_hidden_states=encoder_hidden_states,
404
+ mask_cond_fea=face_mask,
405
+ full_mask=pixel_values_full_mask,
406
+ face_mask=pixel_values_face_mask,
407
+ lip_mask=pixel_values_lip_mask,
408
+ audio_embedding=audio_tensor,
409
+ motion_scale=motion_scale,
410
+ return_dict=False,
411
+ )[0]
412
+
413
+ # perform guidance
414
+ if do_classifier_free_guidance:
415
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
416
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
417
+
418
+ # compute the previous noisy sample x_t -> x_t-1
419
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
420
+
421
+ # call the callback, if provided
422
+ if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
423
+ progress_bar.update()
424
+ if callback is not None and i % callback_steps == 0:
425
+ step_idx = i // getattr(self.scheduler, "order", 1)
426
+ callback(step_idx, t, latents)
427
+
428
+ reference_control_reader.clear()
429
+ reference_control_writer.clear()
430
+
431
+ # Post-processing
432
+ images = self.decode_latents(latents) # (b, c, f, h, w)
433
+
434
+ # Convert to tensor
435
+ if output_type == "tensor":
436
+ images = torch.from_numpy(images)
437
+
438
+ if not return_dict:
439
+ return images
440
+
441
+ return FaceAnimatePipelineOutput(videos=images)
joyhallo/animate/face_animate_static.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module is responsible for handling the animation of faces using a combination of deep learning models and image processing techniques.
3
+ It provides a pipeline to generate realistic face animations by incorporating user-provided conditions such as facial expressions and environments.
4
+ The module utilizes various schedulers and utilities to optimize the animation process and ensure efficient performance.
5
+
6
+ Functions and Classes:
7
+ - StaticPipelineOutput: A class that represents the output of the animation pipeline, c
8
+ ontaining properties and methods related to the generated images.
9
+ - prepare_latents: A function that prepares the initial noise for the animation process,
10
+ scaling it according to the scheduler's requirements.
11
+ - prepare_condition: A function that processes the user-provided conditions
12
+ (e.g., facial expressions) and prepares them for use in the animation pipeline.
13
+ - decode_latents: A function that decodes the latent representations of the face animations into
14
+ their corresponding image formats.
15
+ - prepare_extra_step_kwargs: A function that prepares additional parameters for each step of
16
+ the animation process, such as the generator and eta values.
17
+
18
+ Dependencies:
19
+ - numpy: A library for numerical computing.
20
+ - torch: A machine learning library based on PyTorch.
21
+ - diffusers: A library for image-to-image diffusion models.
22
+ - transformers: A library for pre-trained transformer models.
23
+
24
+ Usage:
25
+ - To create an instance of the animation pipeline, provide the necessary components such as
26
+ the VAE, reference UNET, denoising UNET, face locator, and image processor.
27
+ - Use the pipeline's methods to prepare the latents, conditions, and extra step arguments as
28
+ required for the animation process.
29
+ - Generate the face animations by decoding the latents and processing the conditions.
30
+
31
+ Note:
32
+ - The module is designed to work with the diffusers library, which is based on
33
+ the paper "Diffusion Models for Image-to-Image Translation" (https://arxiv.org/abs/2102.02765).
34
+ - The face animations generated by this module should be used for entertainment purposes
35
+ only and should respect the rights and privacy of the individuals involved.
36
+ """
37
+ import inspect
38
+ from dataclasses import dataclass
39
+ from typing import Callable, List, Optional, Union
40
+
41
+ import numpy as np
42
+ import torch
43
+ from diffusers import DiffusionPipeline
44
+ from diffusers.image_processor import VaeImageProcessor
45
+ from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
46
+ EulerAncestralDiscreteScheduler,
47
+ EulerDiscreteScheduler, LMSDiscreteScheduler,
48
+ PNDMScheduler)
49
+ from diffusers.utils import BaseOutput, is_accelerate_available
50
+ from diffusers.utils.torch_utils import randn_tensor
51
+ from einops import rearrange
52
+ from tqdm import tqdm
53
+ from transformers import CLIPImageProcessor
54
+
55
+ from joyhallo.models.mutual_self_attention import ReferenceAttentionControl
56
+
57
+ if is_accelerate_available():
58
+ from accelerate import cpu_offload
59
+ else:
60
+ raise ImportError("Please install accelerate via `pip install accelerate`")
61
+
62
+
63
+ @dataclass
64
+ class StaticPipelineOutput(BaseOutput):
65
+ """
66
+ StaticPipelineOutput is a class that represents the output of the static pipeline.
67
+ It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray.
68
+
69
+ Attributes:
70
+ images (Union[torch.Tensor, np.ndarray]): The generated images.
71
+ """
72
+ images: Union[torch.Tensor, np.ndarray]
73
+
74
+
75
+ class StaticPipeline(DiffusionPipeline):
76
+ """
77
+ StaticPipelineOutput is a class that represents the output of the static pipeline.
78
+ It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray.
79
+
80
+ Attributes:
81
+ images (Union[torch.Tensor, np.ndarray]): The generated images.
82
+ """
83
+ _optional_components = []
84
+
85
+ def __init__(
86
+ self,
87
+ vae,
88
+ reference_unet,
89
+ denoising_unet,
90
+ face_locator,
91
+ imageproj,
92
+ scheduler: Union[
93
+ DDIMScheduler,
94
+ PNDMScheduler,
95
+ LMSDiscreteScheduler,
96
+ EulerDiscreteScheduler,
97
+ EulerAncestralDiscreteScheduler,
98
+ DPMSolverMultistepScheduler,
99
+ ],
100
+ ):
101
+ super().__init__()
102
+
103
+ self.register_modules(
104
+ vae=vae,
105
+ reference_unet=reference_unet,
106
+ denoising_unet=denoising_unet,
107
+ face_locator=face_locator,
108
+ scheduler=scheduler,
109
+ imageproj=imageproj,
110
+ )
111
+ self.vae_scale_factor = 2 ** (
112
+ len(self.vae.config.block_out_channels) - 1)
113
+ self.clip_image_processor = CLIPImageProcessor()
114
+ self.ref_image_processor = VaeImageProcessor(
115
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
116
+ )
117
+ self.cond_image_processor = VaeImageProcessor(
118
+ vae_scale_factor=self.vae_scale_factor,
119
+ do_convert_rgb=True,
120
+ do_normalize=False,
121
+ )
122
+
123
+ def enable_vae_slicing(self):
124
+ """
125
+ Enable VAE slicing.
126
+
127
+ This method enables slicing for the VAE model, which can help improve the performance of decoding latents when working with large images.
128
+ """
129
+ self.vae.enable_slicing()
130
+
131
+ def disable_vae_slicing(self):
132
+ """
133
+ Disable vae slicing.
134
+
135
+ This function disables the vae slicing for the StaticPipeline object.
136
+ It calls the `disable_slicing()` method of the vae model.
137
+ This is useful when you want to use the entire vae model for decoding latents
138
+ instead of slicing it for better performance.
139
+ """
140
+ self.vae.disable_slicing()
141
+
142
+ def enable_sequential_cpu_offload(self, gpu_id=0):
143
+ """
144
+ Offloads selected models to the GPU for increased performance.
145
+
146
+ Args:
147
+ gpu_id (int, optional): The ID of the GPU to offload models to. Defaults to 0.
148
+ """
149
+ device = torch.device(f"cuda:{gpu_id}")
150
+
151
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
152
+ if cpu_offloaded_model is not None:
153
+ cpu_offload(cpu_offloaded_model, device)
154
+
155
+ @property
156
+ def _execution_device(self):
157
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
158
+ return self.device
159
+ for module in self.unet.modules():
160
+ if (
161
+ hasattr(module, "_hf_hook")
162
+ and hasattr(module._hf_hook, "execution_device")
163
+ and module._hf_hook.execution_device is not None
164
+ ):
165
+ return torch.device(module._hf_hook.execution_device)
166
+ return self.device
167
+
168
+ def decode_latents(self, latents):
169
+ """
170
+ Decode the given latents to video frames.
171
+
172
+ Parameters:
173
+ latents (torch.Tensor): The latents to be decoded. Shape: (batch_size, num_channels_latents, video_length, height, width).
174
+
175
+ Returns:
176
+ video (torch.Tensor): The decoded video frames. Shape: (batch_size, num_channels_latents, video_length, height, width).
177
+ """
178
+ video_length = latents.shape[2]
179
+ latents = 1 / 0.18215 * latents
180
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
181
+ # video = self.vae.decode(latents).sample
182
+ video = []
183
+ for frame_idx in tqdm(range(latents.shape[0])):
184
+ video.append(self.vae.decode(
185
+ latents[frame_idx: frame_idx + 1]).sample)
186
+ video = torch.cat(video)
187
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
188
+ video = (video / 2 + 0.5).clamp(0, 1)
189
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
190
+ video = video.cpu().float().numpy()
191
+ return video
192
+
193
+ def prepare_extra_step_kwargs(self, generator, eta):
194
+ """
195
+ Prepare extra keyword arguments for the scheduler step.
196
+
197
+ Since not all schedulers have the same signature, this function helps to create a consistent interface for the scheduler.
198
+
199
+ Args:
200
+ generator (Optional[torch.Generator]): A random number generator for reproducibility.
201
+ eta (float): The eta parameter used with the DDIMScheduler. It should be between 0 and 1.
202
+
203
+ Returns:
204
+ dict: A dictionary containing the extra keyword arguments for the scheduler step.
205
+ """
206
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
207
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
208
+ # and should be between [0, 1]
209
+
210
+ accepts_eta = "eta" in set(
211
+ inspect.signature(self.scheduler.step).parameters.keys()
212
+ )
213
+ extra_step_kwargs = {}
214
+ if accepts_eta:
215
+ extra_step_kwargs["eta"] = eta
216
+
217
+ # check if the scheduler accepts generator
218
+ accepts_generator = "generator" in set(
219
+ inspect.signature(self.scheduler.step).parameters.keys()
220
+ )
221
+ if accepts_generator:
222
+ extra_step_kwargs["generator"] = generator
223
+ return extra_step_kwargs
224
+
225
+ def prepare_latents(
226
+ self,
227
+ batch_size,
228
+ num_channels_latents,
229
+ width,
230
+ height,
231
+ dtype,
232
+ device,
233
+ generator,
234
+ latents=None,
235
+ ):
236
+ """
237
+ Prepares the initial latents for the diffusion pipeline.
238
+
239
+ Args:
240
+ batch_size (int): The number of images to generate in one forward pass.
241
+ num_channels_latents (int): The number of channels in the latents tensor.
242
+ width (int): The width of the latents tensor.
243
+ height (int): The height of the latents tensor.
244
+ dtype (torch.dtype): The data type of the latents tensor.
245
+ device (torch.device): The device to place the latents tensor on.
246
+ generator (Optional[torch.Generator], optional): A random number generator
247
+ for reproducibility. Defaults to None.
248
+ latents (Optional[torch.Tensor], optional): Pre-computed latents to use as
249
+ initial conditions for the diffusion process. Defaults to None.
250
+
251
+ Returns:
252
+ torch.Tensor: The prepared latents tensor.
253
+ """
254
+ shape = (
255
+ batch_size,
256
+ num_channels_latents,
257
+ height // self.vae_scale_factor,
258
+ width // self.vae_scale_factor,
259
+ )
260
+ if isinstance(generator, list) and len(generator) != batch_size:
261
+ raise ValueError(
262
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
263
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
264
+ )
265
+
266
+ if latents is None:
267
+ latents = randn_tensor(
268
+ shape, generator=generator, device=device, dtype=dtype
269
+ )
270
+ else:
271
+ latents = latents.to(device)
272
+
273
+ # scale the initial noise by the standard deviation required by the scheduler
274
+ latents = latents * self.scheduler.init_noise_sigma
275
+ return latents
276
+
277
+ def prepare_condition(
278
+ self,
279
+ cond_image,
280
+ width,
281
+ height,
282
+ device,
283
+ dtype,
284
+ do_classififer_free_guidance=False,
285
+ ):
286
+ """
287
+ Prepares the condition for the face animation pipeline.
288
+
289
+ Args:
290
+ cond_image (torch.Tensor): The conditional image tensor.
291
+ width (int): The width of the output image.
292
+ height (int): The height of the output image.
293
+ device (torch.device): The device to run the pipeline on.
294
+ dtype (torch.dtype): The data type of the tensor.
295
+ do_classififer_free_guidance (bool, optional): Whether to use classifier-free guidance or not. Defaults to False.
296
+
297
+ Returns:
298
+ Tuple[torch.Tensor, torch.Tensor]: A tuple of processed condition and mask tensors.
299
+ """
300
+ image = self.cond_image_processor.preprocess(
301
+ cond_image, height=height, width=width
302
+ ).to(dtype=torch.float32)
303
+
304
+ image = image.to(device=device, dtype=dtype)
305
+
306
+ if do_classififer_free_guidance:
307
+ image = torch.cat([image] * 2)
308
+
309
+ return image
310
+
311
+ @torch.no_grad()
312
+ def __call__(
313
+ self,
314
+ ref_image,
315
+ face_mask,
316
+ width,
317
+ height,
318
+ num_inference_steps,
319
+ guidance_scale,
320
+ face_embedding,
321
+ num_images_per_prompt=1,
322
+ eta: float = 0.0,
323
+ generator: Optional[Union[torch.Generator,
324
+ List[torch.Generator]]] = None,
325
+ output_type: Optional[str] = "tensor",
326
+ return_dict: bool = True,
327
+ callback: Optional[Callable[[
328
+ int, int, torch.FloatTensor], None]] = None,
329
+ callback_steps: Optional[int] = 1,
330
+ **kwargs,
331
+ ):
332
+ # Default height and width to unet
333
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
334
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
335
+
336
+ device = self._execution_device
337
+
338
+ do_classifier_free_guidance = guidance_scale > 1.0
339
+
340
+ # Prepare timesteps
341
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
342
+ timesteps = self.scheduler.timesteps
343
+
344
+ batch_size = 1
345
+
346
+ image_prompt_embeds = self.imageproj(face_embedding)
347
+ uncond_image_prompt_embeds = self.imageproj(
348
+ torch.zeros_like(face_embedding))
349
+
350
+ if do_classifier_free_guidance:
351
+ image_prompt_embeds = torch.cat(
352
+ [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
353
+ )
354
+
355
+ reference_control_writer = ReferenceAttentionControl(
356
+ self.reference_unet,
357
+ do_classifier_free_guidance=do_classifier_free_guidance,
358
+ mode="write",
359
+ batch_size=batch_size,
360
+ fusion_blocks="full",
361
+ )
362
+ reference_control_reader = ReferenceAttentionControl(
363
+ self.denoising_unet,
364
+ do_classifier_free_guidance=do_classifier_free_guidance,
365
+ mode="read",
366
+ batch_size=batch_size,
367
+ fusion_blocks="full",
368
+ )
369
+
370
+ num_channels_latents = self.denoising_unet.in_channels
371
+ latents = self.prepare_latents(
372
+ batch_size * num_images_per_prompt,
373
+ num_channels_latents,
374
+ width,
375
+ height,
376
+ face_embedding.dtype,
377
+ device,
378
+ generator,
379
+ )
380
+ latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
381
+ # latents_dtype = latents.dtype
382
+
383
+ # Prepare extra step kwargs.
384
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
385
+
386
+ # Prepare ref image latents
387
+ ref_image_tensor = self.ref_image_processor.preprocess(
388
+ ref_image, height=height, width=width
389
+ ) # (bs, c, width, height)
390
+ ref_image_tensor = ref_image_tensor.to(
391
+ dtype=self.vae.dtype, device=self.vae.device
392
+ )
393
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
394
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
395
+
396
+ # Prepare face mask image
397
+ face_mask_tensor = self.cond_image_processor.preprocess(
398
+ face_mask, height=height, width=width
399
+ )
400
+ face_mask_tensor = face_mask_tensor.unsqueeze(2) # (bs, c, 1, h, w)
401
+ face_mask_tensor = face_mask_tensor.to(
402
+ device=device, dtype=self.face_locator.dtype
403
+ )
404
+ mask_fea = self.face_locator(face_mask_tensor)
405
+ mask_fea = (
406
+ torch.cat(
407
+ [mask_fea] * 2) if do_classifier_free_guidance else mask_fea
408
+ )
409
+
410
+ # denoising loop
411
+ num_warmup_steps = len(timesteps) - \
412
+ num_inference_steps * self.scheduler.order
413
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
414
+ for i, t in enumerate(timesteps):
415
+ # 1. Forward reference image
416
+ if i == 0:
417
+ self.reference_unet(
418
+ ref_image_latents.repeat(
419
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
420
+ ),
421
+ torch.zeros_like(t),
422
+ encoder_hidden_states=image_prompt_embeds,
423
+ return_dict=False,
424
+ )
425
+
426
+ # 2. Update reference unet feature into denosing net
427
+ reference_control_reader.update(reference_control_writer)
428
+
429
+ # 3.1 expand the latents if we are doing classifier free guidance
430
+ latent_model_input = (
431
+ torch.cat(
432
+ [latents] * 2) if do_classifier_free_guidance else latents
433
+ )
434
+ latent_model_input = self.scheduler.scale_model_input(
435
+ latent_model_input, t
436
+ )
437
+
438
+ noise_pred = self.denoising_unet(
439
+ latent_model_input,
440
+ t,
441
+ encoder_hidden_states=image_prompt_embeds,
442
+ mask_cond_fea=mask_fea,
443
+ return_dict=False,
444
+ )[0]
445
+
446
+ # perform guidance
447
+ if do_classifier_free_guidance:
448
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
449
+ noise_pred = noise_pred_uncond + guidance_scale * (
450
+ noise_pred_text - noise_pred_uncond
451
+ )
452
+
453
+ # compute the previous noisy sample x_t -> x_t-1
454
+ latents = self.scheduler.step(
455
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
456
+ )[0]
457
+
458
+ # call the callback, if provided
459
+ if i == len(timesteps) - 1 or (
460
+ (i + 1) > num_warmup_steps and (i +
461
+ 1) % self.scheduler.order == 0
462
+ ):
463
+ progress_bar.update()
464
+ if callback is not None and i % callback_steps == 0:
465
+ step_idx = i // getattr(self.scheduler, "order", 1)
466
+ callback(step_idx, t, latents)
467
+ reference_control_reader.clear()
468
+ reference_control_writer.clear()
469
+
470
+ # Post-processing
471
+ image = self.decode_latents(latents) # (b, c, 1, h, w)
472
+
473
+ # Convert to tensor
474
+ if output_type == "tensor":
475
+ image = torch.from_numpy(image)
476
+
477
+ if not return_dict:
478
+ return image
479
+
480
+ return StaticPipelineOutput(images=image)
joyhallo/datasets/__init__.py ADDED
File without changes
joyhallo/datasets/audio_processor.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This module contains the AudioProcessor class and related functions for processing audio data.
3
+ It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
4
+ and audio separation. The class is initialized with configuration parameters and can process
5
+ audio files using the provided models.
6
+ '''
7
+ import math
8
+ import os
9
+
10
+ import librosa
11
+ import numpy as np
12
+ import torch
13
+ from audio_separator.separator import Separator
14
+ from einops import rearrange
15
+ from transformers import Wav2Vec2FeatureExtractor
16
+
17
+ from joyhallo.models.wav2vec import Wav2VecModel
18
+ from joyhallo.utils.util import resample_audio
19
+
20
+
21
+ class AudioProcessor:
22
+ """
23
+ AudioProcessor is a class that handles the processing of audio files.
24
+ It takes care of preprocessing the audio files, extracting features
25
+ using wav2vec models, and separating audio signals if needed.
26
+
27
+ :param sample_rate: Sampling rate of the audio file
28
+ :param fps: Frames per second for the extracted features
29
+ :param wav2vec_model_path: Path to the wav2vec model
30
+ :param only_last_features: Whether to only use the last features
31
+ :param audio_separator_model_path: Path to the audio separator model
32
+ :param audio_separator_model_name: Name of the audio separator model
33
+ :param cache_dir: Directory to cache the intermediate results
34
+ :param device: Device to run the processing on
35
+ """
36
+ def __init__(
37
+ self,
38
+ sample_rate,
39
+ fps,
40
+ wav2vec_model_path,
41
+ only_last_features,
42
+ audio_separator_model_path:str=None,
43
+ audio_separator_model_name:str=None,
44
+ cache_dir:str='',
45
+ device="cuda:0",
46
+ ) -> None:
47
+ self.sample_rate = sample_rate
48
+ self.fps = fps
49
+ self.device = device
50
+
51
+ self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device)
52
+ self.audio_encoder.feature_extractor._freeze_parameters()
53
+ self.only_last_features = only_last_features
54
+
55
+ if audio_separator_model_name is not None:
56
+ try:
57
+ os.makedirs(cache_dir, exist_ok=True)
58
+ except OSError as _:
59
+ print("Fail to create the output cache dir.")
60
+ self.audio_separator = Separator(
61
+ output_dir=cache_dir,
62
+ output_single_stem="vocals",
63
+ model_file_dir=audio_separator_model_path,
64
+ )
65
+ self.audio_separator.load_model(audio_separator_model_name)
66
+ assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
67
+ else:
68
+ self.audio_separator=None
69
+ print("Use audio directly without vocals seperator.")
70
+
71
+
72
+ self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
73
+
74
+
75
+ def preprocess(self, wav_file: str, clip_length: int=-1):
76
+ """
77
+ Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
78
+ The separated vocal track is then converted into wav2vec2 for further processing or analysis.
79
+
80
+ Args:
81
+ wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
82
+
83
+ Raises:
84
+ RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues
85
+ such as file not found, unsupported file format, or errors during the audio processing steps.
86
+
87
+ Returns:
88
+ torch.tensor: Returns an audio embedding as a torch.tensor
89
+ """
90
+ if self.audio_separator is not None:
91
+ # 1. separate vocals
92
+ # TODO: process in memory
93
+ outputs = self.audio_separator.separate(wav_file)
94
+ if len(outputs) <= 0:
95
+ raise RuntimeError("Audio separate failed.")
96
+
97
+ vocal_audio_file = outputs[0]
98
+ vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
99
+ vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file)
100
+ vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate)
101
+ else:
102
+ vocal_audio_file=wav_file
103
+
104
+ # 2. extract wav2vec features
105
+ speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate)
106
+ audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values)
107
+ seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
108
+ audio_length = seq_len
109
+
110
+ audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
111
+
112
+ if clip_length>0 and seq_len % clip_length != 0:
113
+ audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0)
114
+ seq_len += clip_length - seq_len % clip_length
115
+ audio_feature = audio_feature.unsqueeze(0)
116
+
117
+ with torch.no_grad():
118
+ embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True)
119
+ assert len(embeddings) > 0, "Fail to extract audio embedding"
120
+ if self.only_last_features:
121
+ audio_emb = embeddings.last_hidden_state.squeeze()
122
+ else:
123
+ audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
124
+ audio_emb = rearrange(audio_emb, "b s d -> s b d")
125
+
126
+ audio_emb = audio_emb.cpu().detach()
127
+
128
+ return audio_emb, audio_length
129
+
130
+ def get_embedding(self, wav_file: str):
131
+ """preprocess wav audio file convert to embeddings
132
+
133
+ Args:
134
+ wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
135
+
136
+ Returns:
137
+ torch.tensor: Returns an audio embedding as a torch.tensor
138
+ """
139
+ speech_array, sampling_rate = librosa.load(
140
+ wav_file, sr=self.sample_rate)
141
+ assert sampling_rate == 16000, "The audio sample rate must be 16000"
142
+ audio_feature = np.squeeze(self.wav2vec_feature_extractor(
143
+ speech_array, sampling_rate=sampling_rate).input_values)
144
+ seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
145
+
146
+ audio_feature = torch.from_numpy(
147
+ audio_feature).float().to(device=self.device)
148
+ audio_feature = audio_feature.unsqueeze(0)
149
+
150
+ with torch.no_grad():
151
+ embeddings = self.audio_encoder(
152
+ audio_feature, seq_len=seq_len, output_hidden_states=True)
153
+ assert len(embeddings) > 0, "Fail to extract audio embedding"
154
+
155
+ if self.only_last_features:
156
+ audio_emb = embeddings.last_hidden_state.squeeze()
157
+ else:
158
+ audio_emb = torch.stack(
159
+ embeddings.hidden_states[1:], dim=1).squeeze(0)
160
+ audio_emb = rearrange(audio_emb, "b s d -> s b d")
161
+
162
+ audio_emb = audio_emb.cpu().detach()
163
+
164
+ return audio_emb
165
+
166
+ def close(self):
167
+ """
168
+ TODO: to be implemented
169
+ """
170
+ return self
171
+
172
+ def __enter__(self):
173
+ return self
174
+
175
+ def __exit__(self, _exc_type, _exc_val, _exc_tb):
176
+ self.close()
joyhallo/datasets/image_processor.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module is responsible for processing images, particularly for face-related tasks.
3
+ It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like
4
+ face detection, augmentation, and mask rendering. The ImageProcessor class encapsulates
5
+ the functionality for these operations.
6
+ """
7
+ import os
8
+ from typing import List
9
+
10
+ import cv2
11
+ import mediapipe as mp
12
+ import numpy as np
13
+ import torch
14
+ from insightface.app import FaceAnalysis
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+
18
+ from ..utils.util import (blur_mask, get_landmark_overframes, get_mask,
19
+ get_union_face_mask, get_union_lip_mask)
20
+
21
+ MEAN = 0.5
22
+ STD = 0.5
23
+
24
+ class ImageProcessor:
25
+ """
26
+ ImageProcessor is a class responsible for processing images, particularly for face-related tasks.
27
+ It takes in an image and performs various operations such as augmentation, face detection,
28
+ face embedding extraction, and rendering a face mask. The processed images are then used for
29
+ further analysis or recognition purposes.
30
+
31
+ Attributes:
32
+ img_size (int): The size of the image to be processed.
33
+ face_analysis_model_path (str): The path to the face analysis model.
34
+
35
+ Methods:
36
+ preprocess(source_image_path, cache_dir):
37
+ Preprocesses the input image by performing augmentation, face detection,
38
+ face embedding extraction, and rendering a face mask.
39
+
40
+ close():
41
+ Closes the ImageProcessor and releases any resources being used.
42
+
43
+ _augmentation(images, transform, state=None):
44
+ Applies image augmentation to the input images using the given transform and state.
45
+
46
+ __enter__():
47
+ Enters a runtime context and returns the ImageProcessor object.
48
+
49
+ __exit__(_exc_type, _exc_val, _exc_tb):
50
+ Exits a runtime context and handles any exceptions that occurred during the processing.
51
+ """
52
+ def __init__(self, img_size, face_analysis_model_path) -> None:
53
+ self.img_size = img_size
54
+
55
+ self.pixel_transform = transforms.Compose(
56
+ [
57
+ transforms.Resize(self.img_size),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize([MEAN], [STD]),
60
+ ]
61
+ )
62
+
63
+ self.cond_transform = transforms.Compose(
64
+ [
65
+ transforms.Resize(self.img_size),
66
+ transforms.ToTensor(),
67
+ ]
68
+ )
69
+
70
+ self.attn_transform_64 = transforms.Compose(
71
+ [
72
+ transforms.Resize(
73
+ (self.img_size[0] // 8, self.img_size[0] // 8)),
74
+ transforms.ToTensor(),
75
+ ]
76
+ )
77
+ self.attn_transform_32 = transforms.Compose(
78
+ [
79
+ transforms.Resize(
80
+ (self.img_size[0] // 16, self.img_size[0] // 16)),
81
+ transforms.ToTensor(),
82
+ ]
83
+ )
84
+ self.attn_transform_16 = transforms.Compose(
85
+ [
86
+ transforms.Resize(
87
+ (self.img_size[0] // 32, self.img_size[0] // 32)),
88
+ transforms.ToTensor(),
89
+ ]
90
+ )
91
+ self.attn_transform_8 = transforms.Compose(
92
+ [
93
+ transforms.Resize(
94
+ (self.img_size[0] // 64, self.img_size[0] // 64)),
95
+ transforms.ToTensor(),
96
+ ]
97
+ )
98
+
99
+ self.face_analysis = FaceAnalysis(
100
+ name="",
101
+ root=face_analysis_model_path,
102
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
103
+ )
104
+ self.face_analysis.prepare(ctx_id=0, det_size=(640, 640))
105
+
106
+ def preprocess(self, source_image_path: str, cache_dir: str, face_region_ratio: float):
107
+ """
108
+ Apply preprocessing to the source image to prepare for face analysis.
109
+
110
+ Parameters:
111
+ source_image_path (str): The path to the source image.
112
+ cache_dir (str): The directory to cache intermediate results.
113
+
114
+ Returns:
115
+ None
116
+ """
117
+ source_image = Image.open(source_image_path)
118
+ ref_image_pil = source_image.convert("RGB")
119
+ # 1. image augmentation
120
+ pixel_values_ref_img = self._augmentation(ref_image_pil, self.pixel_transform)
121
+
122
+ # 2.1 detect face
123
+ faces = self.face_analysis.get(cv2.cvtColor(np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR))
124
+ if not faces:
125
+ print("No faces detected in the image. Using the entire image as the face region.")
126
+ # Use the entire image as the face region
127
+ face = {
128
+ "bbox": [0, 0, ref_image_pil.width, ref_image_pil.height],
129
+ "embedding": np.zeros(512)
130
+ }
131
+ else:
132
+ # Sort faces by size and select the largest one
133
+ faces_sorted = sorted(faces, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), reverse=True)
134
+ face = faces_sorted[0] # Select the largest face
135
+
136
+ # 2.2 face embedding
137
+ face_emb = face["embedding"]
138
+
139
+ # 2.3 render face mask
140
+ get_mask(source_image_path, cache_dir, face_region_ratio)
141
+ file_name = os.path.basename(source_image_path).split(".")[0]
142
+ face_mask_pil = Image.open(
143
+ os.path.join(cache_dir, f"{file_name}_face_mask.png")).convert("RGB")
144
+
145
+ face_mask = self._augmentation(face_mask_pil, self.cond_transform)
146
+
147
+ # 2.4 detect and expand lip, face mask
148
+ sep_background_mask = Image.open(
149
+ os.path.join(cache_dir, f"{file_name}_sep_background.png"))
150
+ sep_face_mask = Image.open(
151
+ os.path.join(cache_dir, f"{file_name}_sep_face.png"))
152
+ sep_lip_mask = Image.open(
153
+ os.path.join(cache_dir, f"{file_name}_sep_lip.png"))
154
+
155
+ pixel_values_face_mask = [
156
+ self._augmentation(sep_face_mask, self.attn_transform_64),
157
+ self._augmentation(sep_face_mask, self.attn_transform_32),
158
+ self._augmentation(sep_face_mask, self.attn_transform_16),
159
+ self._augmentation(sep_face_mask, self.attn_transform_8),
160
+ ]
161
+ pixel_values_lip_mask = [
162
+ self._augmentation(sep_lip_mask, self.attn_transform_64),
163
+ self._augmentation(sep_lip_mask, self.attn_transform_32),
164
+ self._augmentation(sep_lip_mask, self.attn_transform_16),
165
+ self._augmentation(sep_lip_mask, self.attn_transform_8),
166
+ ]
167
+ pixel_values_full_mask = [
168
+ self._augmentation(sep_background_mask, self.attn_transform_64),
169
+ self._augmentation(sep_background_mask, self.attn_transform_32),
170
+ self._augmentation(sep_background_mask, self.attn_transform_16),
171
+ self._augmentation(sep_background_mask, self.attn_transform_8),
172
+ ]
173
+
174
+ pixel_values_full_mask = [mask.view(1, -1)
175
+ for mask in pixel_values_full_mask]
176
+ pixel_values_face_mask = [mask.view(1, -1)
177
+ for mask in pixel_values_face_mask]
178
+ pixel_values_lip_mask = [mask.view(1, -1)
179
+ for mask in pixel_values_lip_mask]
180
+
181
+ return pixel_values_ref_img, face_mask, face_emb, pixel_values_full_mask, pixel_values_face_mask, pixel_values_lip_mask
182
+
183
+ def close(self):
184
+ """
185
+ Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance.
186
+
187
+ Args:
188
+ self: The ImageProcessor instance.
189
+
190
+ Returns:
191
+ None.
192
+ """
193
+ for _, model in self.face_analysis.models.items():
194
+ if hasattr(model, "Dispose"):
195
+ model.Dispose()
196
+
197
+ def _augmentation(self, images, transform, state=None):
198
+ if state is not None:
199
+ torch.set_rng_state(state)
200
+ if isinstance(images, List):
201
+ transformed_images = [transform(img) for img in images]
202
+ ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
203
+ else:
204
+ ret_tensor = transform(images) # (c, h, w)
205
+ return ret_tensor
206
+
207
+ def __enter__(self):
208
+ return self
209
+
210
+ def __exit__(self, _exc_type, _exc_val, _exc_tb):
211
+ self.close()
212
+
213
+
214
+ class ImageProcessorForDataProcessing():
215
+ """
216
+ ImageProcessor is a class responsible for processing images, particularly for face-related tasks.
217
+ It takes in an image and performs various operations such as augmentation, face detection,
218
+ face embedding extraction, and rendering a face mask. The processed images are then used for
219
+ further analysis or recognition purposes.
220
+
221
+ Attributes:
222
+ img_size (int): The size of the image to be processed.
223
+ face_analysis_model_path (str): The path to the face analysis model.
224
+
225
+ Methods:
226
+ preprocess(source_image_path, cache_dir):
227
+ Preprocesses the input image by performing augmentation, face detection,
228
+ face embedding extraction, and rendering a face mask.
229
+
230
+ close():
231
+ Closes the ImageProcessor and releases any resources being used.
232
+
233
+ _augmentation(images, transform, state=None):
234
+ Applies image augmentation to the input images using the given transform and state.
235
+
236
+ __enter__():
237
+ Enters a runtime context and returns the ImageProcessor object.
238
+
239
+ __exit__(_exc_type, _exc_val, _exc_tb):
240
+ Exits a runtime context and handles any exceptions that occurred during the processing.
241
+ """
242
+ def __init__(self, face_analysis_model_path, landmark_model_path, step) -> None:
243
+ if step == 2:
244
+ self.face_analysis = FaceAnalysis(
245
+ name="",
246
+ root=face_analysis_model_path,
247
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
248
+ )
249
+ self.face_analysis.prepare(ctx_id=0, det_size=(640, 640))
250
+ self.landmarker = None
251
+ else:
252
+ BaseOptions = mp.tasks.BaseOptions
253
+ FaceLandmarker = mp.tasks.vision.FaceLandmarker
254
+ FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
255
+ VisionRunningMode = mp.tasks.vision.RunningMode
256
+ # Create a face landmarker instance with the video mode:
257
+ options = FaceLandmarkerOptions(
258
+ base_options=BaseOptions(model_asset_path=landmark_model_path),
259
+ running_mode=VisionRunningMode.IMAGE,
260
+ )
261
+ self.landmarker = FaceLandmarker.create_from_options(options)
262
+ self.face_analysis = None
263
+
264
+ def preprocess(self, source_image_path: str):
265
+ """
266
+ Apply preprocessing to the source image to prepare for face analysis.
267
+
268
+ Parameters:
269
+ source_image_path (str): The path to the source image.
270
+ cache_dir (str): The directory to cache intermediate results.
271
+
272
+ Returns:
273
+ None
274
+ """
275
+ # 1. get face embdeding
276
+ face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask = None, None, None, None, None
277
+ if self.face_analysis:
278
+ for frame in sorted(os.listdir(source_image_path)):
279
+ try:
280
+ source_image = Image.open(
281
+ os.path.join(source_image_path, frame))
282
+ ref_image_pil = source_image.convert("RGB")
283
+ # 2.1 detect face
284
+ faces = self.face_analysis.get(cv2.cvtColor(
285
+ np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR))
286
+ # use max size face
287
+ face = sorted(faces, key=lambda x: (
288
+ x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1]
289
+ # 2.2 face embedding
290
+ face_emb = face["embedding"]
291
+ if face_emb is not None:
292
+ break
293
+ except Exception as _:
294
+ continue
295
+
296
+ if self.landmarker:
297
+ # 3.1 get landmark
298
+ landmarks, height, width = get_landmark_overframes(
299
+ self.landmarker, source_image_path)
300
+ assert len(landmarks) == len(os.listdir(source_image_path))
301
+
302
+ # 3 render face and lip mask
303
+ face_mask = get_union_face_mask(landmarks, height, width)
304
+ lip_mask = get_union_lip_mask(landmarks, height, width)
305
+
306
+ # 4 gaussian blur
307
+ blur_face_mask = blur_mask(face_mask, (64, 64), (51, 51))
308
+ blur_lip_mask = blur_mask(lip_mask, (64, 64), (31, 31))
309
+
310
+ # 5 seperate mask
311
+ sep_face_mask = cv2.subtract(blur_face_mask, blur_lip_mask)
312
+ sep_pose_mask = 255.0 - blur_face_mask
313
+ sep_lip_mask = blur_lip_mask
314
+
315
+ return face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask
316
+
317
+ def close(self):
318
+ """
319
+ Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance.
320
+
321
+ Args:
322
+ self: The ImageProcessor instance.
323
+
324
+ Returns:
325
+ None.
326
+ """
327
+ for _, model in self.face_analysis.models.items():
328
+ if hasattr(model, "Dispose"):
329
+ model.Dispose()
330
+
331
+ def _augmentation(self, images, transform, state=None):
332
+ if state is not None:
333
+ torch.set_rng_state(state)
334
+ if isinstance(images, List):
335
+ transformed_images = [transform(img) for img in images]
336
+ ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
337
+ else:
338
+ ret_tensor = transform(images) # (c, h, w)
339
+ return ret_tensor
340
+
341
+ def __enter__(self):
342
+ return self
343
+
344
+ def __exit__(self, _exc_type, _exc_val, _exc_tb):
345
+ self.close()
joyhallo/datasets/mask_image.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains the code for a dataset class called FaceMaskDataset, which is used to process and
3
+ load image data related to face masks. The dataset class inherits from the PyTorch Dataset class and
4
+ provides methods for data augmentation, getting items from the dataset, and determining the length of the
5
+ dataset. The module also includes imports for necessary libraries such as json, random, pathlib, torch,
6
+ PIL, and transformers.
7
+ """
8
+
9
+ import json
10
+ import random
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ from PIL import Image
15
+ from torch.utils.data import Dataset
16
+ from torchvision import transforms
17
+ from transformers import CLIPImageProcessor
18
+
19
+
20
+ class FaceMaskDataset(Dataset):
21
+ """
22
+ FaceMaskDataset is a custom dataset for face mask images.
23
+
24
+ Args:
25
+ img_size (int): The size of the input images.
26
+ drop_ratio (float, optional): The ratio of dropped pixels during data augmentation. Defaults to 0.1.
27
+ data_meta_paths (list, optional): The paths to the metadata files containing image paths and labels. Defaults to ["./data/HDTF_meta.json"].
28
+ sample_margin (int, optional): The margin for sampling regions in the image. Defaults to 30.
29
+
30
+ Attributes:
31
+ img_size (int): The size of the input images.
32
+ drop_ratio (float): The ratio of dropped pixels during data augmentation.
33
+ data_meta_paths (list): The paths to the metadata files containing image paths and labels.
34
+ sample_margin (int): The margin for sampling regions in the image.
35
+ processor (CLIPImageProcessor): The image processor for preprocessing images.
36
+ transform (transforms.Compose): The image augmentation transform.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ img_size,
42
+ drop_ratio=0.1,
43
+ data_meta_paths=None,
44
+ sample_margin=30,
45
+ ):
46
+ super().__init__()
47
+
48
+ self.img_size = img_size
49
+ self.sample_margin = sample_margin
50
+
51
+ vid_meta = []
52
+ for data_meta_path in data_meta_paths:
53
+ with open(data_meta_path, "r", encoding="utf-8") as f:
54
+ vid_meta.extend(json.load(f))
55
+ self.vid_meta = vid_meta
56
+ self.length = len(self.vid_meta)
57
+
58
+ self.clip_image_processor = CLIPImageProcessor()
59
+
60
+ self.transform = transforms.Compose(
61
+ [
62
+ transforms.Resize(self.img_size),
63
+ transforms.ToTensor(),
64
+ transforms.Normalize([0.5], [0.5]),
65
+ ]
66
+ )
67
+
68
+ self.cond_transform = transforms.Compose(
69
+ [
70
+ transforms.Resize(self.img_size),
71
+ transforms.ToTensor(),
72
+ ]
73
+ )
74
+
75
+ self.drop_ratio = drop_ratio
76
+
77
+ def augmentation(self, image, transform, state=None):
78
+ """
79
+ Apply data augmentation to the input image.
80
+
81
+ Args:
82
+ image (PIL.Image): The input image.
83
+ transform (torchvision.transforms.Compose): The data augmentation transforms.
84
+ state (dict, optional): The random state for reproducibility. Defaults to None.
85
+
86
+ Returns:
87
+ PIL.Image: The augmented image.
88
+ """
89
+ if state is not None:
90
+ torch.set_rng_state(state)
91
+ return transform(image)
92
+
93
+ def __getitem__(self, index):
94
+ video_meta = self.vid_meta[index]
95
+ video_path = video_meta["image_path"]
96
+ mask_path = video_meta["mask_path"]
97
+ face_emb_path = video_meta["face_emb"]
98
+
99
+ video_frames = sorted(Path(video_path).iterdir())
100
+ video_length = len(video_frames)
101
+
102
+ margin = min(self.sample_margin, video_length)
103
+
104
+ ref_img_idx = random.randint(0, video_length - 1)
105
+ if ref_img_idx + margin < video_length:
106
+ tgt_img_idx = random.randint(
107
+ ref_img_idx + margin, video_length - 1)
108
+ elif ref_img_idx - margin > 0:
109
+ tgt_img_idx = random.randint(0, ref_img_idx - margin)
110
+ else:
111
+ tgt_img_idx = random.randint(0, video_length - 1)
112
+
113
+ ref_img_pil = Image.open(video_frames[ref_img_idx])
114
+ tgt_img_pil = Image.open(video_frames[tgt_img_idx])
115
+
116
+ tgt_mask_pil = Image.open(mask_path)
117
+
118
+ assert ref_img_pil is not None, "Fail to load reference image."
119
+ assert tgt_img_pil is not None, "Fail to load target image."
120
+ assert tgt_mask_pil is not None, "Fail to load target mask."
121
+
122
+ state = torch.get_rng_state()
123
+ tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
124
+ tgt_mask_img = self.augmentation(
125
+ tgt_mask_pil, self.cond_transform, state)
126
+ tgt_mask_img = tgt_mask_img.repeat(3, 1, 1)
127
+ ref_img_vae = self.augmentation(
128
+ ref_img_pil, self.transform, state)
129
+ face_emb = torch.load(face_emb_path)
130
+
131
+
132
+ sample = {
133
+ "video_dir": video_path,
134
+ "img": tgt_img,
135
+ "tgt_mask": tgt_mask_img,
136
+ "ref_img": ref_img_vae,
137
+ "face_emb": face_emb,
138
+ }
139
+
140
+ return sample
141
+
142
+ def __len__(self):
143
+ return len(self.vid_meta)
144
+
145
+
146
+ if __name__ == "__main__":
147
+ data = FaceMaskDataset(img_size=(512, 512))
148
+ train_dataloader = torch.utils.data.DataLoader(
149
+ data, batch_size=4, shuffle=True, num_workers=1
150
+ )
151
+ for step, batch in enumerate(train_dataloader):
152
+ print(batch["tgt_mask"].shape)
153
+ break
joyhallo/datasets/talk_video.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ talking_video_dataset.py
3
+
4
+ This module defines the TalkingVideoDataset class, a custom PyTorch dataset
5
+ for handling talking video data. The dataset uses video files, masks, and
6
+ embeddings to prepare data for tasks such as video generation and
7
+ speech-driven video animation.
8
+
9
+ Classes:
10
+ TalkingVideoDataset
11
+
12
+ Dependencies:
13
+ json
14
+ random
15
+ torch
16
+ decord.VideoReader, decord.cpu
17
+ PIL.Image
18
+ torch.utils.data.Dataset
19
+ torchvision.transforms
20
+
21
+ Example:
22
+ from talking_video_dataset import TalkingVideoDataset
23
+ from torch.utils.data import DataLoader
24
+
25
+ # Example configuration for the Wav2Vec model
26
+ class Wav2VecConfig:
27
+ def __init__(self, audio_type, model_scale, features):
28
+ self.audio_type = audio_type
29
+ self.model_scale = model_scale
30
+ self.features = features
31
+
32
+ wav2vec_cfg = Wav2VecConfig(audio_type="wav2vec2", model_scale="base", features="feature")
33
+
34
+ # Initialize dataset
35
+ dataset = TalkingVideoDataset(
36
+ img_size=(512, 512),
37
+ sample_rate=16000,
38
+ audio_margin=2,
39
+ n_motion_frames=0,
40
+ n_sample_frames=16,
41
+ data_meta_paths=["path/to/meta1.json", "path/to/meta2.json"],
42
+ wav2vec_cfg=wav2vec_cfg,
43
+ )
44
+
45
+ # Initialize dataloader
46
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
47
+
48
+ # Fetch one batch of data
49
+ batch = next(iter(dataloader))
50
+ print(batch["pixel_values_vid"].shape) # Example output: (4, 16, 3, 512, 512)
51
+
52
+ The TalkingVideoDataset class provides methods for loading video frames, masks,
53
+ audio embeddings, and other relevant data, applying transformations, and preparing
54
+ the data for training and evaluation in a deep learning pipeline.
55
+
56
+ Attributes:
57
+ img_size (tuple): The dimensions to resize the video frames to.
58
+ sample_rate (int): The audio sample rate.
59
+ audio_margin (int): The margin for audio sampling.
60
+ n_motion_frames (int): The number of motion frames.
61
+ n_sample_frames (int): The number of sample frames.
62
+ data_meta_paths (list): List of paths to the JSON metadata files.
63
+ wav2vec_cfg (object): Configuration for the Wav2Vec model.
64
+
65
+ Methods:
66
+ augmentation(images, transform, state=None): Apply transformation to input images.
67
+ __getitem__(index): Get a sample from the dataset at the specified index.
68
+ __len__(): Return the length of the dataset.
69
+ """
70
+
71
+ import json
72
+ import random
73
+ from typing import List
74
+
75
+ import torch
76
+ from decord import VideoReader, cpu
77
+ from PIL import Image
78
+ from torch.utils.data import Dataset
79
+ from torchvision import transforms
80
+
81
+
82
+ class TalkingVideoDataset(Dataset):
83
+ """
84
+ A dataset class for processing talking video data.
85
+
86
+ Args:
87
+ img_size (tuple, optional): The size of the output images. Defaults to (512, 512).
88
+ sample_rate (int, optional): The sample rate of the audio data. Defaults to 16000.
89
+ audio_margin (int, optional): The margin for the audio data. Defaults to 2.
90
+ n_motion_frames (int, optional): The number of motion frames. Defaults to 0.
91
+ n_sample_frames (int, optional): The number of sample frames. Defaults to 16.
92
+ data_meta_paths (list, optional): The paths to the data metadata. Defaults to None.
93
+ wav2vec_cfg (dict, optional): The configuration for the wav2vec model. Defaults to None.
94
+
95
+ Attributes:
96
+ img_size (tuple): The size of the output images.
97
+ sample_rate (int): The sample rate of the audio data.
98
+ audio_margin (int): The margin for the audio data.
99
+ n_motion_frames (int): The number of motion frames.
100
+ n_sample_frames (int): The number of sample frames.
101
+ data_meta_paths (list): The paths to the data metadata.
102
+ wav2vec_cfg (dict): The configuration for the wav2vec model.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ img_size=(512, 512),
108
+ sample_rate=16000,
109
+ audio_margin=2,
110
+ n_motion_frames=0,
111
+ n_sample_frames=16,
112
+ data_meta_paths=None,
113
+ wav2vec_cfg=None,
114
+ ):
115
+ super().__init__()
116
+ self.sample_rate = sample_rate
117
+ self.img_size = img_size
118
+ self.audio_margin = audio_margin
119
+ self.n_motion_frames = n_motion_frames
120
+ self.n_sample_frames = n_sample_frames
121
+ self.audio_type = wav2vec_cfg.audio_type
122
+ self.audio_model = wav2vec_cfg.model_scale
123
+ self.audio_features = wav2vec_cfg.features
124
+
125
+ vid_meta = []
126
+ for data_meta_path in data_meta_paths:
127
+ with open(data_meta_path, "r", encoding="utf-8") as f:
128
+ vid_meta.extend(json.load(f))
129
+ self.vid_meta = vid_meta
130
+ self.length = len(self.vid_meta)
131
+ self.pixel_transform = transforms.Compose(
132
+ [
133
+ transforms.Resize(self.img_size),
134
+ transforms.ToTensor(),
135
+ transforms.Normalize([0.5], [0.5]),
136
+ ]
137
+ )
138
+
139
+ self.cond_transform = transforms.Compose(
140
+ [
141
+ transforms.Resize(self.img_size),
142
+ transforms.ToTensor(),
143
+ ]
144
+ )
145
+ self.attn_transform_64 = transforms.Compose(
146
+ [
147
+ transforms.Resize(
148
+ (self.img_size[0] // 8, self.img_size[0] // 8)),
149
+ transforms.ToTensor(),
150
+ ]
151
+ )
152
+ self.attn_transform_32 = transforms.Compose(
153
+ [
154
+ transforms.Resize(
155
+ (self.img_size[0] // 16, self.img_size[0] // 16)),
156
+ transforms.ToTensor(),
157
+ ]
158
+ )
159
+ self.attn_transform_16 = transforms.Compose(
160
+ [
161
+ transforms.Resize(
162
+ (self.img_size[0] // 32, self.img_size[0] // 32)),
163
+ transforms.ToTensor(),
164
+ ]
165
+ )
166
+ self.attn_transform_8 = transforms.Compose(
167
+ [
168
+ transforms.Resize(
169
+ (self.img_size[0] // 64, self.img_size[0] // 64)),
170
+ transforms.ToTensor(),
171
+ ]
172
+ )
173
+
174
+ def augmentation(self, images, transform, state=None):
175
+ """
176
+ Apply the given transformation to the input images.
177
+
178
+ Args:
179
+ images (List[PIL.Image] or PIL.Image): The input images to be transformed.
180
+ transform (torchvision.transforms.Compose): The transformation to be applied to the images.
181
+ state (torch.ByteTensor, optional): The state of the random number generator.
182
+ If provided, it will set the RNG state to this value before applying the transformation. Defaults to None.
183
+
184
+ Returns:
185
+ torch.Tensor: The transformed images as a tensor.
186
+ If the input was a list of images, the tensor will have shape (f, c, h, w),
187
+ where f is the number of images, c is the number of channels, h is the height, and w is the width.
188
+ If the input was a single image, the tensor will have shape (c, h, w),
189
+ where c is the number of channels, h is the height, and w is the width.
190
+ """
191
+ if state is not None:
192
+ torch.set_rng_state(state)
193
+ if isinstance(images, List):
194
+ transformed_images = [transform(img) for img in images]
195
+ ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
196
+ else:
197
+ ret_tensor = transform(images) # (c, h, w)
198
+ return ret_tensor
199
+
200
+ def __getitem__(self, index):
201
+ video_meta = self.vid_meta[index]
202
+ video_path = video_meta["video_path"]
203
+ mask_path = video_meta["mask_path"]
204
+ lip_mask_union_path = video_meta.get("sep_mask_lip", None)
205
+ face_mask_union_path = video_meta.get("sep_mask_face", None)
206
+ full_mask_union_path = video_meta.get("sep_mask_border", None)
207
+ face_emb_path = video_meta["face_emb_path"]
208
+ audio_emb_path = video_meta[
209
+ f"{self.audio_type}_emb_{self.audio_model}_{self.audio_features}"
210
+ ]
211
+ tgt_mask_pil = Image.open(mask_path)
212
+ video_frames = VideoReader(video_path, ctx=cpu(0))
213
+ assert tgt_mask_pil is not None, "Fail to load target mask."
214
+ assert (video_frames is not None and len(video_frames) > 0), "Fail to load video frames."
215
+
216
+ # 提前加载的位置,确认长度
217
+ audio_emb = torch.load(audio_emb_path)
218
+
219
+ # print(len(video_frames), len(audio_emb))
220
+ # 避免长度不一致,超索引范围
221
+ video_length = min(len(video_frames), len(audio_emb))
222
+
223
+ assert (
224
+ video_length
225
+ > self.n_sample_frames + self.n_motion_frames + 2 * self.audio_margin
226
+ )
227
+ start_idx = random.randint(
228
+ self.n_motion_frames,
229
+ video_length - self.n_sample_frames - self.audio_margin - 1,
230
+ )
231
+
232
+ videos = video_frames[start_idx : start_idx + self.n_sample_frames]
233
+
234
+ frame_list = [
235
+ Image.fromarray(video).convert("RGB") for video in videos.asnumpy()
236
+ ]
237
+
238
+ face_masks_list = [Image.open(face_mask_union_path)] * self.n_sample_frames
239
+ lip_masks_list = [Image.open(lip_mask_union_path)] * self.n_sample_frames
240
+ full_masks_list = [Image.open(full_mask_union_path)] * self.n_sample_frames
241
+ assert face_masks_list[0] is not None, "Fail to load face mask."
242
+ assert lip_masks_list[0] is not None, "Fail to load lip mask."
243
+ assert full_masks_list[0] is not None, "Fail to load full mask."
244
+
245
+
246
+ face_emb = torch.load(face_emb_path)
247
+
248
+ indices = (
249
+ torch.arange(2 * self.audio_margin + 1) - self.audio_margin
250
+ ) # Generates [-2, -1, 0, 1, 2]
251
+ center_indices = torch.arange(
252
+ start_idx,
253
+ start_idx + self.n_sample_frames,
254
+ ).unsqueeze(1) + indices.unsqueeze(0)
255
+ audio_tensor = audio_emb[center_indices]
256
+
257
+ ref_img_idx = random.randint(
258
+ self.n_motion_frames,
259
+ video_length - self.n_sample_frames - self.audio_margin - 1,
260
+ )
261
+ ref_img = video_frames[ref_img_idx].asnumpy()
262
+ ref_img = Image.fromarray(ref_img)
263
+
264
+ if self.n_motion_frames > 0:
265
+ motions = video_frames[start_idx - self.n_motion_frames : start_idx]
266
+ motion_list = [
267
+ Image.fromarray(motion).convert("RGB") for motion in motions.asnumpy()
268
+ ]
269
+
270
+ # transform
271
+ state = torch.get_rng_state()
272
+ pixel_values_vid = self.augmentation(frame_list, self.pixel_transform, state)
273
+
274
+ pixel_values_mask = self.augmentation(tgt_mask_pil, self.cond_transform, state)
275
+ pixel_values_mask = pixel_values_mask.repeat(3, 1, 1)
276
+
277
+ pixel_values_face_mask = [
278
+ self.augmentation(face_masks_list, self.attn_transform_64, state),
279
+ self.augmentation(face_masks_list, self.attn_transform_32, state),
280
+ self.augmentation(face_masks_list, self.attn_transform_16, state),
281
+ self.augmentation(face_masks_list, self.attn_transform_8, state),
282
+ ]
283
+ pixel_values_lip_mask = [
284
+ self.augmentation(lip_masks_list, self.attn_transform_64, state),
285
+ self.augmentation(lip_masks_list, self.attn_transform_32, state),
286
+ self.augmentation(lip_masks_list, self.attn_transform_16, state),
287
+ self.augmentation(lip_masks_list, self.attn_transform_8, state),
288
+ ]
289
+ pixel_values_full_mask = [
290
+ self.augmentation(full_masks_list, self.attn_transform_64, state),
291
+ self.augmentation(full_masks_list, self.attn_transform_32, state),
292
+ self.augmentation(full_masks_list, self.attn_transform_16, state),
293
+ self.augmentation(full_masks_list, self.attn_transform_8, state),
294
+ ]
295
+
296
+ pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
297
+ pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
298
+ if self.n_motion_frames > 0:
299
+ pixel_values_motion = self.augmentation(
300
+ motion_list, self.pixel_transform, state
301
+ )
302
+ pixel_values_ref_img = torch.cat(
303
+ [pixel_values_ref_img, pixel_values_motion], dim=0
304
+ )
305
+
306
+ sample = {
307
+ "video_dir": video_path,
308
+ "pixel_values_vid": pixel_values_vid,
309
+ "pixel_values_mask": pixel_values_mask,
310
+ "pixel_values_face_mask": pixel_values_face_mask,
311
+ "pixel_values_lip_mask": pixel_values_lip_mask,
312
+ "pixel_values_full_mask": pixel_values_full_mask,
313
+ "audio_tensor": audio_tensor,
314
+ "pixel_values_ref_img": pixel_values_ref_img,
315
+ "face_emb": face_emb,
316
+ }
317
+
318
+ return sample
319
+
320
+ def __len__(self):
321
+ return len(self.vid_meta)
joyhallo/models/__init__.py ADDED
File without changes
joyhallo/models/attention.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains various transformer blocks for different applications, such as BasicTransformerBlock,
3
+ TemporalBasicTransformerBlock, and AudioTemporalBasicTransformerBlock. These blocks are used in various models,
4
+ such as GLIGEN, UNet, and others. The transformer blocks implement self-attention, cross-attention, feed-forward
5
+ networks, and other related functions.
6
+
7
+ Functions and classes included in this module are:
8
+ - BasicTransformerBlock: A basic transformer block with self-attention, cross-attention, and feed-forward layers.
9
+ - TemporalBasicTransformerBlock: A transformer block with additional temporal attention mechanisms for video data.
10
+ - AudioTemporalBasicTransformerBlock: A transformer block with additional audio-specific mechanisms for audio data.
11
+ - zero_module: A function to zero out the parameters of a given module.
12
+
13
+ For more information on each specific class and function, please refer to the respective docstrings.
14
+ """
15
+
16
+ from typing import Any, Dict, List, Optional
17
+
18
+ import torch
19
+ from diffusers.models.attention import (AdaLayerNorm, AdaLayerNormZero,
20
+ Attention, FeedForward)
21
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
22
+ from einops import rearrange
23
+ from torch import nn
24
+
25
+
26
+ class GatedSelfAttentionDense(nn.Module):
27
+ """
28
+ A gated self-attention dense layer that combines visual features and object features.
29
+
30
+ Parameters:
31
+ query_dim (`int`): The number of channels in the query.
32
+ context_dim (`int`): The number of channels in the context.
33
+ n_heads (`int`): The number of heads to use for attention.
34
+ d_head (`int`): The number of channels in each head.
35
+ """
36
+
37
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
38
+ super().__init__()
39
+
40
+ # we need a linear projection since we need cat visual feature and obj feature
41
+ self.linear = nn.Linear(context_dim, query_dim)
42
+
43
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
44
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
45
+
46
+ self.norm1 = nn.LayerNorm(query_dim)
47
+ self.norm2 = nn.LayerNorm(query_dim)
48
+
49
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
50
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
51
+
52
+ self.enabled = True
53
+
54
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ Apply the Gated Self-Attention mechanism to the input tensor `x` and object tensor `objs`.
57
+
58
+ Args:
59
+ x (torch.Tensor): The input tensor.
60
+ objs (torch.Tensor): The object tensor.
61
+
62
+ Returns:
63
+ torch.Tensor: The output tensor after applying Gated Self-Attention.
64
+ """
65
+ if not self.enabled:
66
+ return x
67
+
68
+ n_visual = x.shape[1]
69
+ objs = self.linear(objs)
70
+
71
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
72
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
73
+
74
+ return x
75
+
76
+ class BasicTransformerBlock(nn.Module):
77
+ r"""
78
+ A basic Transformer block.
79
+
80
+ Parameters:
81
+ dim (`int`): The number of channels in the input and output.
82
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
83
+ attention_head_dim (`int`): The number of channels in each head.
84
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
85
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
86
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
87
+ num_embeds_ada_norm (:
88
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
89
+ attention_bias (:
90
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
91
+ only_cross_attention (`bool`, *optional*):
92
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
93
+ double_self_attention (`bool`, *optional*):
94
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
95
+ upcast_attention (`bool`, *optional*):
96
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
97
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
98
+ Whether to use learnable elementwise affine parameters for normalization.
99
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
100
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
101
+ final_dropout (`bool` *optional*, defaults to False):
102
+ Whether to apply a final dropout after the last feed-forward layer.
103
+ attention_type (`str`, *optional*, defaults to `"default"`):
104
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
105
+ positional_embeddings (`str`, *optional*, defaults to `None`):
106
+ The type of positional embeddings to apply to.
107
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
108
+ The maximum number of positional embeddings to apply.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ dim: int,
114
+ num_attention_heads: int,
115
+ attention_head_dim: int,
116
+ dropout=0.0,
117
+ cross_attention_dim: Optional[int] = None,
118
+ activation_fn: str = "geglu",
119
+ num_embeds_ada_norm: Optional[int] = None,
120
+ attention_bias: bool = False,
121
+ only_cross_attention: bool = False,
122
+ double_self_attention: bool = False,
123
+ upcast_attention: bool = False,
124
+ norm_elementwise_affine: bool = True,
125
+ # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
126
+ norm_type: str = "layer_norm",
127
+ norm_eps: float = 1e-5,
128
+ final_dropout: bool = False,
129
+ attention_type: str = "default",
130
+ positional_embeddings: Optional[str] = None,
131
+ num_positional_embeddings: Optional[int] = None,
132
+ ):
133
+ super().__init__()
134
+ self.only_cross_attention = only_cross_attention
135
+
136
+ self.use_ada_layer_norm_zero = (
137
+ num_embeds_ada_norm is not None
138
+ ) and norm_type == "ada_norm_zero"
139
+ self.use_ada_layer_norm = (
140
+ num_embeds_ada_norm is not None
141
+ ) and norm_type == "ada_norm"
142
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
143
+ self.use_layer_norm = norm_type == "layer_norm"
144
+
145
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
146
+ raise ValueError(
147
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
148
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
149
+ )
150
+
151
+ if positional_embeddings and (num_positional_embeddings is None):
152
+ raise ValueError(
153
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
154
+ )
155
+
156
+ if positional_embeddings == "sinusoidal":
157
+ self.pos_embed = SinusoidalPositionalEmbedding(
158
+ dim, max_seq_length=num_positional_embeddings
159
+ )
160
+ else:
161
+ self.pos_embed = None
162
+
163
+ # Define 3 blocks. Each block has its own normalization layer.
164
+ # 1. Self-Attn
165
+ if self.use_ada_layer_norm:
166
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
167
+ elif self.use_ada_layer_norm_zero:
168
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
169
+ else:
170
+ self.norm1 = nn.LayerNorm(
171
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
172
+ )
173
+
174
+ self.attn1 = Attention(
175
+ query_dim=dim,
176
+ heads=num_attention_heads,
177
+ dim_head=attention_head_dim,
178
+ dropout=dropout,
179
+ bias=attention_bias,
180
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
181
+ upcast_attention=upcast_attention,
182
+ )
183
+
184
+ # 2. Cross-Attn
185
+ if cross_attention_dim is not None or double_self_attention:
186
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
187
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
188
+ # the second cross attention block.
189
+ self.norm2 = (
190
+ AdaLayerNorm(dim, num_embeds_ada_norm)
191
+ if self.use_ada_layer_norm
192
+ else nn.LayerNorm(
193
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
194
+ )
195
+ )
196
+ self.attn2 = Attention(
197
+ query_dim=dim,
198
+ cross_attention_dim=(
199
+ cross_attention_dim if not double_self_attention else None
200
+ ),
201
+ heads=num_attention_heads,
202
+ dim_head=attention_head_dim,
203
+ dropout=dropout,
204
+ bias=attention_bias,
205
+ upcast_attention=upcast_attention,
206
+ ) # is self-attn if encoder_hidden_states is none
207
+ else:
208
+ self.norm2 = None
209
+ self.attn2 = None
210
+
211
+ # 3. Feed-forward
212
+ if not self.use_ada_layer_norm_single:
213
+ self.norm3 = nn.LayerNorm(
214
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
215
+ )
216
+
217
+ self.ff = FeedForward(
218
+ dim,
219
+ dropout=dropout,
220
+ activation_fn=activation_fn,
221
+ final_dropout=final_dropout,
222
+ )
223
+
224
+ # 4. Fuser
225
+ if attention_type in {"gated", "gated-text-image"}: # Updated line
226
+ self.fuser = GatedSelfAttentionDense(
227
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
228
+ )
229
+
230
+ # 5. Scale-shift for PixArt-Alpha.
231
+ if self.use_ada_layer_norm_single:
232
+ self.scale_shift_table = nn.Parameter(
233
+ torch.randn(6, dim) / dim**0.5)
234
+
235
+ # let chunk size default to None
236
+ self._chunk_size = None
237
+ self._chunk_dim = 0
238
+
239
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
240
+ """
241
+ Sets the chunk size for feed-forward processing in the transformer block.
242
+
243
+ Args:
244
+ chunk_size (Optional[int]): The size of the chunks to process in feed-forward layers.
245
+ If None, the chunk size is set to the maximum possible value.
246
+ dim (int, optional): The dimension along which to split the input tensor into chunks. Defaults to 0.
247
+
248
+ Returns:
249
+ None.
250
+ """
251
+ self._chunk_size = chunk_size
252
+ self._chunk_dim = dim
253
+
254
+ def forward(
255
+ self,
256
+ hidden_states: torch.FloatTensor,
257
+ attention_mask: Optional[torch.FloatTensor] = None,
258
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
259
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
260
+ timestep: Optional[torch.LongTensor] = None,
261
+ cross_attention_kwargs: Dict[str, Any] = None,
262
+ class_labels: Optional[torch.LongTensor] = None,
263
+ ) -> torch.FloatTensor:
264
+ """
265
+ This function defines the forward pass of the BasicTransformerBlock.
266
+
267
+ Args:
268
+ self (BasicTransformerBlock):
269
+ An instance of the BasicTransformerBlock class.
270
+ hidden_states (torch.FloatTensor):
271
+ A tensor containing the hidden states.
272
+ attention_mask (Optional[torch.FloatTensor], optional):
273
+ A tensor containing the attention mask. Defaults to None.
274
+ encoder_hidden_states (Optional[torch.FloatTensor], optional):
275
+ A tensor containing the encoder hidden states. Defaults to None.
276
+ encoder_attention_mask (Optional[torch.FloatTensor], optional):
277
+ A tensor containing the encoder attention mask. Defaults to None.
278
+ timestep (Optional[torch.LongTensor], optional):
279
+ A tensor containing the timesteps. Defaults to None.
280
+ cross_attention_kwargs (Dict[str, Any], optional):
281
+ Additional cross-attention arguments. Defaults to None.
282
+ class_labels (Optional[torch.LongTensor], optional):
283
+ A tensor containing the class labels. Defaults to None.
284
+
285
+ Returns:
286
+ torch.FloatTensor:
287
+ A tensor containing the transformed hidden states.
288
+ """
289
+ # Notice that normalization is always applied before the real computation in the following blocks.
290
+ # 0. Self-Attention
291
+ batch_size = hidden_states.shape[0]
292
+
293
+ gate_msa = None
294
+ scale_mlp = None
295
+ shift_mlp = None
296
+ gate_mlp = None
297
+ if self.use_ada_layer_norm:
298
+ norm_hidden_states = self.norm1(hidden_states, timestep)
299
+ elif self.use_ada_layer_norm_zero:
300
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
301
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
302
+ )
303
+ elif self.use_layer_norm:
304
+ norm_hidden_states = self.norm1(hidden_states)
305
+ elif self.use_ada_layer_norm_single:
306
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
307
+ self.scale_shift_table[None] +
308
+ timestep.reshape(batch_size, 6, -1)
309
+ ).chunk(6, dim=1)
310
+ norm_hidden_states = self.norm1(hidden_states)
311
+ norm_hidden_states = norm_hidden_states * \
312
+ (1 + scale_msa) + shift_msa
313
+ norm_hidden_states = norm_hidden_states.squeeze(1)
314
+ else:
315
+ raise ValueError("Incorrect norm used")
316
+
317
+ if self.pos_embed is not None:
318
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
319
+
320
+ # 1. Retrieve lora scale.
321
+ lora_scale = (
322
+ cross_attention_kwargs.get("scale", 1.0)
323
+ if cross_attention_kwargs is not None
324
+ else 1.0
325
+ )
326
+
327
+ # 2. Prepare GLIGEN inputs
328
+ cross_attention_kwargs = (
329
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
330
+ )
331
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
332
+
333
+ attn_output = self.attn1(
334
+ norm_hidden_states,
335
+ encoder_hidden_states=(
336
+ encoder_hidden_states if self.only_cross_attention else None
337
+ ),
338
+ attention_mask=attention_mask,
339
+ **cross_attention_kwargs,
340
+ )
341
+ if self.use_ada_layer_norm_zero:
342
+ attn_output = gate_msa.unsqueeze(1) * attn_output
343
+ elif self.use_ada_layer_norm_single:
344
+ attn_output = gate_msa * attn_output
345
+
346
+ hidden_states = attn_output + hidden_states
347
+ if hidden_states.ndim == 4:
348
+ hidden_states = hidden_states.squeeze(1)
349
+
350
+ # 2.5 GLIGEN Control
351
+ if gligen_kwargs is not None:
352
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
353
+
354
+ # 3. Cross-Attention
355
+ if self.attn2 is not None:
356
+ if self.use_ada_layer_norm:
357
+ norm_hidden_states = self.norm2(hidden_states, timestep)
358
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
359
+ norm_hidden_states = self.norm2(hidden_states)
360
+ elif self.use_ada_layer_norm_single:
361
+ # For PixArt norm2 isn't applied here:
362
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
363
+ norm_hidden_states = hidden_states
364
+ else:
365
+ raise ValueError("Incorrect norm")
366
+
367
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
368
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
369
+
370
+ attn_output = self.attn2(
371
+ norm_hidden_states,
372
+ encoder_hidden_states=encoder_hidden_states,
373
+ attention_mask=encoder_attention_mask,
374
+ **cross_attention_kwargs,
375
+ )
376
+ hidden_states = attn_output + hidden_states
377
+
378
+ # 4. Feed-forward
379
+ if not self.use_ada_layer_norm_single:
380
+ norm_hidden_states = self.norm3(hidden_states)
381
+
382
+ if self.use_ada_layer_norm_zero:
383
+ norm_hidden_states = (
384
+ norm_hidden_states *
385
+ (1 + scale_mlp[:, None]) + shift_mlp[:, None]
386
+ )
387
+
388
+ if self.use_ada_layer_norm_single:
389
+ norm_hidden_states = self.norm2(hidden_states)
390
+ norm_hidden_states = norm_hidden_states * \
391
+ (1 + scale_mlp) + shift_mlp
392
+
393
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
394
+
395
+ if self.use_ada_layer_norm_zero:
396
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
397
+ elif self.use_ada_layer_norm_single:
398
+ ff_output = gate_mlp * ff_output
399
+
400
+ hidden_states = ff_output + hidden_states
401
+ if hidden_states.ndim == 4:
402
+ hidden_states = hidden_states.squeeze(1)
403
+
404
+ return hidden_states
405
+
406
+
407
+ class TemporalBasicTransformerBlock(nn.Module):
408
+ """
409
+ A PyTorch module that extends the BasicTransformerBlock to include temporal attention mechanisms.
410
+ This class is particularly useful for video-related tasks where capturing temporal information within the sequence of frames is necessary.
411
+
412
+ Attributes:
413
+ dim (int): The dimension of the input and output embeddings.
414
+ num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism.
415
+ attention_head_dim (int): The dimension of each attention head.
416
+ dropout (float): The dropout probability for the attention scores.
417
+ cross_attention_dim (Optional[int]): The dimension of the cross-attention mechanism.
418
+ activation_fn (str): The activation function used in the feed-forward layer.
419
+ num_embeds_ada_norm (Optional[int]): The number of embeddings for adaptive normalization.
420
+ attention_bias (bool): If True, uses bias in the attention mechanism.
421
+ only_cross_attention (bool): If True, only uses cross-attention.
422
+ upcast_attention (bool): If True, upcasts the attention mechanism for better performance.
423
+ unet_use_cross_frame_attention (Optional[bool]): If True, uses cross-frame attention in the UNet model.
424
+ unet_use_temporal_attention (Optional[bool]): If True, uses temporal attention in the UNet model.
425
+ """
426
+ def __init__(
427
+ self,
428
+ dim: int,
429
+ num_attention_heads: int,
430
+ attention_head_dim: int,
431
+ dropout=0.0,
432
+ cross_attention_dim: Optional[int] = None,
433
+ activation_fn: str = "geglu",
434
+ num_embeds_ada_norm: Optional[int] = None,
435
+ attention_bias: bool = False,
436
+ only_cross_attention: bool = False,
437
+ upcast_attention: bool = False,
438
+ unet_use_cross_frame_attention=None,
439
+ unet_use_temporal_attention=None,
440
+ ):
441
+ """
442
+ The TemporalBasicTransformerBlock class is a PyTorch module that extends the BasicTransformerBlock to include temporal attention mechanisms.
443
+ This is particularly useful for video-related tasks, where the model needs to capture the temporal information within the sequence of frames.
444
+ The block consists of self-attention, cross-attention, feed-forward, and temporal attention mechanisms.
445
+
446
+ dim (int): The dimension of the input and output embeddings.
447
+ num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism.
448
+ attention_head_dim (int): The dimension of each attention head.
449
+ dropout (float, optional): The dropout probability for the attention scores. Defaults to 0.0.
450
+ cross_attention_dim (int, optional): The dimension of the cross-attention mechanism. Defaults to None.
451
+ activation_fn (str, optional): The activation function used in the feed-forward layer. Defaults to "geglu".
452
+ num_embeds_ada_norm (int, optional): The number of embeddings for adaptive normalization. Defaults to None.
453
+ attention_bias (bool, optional): If True, uses bias in the attention mechanism. Defaults to False.
454
+ only_cross_attention (bool, optional): If True, only uses cross-attention. Defaults to False.
455
+ upcast_attention (bool, optional): If True, upcasts the attention mechanism for better performance. Defaults to False.
456
+ unet_use_cross_frame_attention (bool, optional): If True, uses cross-frame attention in the UNet model. Defaults to None.
457
+ unet_use_temporal_attention (bool, optional): If True, uses temporal attention in the UNet model. Defaults to None.
458
+
459
+ Forward method:
460
+ hidden_states (torch.FloatTensor): The input hidden states.
461
+ encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states. Defaults to None.
462
+ timestep (torch.LongTensor, optional): The current timestep for the transformer model. Defaults to None.
463
+ attention_mask (torch.FloatTensor, optional): The attention mask for the self-attention mechanism. Defaults to None.
464
+ video_length (int, optional): The length of the video sequence. Defaults to None.
465
+
466
+ Returns:
467
+ torch.FloatTensor: The output hidden states after passing through the TemporalBasicTransformerBlock.
468
+ """
469
+ super().__init__()
470
+ self.only_cross_attention = only_cross_attention
471
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
472
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
473
+ self.unet_use_temporal_attention = unet_use_temporal_attention
474
+
475
+ # SC-Attn
476
+ self.attn1 = Attention(
477
+ query_dim=dim,
478
+ heads=num_attention_heads,
479
+ dim_head=attention_head_dim,
480
+ dropout=dropout,
481
+ bias=attention_bias,
482
+ upcast_attention=upcast_attention,
483
+ )
484
+ self.norm1 = (
485
+ AdaLayerNorm(dim, num_embeds_ada_norm)
486
+ if self.use_ada_layer_norm
487
+ else nn.LayerNorm(dim)
488
+ )
489
+
490
+ # Cross-Attn
491
+ if cross_attention_dim is not None:
492
+ self.attn2 = Attention(
493
+ query_dim=dim,
494
+ cross_attention_dim=cross_attention_dim,
495
+ heads=num_attention_heads,
496
+ dim_head=attention_head_dim,
497
+ dropout=dropout,
498
+ bias=attention_bias,
499
+ upcast_attention=upcast_attention,
500
+ )
501
+ else:
502
+ self.attn2 = None
503
+
504
+ if cross_attention_dim is not None:
505
+ self.norm2 = (
506
+ AdaLayerNorm(dim, num_embeds_ada_norm)
507
+ if self.use_ada_layer_norm
508
+ else nn.LayerNorm(dim)
509
+ )
510
+ else:
511
+ self.norm2 = None
512
+
513
+ # Feed-forward
514
+ self.ff = FeedForward(dim, dropout=dropout,
515
+ activation_fn=activation_fn)
516
+ self.norm3 = nn.LayerNorm(dim)
517
+ self.use_ada_layer_norm_zero = False
518
+
519
+ # Temp-Attn
520
+ # assert unet_use_temporal_attention is not None
521
+ if unet_use_temporal_attention is None:
522
+ unet_use_temporal_attention = False
523
+ if unet_use_temporal_attention:
524
+ self.attn_temp = Attention(
525
+ query_dim=dim,
526
+ heads=num_attention_heads,
527
+ dim_head=attention_head_dim,
528
+ dropout=dropout,
529
+ bias=attention_bias,
530
+ upcast_attention=upcast_attention,
531
+ )
532
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
533
+ self.norm_temp = (
534
+ AdaLayerNorm(dim, num_embeds_ada_norm)
535
+ if self.use_ada_layer_norm
536
+ else nn.LayerNorm(dim)
537
+ )
538
+
539
+ def forward(
540
+ self,
541
+ hidden_states,
542
+ encoder_hidden_states=None,
543
+ timestep=None,
544
+ attention_mask=None,
545
+ video_length=None,
546
+ ):
547
+ """
548
+ Forward pass for the TemporalBasicTransformerBlock.
549
+
550
+ Args:
551
+ hidden_states (torch.FloatTensor): The input hidden states with shape (batch_size, seq_len, dim).
552
+ encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states with shape (batch_size, src_seq_len, dim).
553
+ timestep (torch.LongTensor, optional): The timestep for the transformer block.
554
+ attention_mask (torch.FloatTensor, optional): The attention mask with shape (batch_size, seq_len, seq_len).
555
+ video_length (int, optional): The length of the video sequence.
556
+
557
+ Returns:
558
+ torch.FloatTensor: The output tensor after passing through the transformer block with shape (batch_size, seq_len, dim).
559
+ """
560
+ norm_hidden_states = (
561
+ self.norm1(hidden_states, timestep)
562
+ if self.use_ada_layer_norm
563
+ else self.norm1(hidden_states)
564
+ )
565
+
566
+ if self.unet_use_cross_frame_attention:
567
+ hidden_states = (
568
+ self.attn1(
569
+ norm_hidden_states,
570
+ attention_mask=attention_mask,
571
+ video_length=video_length,
572
+ )
573
+ + hidden_states
574
+ )
575
+ else:
576
+ hidden_states = (
577
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
578
+ + hidden_states
579
+ )
580
+
581
+ if self.attn2 is not None:
582
+ # Cross-Attention
583
+ norm_hidden_states = (
584
+ self.norm2(hidden_states, timestep)
585
+ if self.use_ada_layer_norm
586
+ else self.norm2(hidden_states)
587
+ )
588
+ hidden_states = (
589
+ self.attn2(
590
+ norm_hidden_states,
591
+ encoder_hidden_states=encoder_hidden_states,
592
+ attention_mask=attention_mask,
593
+ )
594
+ + hidden_states
595
+ )
596
+
597
+ # Feed-forward
598
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
599
+
600
+ # Temporal-Attention
601
+ if self.unet_use_temporal_attention:
602
+ d = hidden_states.shape[1]
603
+ hidden_states = rearrange(
604
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
605
+ )
606
+ norm_hidden_states = (
607
+ self.norm_temp(hidden_states, timestep)
608
+ if self.use_ada_layer_norm
609
+ else self.norm_temp(hidden_states)
610
+ )
611
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
612
+ hidden_states = rearrange(
613
+ hidden_states, "(b d) f c -> (b f) d c", d=d)
614
+
615
+ return hidden_states
616
+
617
+
618
+ class AudioTemporalBasicTransformerBlock(nn.Module):
619
+ """
620
+ A PyTorch module designed to handle audio data within a transformer framework, including temporal attention mechanisms.
621
+
622
+ Attributes:
623
+ dim (int): The dimension of the input and output embeddings.
624
+ num_attention_heads (int): The number of attention heads.
625
+ attention_head_dim (int): The dimension of each attention head.
626
+ dropout (float): The dropout probability.
627
+ cross_attention_dim (Optional[int]): The dimension of the cross-attention mechanism.
628
+ activation_fn (str): The activation function for the feed-forward network.
629
+ num_embeds_ada_norm (Optional[int]): The number of embeddings for adaptive normalization.
630
+ attention_bias (bool): If True, uses bias in the attention mechanism.
631
+ only_cross_attention (bool): If True, only uses cross-attention.
632
+ upcast_attention (bool): If True, upcasts the attention mechanism to float32.
633
+ unet_use_cross_frame_attention (Optional[bool]): If True, uses cross-frame attention in UNet.
634
+ unet_use_temporal_attention (Optional[bool]): If True, uses temporal attention in UNet.
635
+ depth (int): The depth of the transformer block.
636
+ unet_block_name (Optional[str]): The name of the UNet block.
637
+ stack_enable_blocks_name (Optional[List[str]]): The list of enabled blocks in the stack.
638
+ stack_enable_blocks_depth (Optional[List[int]]): The list of depths for the enabled blocks in the stack.
639
+ """
640
+ def __init__(
641
+ self,
642
+ dim: int,
643
+ num_attention_heads: int,
644
+ attention_head_dim: int,
645
+ dropout=0.0,
646
+ cross_attention_dim: Optional[int] = None,
647
+ activation_fn: str = "geglu",
648
+ num_embeds_ada_norm: Optional[int] = None,
649
+ attention_bias: bool = False,
650
+ only_cross_attention: bool = False,
651
+ upcast_attention: bool = False,
652
+ unet_use_cross_frame_attention=None,
653
+ unet_use_temporal_attention=None,
654
+ depth=0,
655
+ unet_block_name=None,
656
+ stack_enable_blocks_name: Optional[List[str]] = None,
657
+ stack_enable_blocks_depth: Optional[List[int]] = None,
658
+ ):
659
+ """
660
+ Initializes the AudioTemporalBasicTransformerBlock module.
661
+
662
+ Args:
663
+ dim (int): The dimension of the input and output embeddings.
664
+ num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism.
665
+ attention_head_dim (int): The dimension of each attention head.
666
+ dropout (float, optional): The dropout probability for the attention mechanism. Defaults to 0.0.
667
+ cross_attention_dim (Optional[int], optional): The dimension of the cross-attention mechanism. Defaults to None.
668
+ activation_fn (str, optional): The activation function to be used in the feed-forward network. Defaults to "geglu".
669
+ num_embeds_ada_norm (Optional[int], optional): The number of embeddings for adaptive normalization. Defaults to None.
670
+ attention_bias (bool, optional): If True, uses bias in the attention mechanism. Defaults to False.
671
+ only_cross_attention (bool, optional): If True, only uses cross-attention. Defaults to False.
672
+ upcast_attention (bool, optional): If True, upcasts the attention mechanism to float32. Defaults to False.
673
+ unet_use_cross_frame_attention (Optional[bool], optional): If True, uses cross-frame attention in UNet. Defaults to None.
674
+ unet_use_temporal_attention (Optional[bool], optional): If True, uses temporal attention in UNet. Defaults to None.
675
+ depth (int, optional): The depth of the transformer block. Defaults to 0.
676
+ unet_block_name (Optional[str], optional): The name of the UNet block. Defaults to None.
677
+ stack_enable_blocks_name (Optional[List[str]], optional): The list of enabled blocks in the stack. Defaults to None.
678
+ stack_enable_blocks_depth (Optional[List[int]], optional): The list of depths for the enabled blocks in the stack. Defaults to None.
679
+ """
680
+ super().__init__()
681
+ self.only_cross_attention = only_cross_attention
682
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
683
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
684
+ self.unet_use_temporal_attention = unet_use_temporal_attention
685
+ self.unet_block_name = unet_block_name
686
+ self.depth = depth
687
+
688
+ zero_conv_full = nn.Conv2d(
689
+ dim, dim, kernel_size=1)
690
+ self.zero_conv_full = zero_module(zero_conv_full)
691
+
692
+ zero_conv_face = nn.Conv2d(
693
+ dim, dim, kernel_size=1)
694
+ self.zero_conv_face = zero_module(zero_conv_face)
695
+
696
+ zero_conv_lip = nn.Conv2d(
697
+ dim, dim, kernel_size=1)
698
+ self.zero_conv_lip = zero_module(zero_conv_lip)
699
+ # SC-Attn
700
+ self.attn1 = Attention(
701
+ query_dim=dim,
702
+ heads=num_attention_heads,
703
+ dim_head=attention_head_dim,
704
+ dropout=dropout,
705
+ bias=attention_bias,
706
+ upcast_attention=upcast_attention,
707
+ )
708
+ self.norm1 = (
709
+ AdaLayerNorm(dim, num_embeds_ada_norm)
710
+ if self.use_ada_layer_norm
711
+ else nn.LayerNorm(dim)
712
+ )
713
+
714
+ # Cross-Attn
715
+ if cross_attention_dim is not None:
716
+ if (stack_enable_blocks_name is not None and
717
+ stack_enable_blocks_depth is not None and
718
+ self.unet_block_name in stack_enable_blocks_name and
719
+ self.depth in stack_enable_blocks_depth):
720
+ self.attn2_0 = Attention(
721
+ query_dim=dim,
722
+ cross_attention_dim=cross_attention_dim,
723
+ heads=num_attention_heads,
724
+ dim_head=attention_head_dim,
725
+ dropout=dropout,
726
+ bias=attention_bias,
727
+ upcast_attention=upcast_attention,
728
+ )
729
+ self.attn2 = None
730
+
731
+ else:
732
+ self.attn2 = Attention(
733
+ query_dim=dim,
734
+ cross_attention_dim=cross_attention_dim,
735
+ heads=num_attention_heads,
736
+ dim_head=attention_head_dim,
737
+ dropout=dropout,
738
+ bias=attention_bias,
739
+ upcast_attention=upcast_attention,
740
+ )
741
+ self.attn2_0=None
742
+ else:
743
+ self.attn2 = None
744
+ self.attn2_0 = None
745
+
746
+ if cross_attention_dim is not None:
747
+ self.norm2 = (
748
+ AdaLayerNorm(dim, num_embeds_ada_norm)
749
+ if self.use_ada_layer_norm
750
+ else nn.LayerNorm(dim)
751
+ )
752
+ else:
753
+ self.norm2 = None
754
+
755
+ # Feed-forward
756
+ self.ff = FeedForward(dim, dropout=dropout,
757
+ activation_fn=activation_fn)
758
+ self.norm3 = nn.LayerNorm(dim)
759
+ self.use_ada_layer_norm_zero = False
760
+
761
+
762
+
763
+ def forward(
764
+ self,
765
+ hidden_states,
766
+ encoder_hidden_states=None,
767
+ timestep=None,
768
+ attention_mask=None,
769
+ full_mask=None,
770
+ face_mask=None,
771
+ lip_mask=None,
772
+ motion_scale=None,
773
+ video_length=None,
774
+ ):
775
+ """
776
+ Forward pass for the AudioTemporalBasicTransformerBlock.
777
+
778
+ Args:
779
+ hidden_states (torch.FloatTensor): The input hidden states.
780
+ encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states. Defaults to None.
781
+ timestep (torch.LongTensor, optional): The timestep for the transformer block. Defaults to None.
782
+ attention_mask (torch.FloatTensor, optional): The attention mask. Defaults to None.
783
+ full_mask (torch.FloatTensor, optional): The full mask. Defaults to None.
784
+ face_mask (torch.FloatTensor, optional): The face mask. Defaults to None.
785
+ lip_mask (torch.FloatTensor, optional): The lip mask. Defaults to None.
786
+ video_length (int, optional): The length of the video. Defaults to None.
787
+
788
+ Returns:
789
+ torch.FloatTensor: The output tensor after passing through the AudioTemporalBasicTransformerBlock.
790
+ """
791
+ norm_hidden_states = (
792
+ self.norm1(hidden_states, timestep)
793
+ if self.use_ada_layer_norm
794
+ else self.norm1(hidden_states)
795
+ )
796
+
797
+ if self.unet_use_cross_frame_attention:
798
+ hidden_states = (
799
+ self.attn1(
800
+ norm_hidden_states,
801
+ attention_mask=attention_mask,
802
+ video_length=video_length,
803
+ )
804
+ + hidden_states
805
+ )
806
+ else:
807
+ hidden_states = (
808
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
809
+ + hidden_states
810
+ )
811
+
812
+ if self.attn2 is not None:
813
+ # Cross-Attention
814
+ norm_hidden_states = (
815
+ self.norm2(hidden_states, timestep)
816
+ if self.use_ada_layer_norm
817
+ else self.norm2(hidden_states)
818
+ )
819
+ hidden_states = self.attn2(
820
+ norm_hidden_states,
821
+ encoder_hidden_states=encoder_hidden_states,
822
+ attention_mask=attention_mask,
823
+ ) + hidden_states
824
+
825
+ elif self.attn2_0 is not None:
826
+ norm_hidden_states = (
827
+ self.norm2(hidden_states, timestep)
828
+ if self.use_ada_layer_norm
829
+ else self.norm2(hidden_states)
830
+ )
831
+
832
+ level = self.depth
833
+ all_hidden_states = self.attn2_0(
834
+ norm_hidden_states,
835
+ encoder_hidden_states=encoder_hidden_states,
836
+ attention_mask=attention_mask,
837
+ )
838
+
839
+ full_hidden_states = (
840
+ all_hidden_states * full_mask[level][:, :, None]
841
+ )
842
+ bz, sz, c = full_hidden_states.shape
843
+ sz_sqrt = int(sz ** 0.5)
844
+ full_hidden_states = full_hidden_states.reshape(
845
+ bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2)
846
+ full_hidden_states = self.zero_conv_full(full_hidden_states).permute(0, 2, 3, 1).reshape(bz, -1, c)
847
+
848
+ face_hidden_state = (
849
+ all_hidden_states * face_mask[level][:, :, None]
850
+ )
851
+ face_hidden_state = face_hidden_state.reshape(
852
+ bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2)
853
+ face_hidden_state = self.zero_conv_face(
854
+ face_hidden_state).permute(0, 2, 3, 1).reshape(bz, -1, c)
855
+
856
+ lip_hidden_state = (
857
+ all_hidden_states * lip_mask[level][:, :, None]
858
+ ) # [32, 4096, 320]
859
+ lip_hidden_state = lip_hidden_state.reshape(
860
+ bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2)
861
+ lip_hidden_state = self.zero_conv_lip(
862
+ lip_hidden_state).permute(0, 2, 3, 1).reshape(bz, -1, c)
863
+
864
+ if motion_scale is not None:
865
+ hidden_states = (
866
+ motion_scale[0] * full_hidden_states +
867
+ motion_scale[1] * face_hidden_state +
868
+ motion_scale[2] * lip_hidden_state + hidden_states
869
+ )
870
+ else:
871
+ hidden_states = (
872
+ full_hidden_states +
873
+ face_hidden_state +
874
+ lip_hidden_state + hidden_states
875
+ )
876
+ # Feed-forward
877
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
878
+
879
+ return hidden_states
880
+
881
+ def zero_module(module):
882
+ """
883
+ Zeroes out the parameters of a given module.
884
+
885
+ Args:
886
+ module (nn.Module): The module whose parameters need to be zeroed out.
887
+
888
+ Returns:
889
+ None.
890
+ """
891
+ for p in module.parameters():
892
+ nn.init.zeros_(p)
893
+ return module
joyhallo/models/audio_proj.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides the implementation of an Audio Projection Model, which is designed for
3
+ audio processing tasks. The model takes audio embeddings as input and outputs context tokens
4
+ that can be used for various downstream applications, such as audio analysis or synthesis.
5
+
6
+ The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
7
+ provides a foundation for building custom models. This implementation includes multiple linear
8
+ layers with ReLU activation functions and a LayerNorm for normalization.
9
+
10
+ Key Features:
11
+ - Audio embedding input with flexible sequence length and block structure.
12
+ - Multiple linear layers for feature transformation.
13
+ - ReLU activation for non-linear transformation.
14
+ - LayerNorm for stabilizing and speeding up training.
15
+ - Rearrangement of input embeddings to match the model's expected input shape.
16
+ - Customizable number of blocks, channels, and context tokens for adaptability.
17
+
18
+ The module is structured to be easily integrated into larger systems or used as a standalone
19
+ component for audio feature extraction and processing.
20
+
21
+ Classes:
22
+ - AudioProjModel: A class representing the audio projection model with configurable parameters.
23
+
24
+ Functions:
25
+ - (none)
26
+
27
+ Dependencies:
28
+ - torch: For tensor operations and neural network components.
29
+ - diffusers: For the ModelMixin base class.
30
+ - einops: For tensor rearrangement operations.
31
+
32
+ """
33
+
34
+ import torch
35
+ from diffusers import ModelMixin
36
+ from einops import rearrange
37
+ from torch import nn
38
+
39
+
40
+ class AudioProjModel(ModelMixin):
41
+ """Audio Projection Model
42
+
43
+ This class defines an audio projection model that takes audio embeddings as input
44
+ and produces context tokens as output. The model is based on the ModelMixin class
45
+ and consists of multiple linear layers and activation functions. It can be used
46
+ for various audio processing tasks.
47
+
48
+ Attributes:
49
+ seq_len (int): The length of the audio sequence.
50
+ blocks (int): The number of blocks in the audio projection model.
51
+ channels (int): The number of channels in the audio projection model.
52
+ intermediate_dim (int): The intermediate dimension of the model.
53
+ context_tokens (int): The number of context tokens in the output.
54
+ output_dim (int): The output dimension of the context tokens.
55
+
56
+ Methods:
57
+ __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
58
+ Initializes the AudioProjModel with the given parameters.
59
+ forward(self, audio_embeds):
60
+ Defines the forward pass for the AudioProjModel.
61
+ Parameters:
62
+ audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
63
+ Returns:
64
+ context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
65
+
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ seq_len=5,
71
+ blocks=12, # add a new parameter blocks
72
+ channels=768, # add a new parameter channels
73
+ intermediate_dim=512,
74
+ output_dim=768,
75
+ context_tokens=32,
76
+ ):
77
+ super().__init__()
78
+
79
+ self.seq_len = seq_len
80
+ self.blocks = blocks
81
+ self.channels = channels
82
+ self.input_dim = (
83
+ seq_len * blocks * channels
84
+ ) # update input_dim to be the product of blocks and channels.
85
+ self.intermediate_dim = intermediate_dim
86
+ self.context_tokens = context_tokens
87
+ self.output_dim = output_dim
88
+
89
+ # define multiple linear layers
90
+ self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
91
+ self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
92
+ self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
93
+
94
+ self.norm = nn.LayerNorm(output_dim)
95
+
96
+ def forward(self, audio_embeds):
97
+ """
98
+ Defines the forward pass for the AudioProjModel.
99
+
100
+ Parameters:
101
+ audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
102
+
103
+ Returns:
104
+ context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
105
+ """
106
+ # merge
107
+ video_length = audio_embeds.shape[1]
108
+ audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
109
+ batch_size, window_size, blocks, channels = audio_embeds.shape
110
+ audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
111
+
112
+ audio_embeds = torch.relu(self.proj1(audio_embeds))
113
+ audio_embeds = torch.relu(self.proj2(audio_embeds))
114
+
115
+ context_tokens = self.proj3(audio_embeds).reshape(
116
+ batch_size, self.context_tokens, self.output_dim
117
+ )
118
+
119
+ context_tokens = self.norm(context_tokens)
120
+ context_tokens = rearrange(
121
+ context_tokens, "(bz f) m c -> bz f m c", f=video_length
122
+ )
123
+
124
+ return context_tokens
joyhallo/models/face_locator.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements the FaceLocator class, which is a neural network model designed to
3
+ locate and extract facial features from input images or tensors. It uses a series of
4
+ convolutional layers to progressively downsample and refine the facial feature map.
5
+
6
+ The FaceLocator class is part of a larger system that may involve facial recognition or
7
+ similar tasks where precise location and extraction of facial features are required.
8
+
9
+ Attributes:
10
+ conditioning_embedding_channels (int): The number of channels in the output embedding.
11
+ conditioning_channels (int): The number of input channels for the conditioning tensor.
12
+ block_out_channels (Tuple[int]): A tuple of integers representing the output channels
13
+ for each block in the model.
14
+
15
+ The model uses the following components:
16
+ - InflatedConv3d: A convolutional layer that inflates the input to increase the depth.
17
+ - zero_module: A utility function that may set certain parameters to zero for regularization
18
+ or other purposes.
19
+
20
+ The forward method of the FaceLocator class takes a conditioning tensor as input and
21
+ produces an embedding tensor as output, which can be used for further processing or analysis.
22
+ """
23
+
24
+ from typing import Tuple
25
+
26
+ import torch.nn.functional as F
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from torch import nn
29
+
30
+ from .motion_module import zero_module
31
+ from .resnet import InflatedConv3d
32
+
33
+
34
+ class FaceLocator(ModelMixin):
35
+ """
36
+ The FaceLocator class is a neural network model designed to process and extract facial
37
+ features from an input tensor. It consists of a series of convolutional layers that
38
+ progressively downsample the input while increasing the depth of the feature map.
39
+
40
+ The model is built using InflatedConv3d layers, which are designed to inflate the
41
+ feature channels, allowing for more complex feature extraction. The final output is a
42
+ conditioning embedding that can be used for various tasks such as facial recognition or
43
+ feature-based image manipulation.
44
+
45
+ Parameters:
46
+ conditioning_embedding_channels (int): The number of channels in the output embedding.
47
+ conditioning_channels (int, optional): The number of input channels for the conditioning tensor. Default is 3.
48
+ block_out_channels (Tuple[int], optional): A tuple of integers representing the output channels
49
+ for each block in the model. The default is (16, 32, 64, 128), which defines the
50
+ progression of the network's depth.
51
+
52
+ Attributes:
53
+ conv_in (InflatedConv3d): The initial convolutional layer that starts the feature extraction process.
54
+ blocks (ModuleList[InflatedConv3d]): A list of convolutional layers that form the core of the model.
55
+ conv_out (InflatedConv3d): The final convolutional layer that produces the output embedding.
56
+
57
+ The forward method applies the convolutional layers to the input conditioning tensor and
58
+ returns the resulting embedding tensor.
59
+ """
60
+ def __init__(
61
+ self,
62
+ conditioning_embedding_channels: int,
63
+ conditioning_channels: int = 3,
64
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
65
+ ):
66
+ super().__init__()
67
+ self.conv_in = InflatedConv3d(
68
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
69
+ )
70
+
71
+ self.blocks = nn.ModuleList([])
72
+
73
+ for i in range(len(block_out_channels) - 1):
74
+ channel_in = block_out_channels[i]
75
+ channel_out = block_out_channels[i + 1]
76
+ self.blocks.append(
77
+ InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
78
+ )
79
+ self.blocks.append(
80
+ InflatedConv3d(
81
+ channel_in, channel_out, kernel_size=3, padding=1, stride=2
82
+ )
83
+ )
84
+
85
+ self.conv_out = zero_module(
86
+ InflatedConv3d(
87
+ block_out_channels[-1],
88
+ conditioning_embedding_channels,
89
+ kernel_size=3,
90
+ padding=1,
91
+ )
92
+ )
93
+
94
+ def forward(self, conditioning):
95
+ """
96
+ Forward pass of the FaceLocator model.
97
+
98
+ Args:
99
+ conditioning (Tensor): The input conditioning tensor.
100
+
101
+ Returns:
102
+ Tensor: The output embedding tensor.
103
+ """
104
+ embedding = self.conv_in(conditioning)
105
+ embedding = F.silu(embedding)
106
+
107
+ for block in self.blocks:
108
+ embedding = block(embedding)
109
+ embedding = F.silu(embedding)
110
+
111
+ embedding = self.conv_out(embedding)
112
+
113
+ return embedding
joyhallo/models/image_proj.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ image_proj_model.py
3
+
4
+ This module defines the ImageProjModel class, which is responsible for
5
+ projecting image embeddings into a different dimensional space. The model
6
+ leverages a linear transformation followed by a layer normalization to
7
+ reshape and normalize the input image embeddings for further processing in
8
+ cross-attention mechanisms or other downstream tasks.
9
+
10
+ Classes:
11
+ ImageProjModel
12
+
13
+ Dependencies:
14
+ torch
15
+ diffusers.ModelMixin
16
+
17
+ """
18
+
19
+ import torch
20
+ from diffusers import ModelMixin
21
+
22
+
23
+ class ImageProjModel(ModelMixin):
24
+ """
25
+ ImageProjModel is a class that projects image embeddings into a different
26
+ dimensional space. It inherits from ModelMixin, providing additional functionalities
27
+ specific to image projection.
28
+
29
+ Attributes:
30
+ cross_attention_dim (int): The dimension of the cross attention.
31
+ clip_embeddings_dim (int): The dimension of the CLIP embeddings.
32
+ clip_extra_context_tokens (int): The number of extra context tokens in CLIP.
33
+
34
+ Methods:
35
+ forward(image_embeds): Forward pass of the ImageProjModel, which takes in image
36
+ embeddings and returns the projected tokens.
37
+
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ cross_attention_dim=1024,
43
+ clip_embeddings_dim=1024,
44
+ clip_extra_context_tokens=4,
45
+ ):
46
+ super().__init__()
47
+
48
+ self.generator = None
49
+ self.cross_attention_dim = cross_attention_dim
50
+ self.clip_extra_context_tokens = clip_extra_context_tokens
51
+ self.proj = torch.nn.Linear(
52
+ clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
53
+ )
54
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
55
+
56
+ def forward(self, image_embeds):
57
+ """
58
+ Forward pass of the ImageProjModel, which takes in image embeddings and returns the
59
+ projected tokens after reshaping and normalization.
60
+
61
+ Args:
62
+ image_embeds (torch.Tensor): The input image embeddings, with shape
63
+ batch_size x num_image_tokens x clip_embeddings_dim.
64
+
65
+ Returns:
66
+ clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping
67
+ and normalization, with shape batch_size x (clip_extra_context_tokens *
68
+ cross_attention_dim).
69
+
70
+ """
71
+ embeds = image_embeds
72
+ clip_extra_context_tokens = self.proj(embeds).reshape(
73
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
74
+ )
75
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
76
+ return clip_extra_context_tokens
joyhallo/models/motion_module.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ temporal_transformers.py
3
+
4
+ This module provides classes and functions for implementing Temporal Transformers
5
+ in PyTorch, designed for handling video data and temporal sequences within transformer-based models.
6
+
7
+ Functions:
8
+ zero_module(module)
9
+ Zero out the parameters of a module and return it.
10
+
11
+ Classes:
12
+ TemporalTransformer3DModelOutput(BaseOutput)
13
+ Dataclass for storing the output of TemporalTransformer3DModel.
14
+
15
+ VanillaTemporalModule(nn.Module)
16
+ A Vanilla Temporal Module class for handling temporal data.
17
+
18
+ TemporalTransformer3DModel(nn.Module)
19
+ A Temporal Transformer 3D Model class for transforming temporal data.
20
+
21
+ TemporalTransformerBlock(nn.Module)
22
+ A Temporal Transformer Block class for building the transformer architecture.
23
+
24
+ PositionalEncoding(nn.Module)
25
+ A Positional Encoding module for transformers to encode positional information.
26
+
27
+ Dependencies:
28
+ math
29
+ dataclasses.dataclass
30
+ typing (Callable, Optional)
31
+ torch
32
+ diffusers (FeedForward, Attention, AttnProcessor)
33
+ diffusers.utils (BaseOutput)
34
+ diffusers.utils.import_utils (is_xformers_available)
35
+ einops (rearrange, repeat)
36
+ torch.nn
37
+ xformers
38
+ xformers.ops
39
+
40
+ Example Usage:
41
+ >>> motion_module = get_motion_module(in_channels=512, motion_module_type="Vanilla", motion_module_kwargs={})
42
+ >>> output = motion_module(input_tensor, temb, encoder_hidden_states)
43
+
44
+ This module is designed to facilitate the creation, training, and inference of transformer models
45
+ that operate on temporal data, such as videos or time-series. It includes mechanisms for applying temporal attention,
46
+ managing positional encoding, and integrating with external libraries for efficient attention operations.
47
+ """
48
+
49
+ # This code is copied from https://github.com/guoyww/AnimateDiff.
50
+
51
+ import math
52
+
53
+ import torch
54
+ import xformers
55
+ import xformers.ops
56
+ from diffusers.models.attention import FeedForward
57
+ from diffusers.models.attention_processor import Attention, AttnProcessor
58
+ from diffusers.utils import BaseOutput
59
+ from diffusers.utils.import_utils import is_xformers_available
60
+ from einops import rearrange, repeat
61
+ from torch import nn
62
+
63
+
64
+ def zero_module(module):
65
+ """
66
+ Zero out the parameters of a module and return it.
67
+
68
+ Args:
69
+ - module: A PyTorch module to zero out its parameters.
70
+
71
+ Returns:
72
+ A zeroed out PyTorch module.
73
+ """
74
+ for p in module.parameters():
75
+ p.detach().zero_()
76
+ return module
77
+
78
+
79
+ class TemporalTransformer3DModelOutput(BaseOutput):
80
+ """
81
+ Output class for the TemporalTransformer3DModel.
82
+
83
+ Attributes:
84
+ sample (torch.FloatTensor): The output sample tensor from the model.
85
+ """
86
+ sample: torch.FloatTensor
87
+
88
+ def get_sample_shape(self):
89
+ """
90
+ Returns the shape of the sample tensor.
91
+
92
+ Returns:
93
+ Tuple: The shape of the sample tensor.
94
+ """
95
+ return self.sample.shape
96
+
97
+
98
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
99
+ """
100
+ This function returns a motion module based on the given type and parameters.
101
+
102
+ Args:
103
+ - in_channels (int): The number of input channels for the motion module.
104
+ - motion_module_type (str): The type of motion module to create. Currently, only "Vanilla" is supported.
105
+ - motion_module_kwargs (dict): Additional keyword arguments to pass to the motion module constructor.
106
+
107
+ Returns:
108
+ VanillaTemporalModule: The created motion module.
109
+
110
+ Raises:
111
+ ValueError: If an unsupported motion_module_type is provided.
112
+ """
113
+ if motion_module_type == "Vanilla":
114
+ return VanillaTemporalModule(
115
+ in_channels=in_channels,
116
+ **motion_module_kwargs,
117
+ )
118
+
119
+ raise ValueError
120
+
121
+
122
+ class VanillaTemporalModule(nn.Module):
123
+ """
124
+ A Vanilla Temporal Module class.
125
+
126
+ Args:
127
+ - in_channels (int): The number of input channels for the motion module.
128
+ - num_attention_heads (int): Number of attention heads.
129
+ - num_transformer_block (int): Number of transformer blocks.
130
+ - attention_block_types (tuple): Types of attention blocks.
131
+ - cross_frame_attention_mode: Mode for cross-frame attention.
132
+ - temporal_position_encoding (bool): Flag for temporal position encoding.
133
+ - temporal_position_encoding_max_len (int): Maximum length for temporal position encoding.
134
+ - temporal_attention_dim_div (int): Divisor for temporal attention dimension.
135
+ - zero_initialize (bool): Flag for zero initialization.
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ in_channels,
141
+ num_attention_heads=8,
142
+ num_transformer_block=2,
143
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
144
+ cross_frame_attention_mode=None,
145
+ temporal_position_encoding=False,
146
+ temporal_position_encoding_max_len=24,
147
+ temporal_attention_dim_div=1,
148
+ zero_initialize=True,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.temporal_transformer = TemporalTransformer3DModel(
153
+ in_channels=in_channels,
154
+ num_attention_heads=num_attention_heads,
155
+ attention_head_dim=in_channels
156
+ // num_attention_heads
157
+ // temporal_attention_dim_div,
158
+ num_layers=num_transformer_block,
159
+ attention_block_types=attention_block_types,
160
+ cross_frame_attention_mode=cross_frame_attention_mode,
161
+ temporal_position_encoding=temporal_position_encoding,
162
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
163
+ )
164
+
165
+ if zero_initialize:
166
+ self.temporal_transformer.proj_out = zero_module(
167
+ self.temporal_transformer.proj_out
168
+ )
169
+
170
+ def forward(
171
+ self,
172
+ input_tensor,
173
+ encoder_hidden_states,
174
+ attention_mask=None,
175
+ ):
176
+ """
177
+ Forward pass of the TemporalTransformer3DModel.
178
+
179
+ Args:
180
+ hidden_states (torch.Tensor): The hidden states of the model.
181
+ encoder_hidden_states (torch.Tensor, optional): The hidden states of the encoder.
182
+ attention_mask (torch.Tensor, optional): The attention mask.
183
+
184
+ Returns:
185
+ torch.Tensor: The output tensor after the forward pass.
186
+ """
187
+ hidden_states = input_tensor
188
+ hidden_states = self.temporal_transformer(
189
+ hidden_states, encoder_hidden_states
190
+ )
191
+
192
+ output = hidden_states
193
+ return output
194
+
195
+
196
+ class TemporalTransformer3DModel(nn.Module):
197
+ """
198
+ A Temporal Transformer 3D Model class.
199
+
200
+ Args:
201
+ - in_channels (int): The number of input channels.
202
+ - num_attention_heads (int): Number of attention heads.
203
+ - attention_head_dim (int): Dimension of attention heads.
204
+ - num_layers (int): Number of transformer layers.
205
+ - attention_block_types (tuple): Types of attention blocks.
206
+ - dropout (float): Dropout rate.
207
+ - norm_num_groups (int): Number of groups for normalization.
208
+ - cross_attention_dim (int): Dimension for cross-attention.
209
+ - activation_fn (str): Activation function.
210
+ - attention_bias (bool): Flag for attention bias.
211
+ - upcast_attention (bool): Flag for upcast attention.
212
+ - cross_frame_attention_mode: Mode for cross-frame attention.
213
+ - temporal_position_encoding (bool): Flag for temporal position encoding.
214
+ - temporal_position_encoding_max_len (int): Maximum length for temporal position encoding.
215
+ """
216
+ def __init__(
217
+ self,
218
+ in_channels,
219
+ num_attention_heads,
220
+ attention_head_dim,
221
+ num_layers,
222
+ attention_block_types=(
223
+ "Temporal_Self",
224
+ "Temporal_Self",
225
+ ),
226
+ dropout=0.0,
227
+ norm_num_groups=32,
228
+ cross_attention_dim=768,
229
+ activation_fn="geglu",
230
+ attention_bias=False,
231
+ upcast_attention=False,
232
+ cross_frame_attention_mode=None,
233
+ temporal_position_encoding=False,
234
+ temporal_position_encoding_max_len=24,
235
+ ):
236
+ super().__init__()
237
+
238
+ inner_dim = num_attention_heads * attention_head_dim
239
+
240
+ self.norm = torch.nn.GroupNorm(
241
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
242
+ )
243
+ self.proj_in = nn.Linear(in_channels, inner_dim)
244
+
245
+ self.transformer_blocks = nn.ModuleList(
246
+ [
247
+ TemporalTransformerBlock(
248
+ dim=inner_dim,
249
+ num_attention_heads=num_attention_heads,
250
+ attention_head_dim=attention_head_dim,
251
+ attention_block_types=attention_block_types,
252
+ dropout=dropout,
253
+ cross_attention_dim=cross_attention_dim,
254
+ activation_fn=activation_fn,
255
+ attention_bias=attention_bias,
256
+ upcast_attention=upcast_attention,
257
+ cross_frame_attention_mode=cross_frame_attention_mode,
258
+ temporal_position_encoding=temporal_position_encoding,
259
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
260
+ )
261
+ for d in range(num_layers)
262
+ ]
263
+ )
264
+ self.proj_out = nn.Linear(inner_dim, in_channels)
265
+
266
+ def forward(self, hidden_states, encoder_hidden_states=None):
267
+ """
268
+ Forward pass for the TemporalTransformer3DModel.
269
+
270
+ Args:
271
+ hidden_states (torch.Tensor): The input hidden states with shape (batch_size, sequence_length, in_channels).
272
+ encoder_hidden_states (torch.Tensor, optional): The encoder hidden states with shape (batch_size, encoder_sequence_length, in_channels).
273
+
274
+ Returns:
275
+ torch.Tensor: The output hidden states with shape (batch_size, sequence_length, in_channels).
276
+ """
277
+ assert (
278
+ hidden_states.dim() == 5
279
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
280
+ video_length = hidden_states.shape[2]
281
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
282
+
283
+ batch, _, height, weight = hidden_states.shape
284
+ residual = hidden_states
285
+
286
+ hidden_states = self.norm(hidden_states)
287
+ inner_dim = hidden_states.shape[1]
288
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
289
+ batch, height * weight, inner_dim
290
+ )
291
+ hidden_states = self.proj_in(hidden_states)
292
+
293
+ # Transformer Blocks
294
+ for block in self.transformer_blocks:
295
+ hidden_states = block(
296
+ hidden_states,
297
+ encoder_hidden_states=encoder_hidden_states,
298
+ video_length=video_length,
299
+ )
300
+
301
+ # output
302
+ hidden_states = self.proj_out(hidden_states)
303
+ hidden_states = (
304
+ hidden_states.reshape(batch, height, weight, inner_dim)
305
+ .permute(0, 3, 1, 2)
306
+ .contiguous()
307
+ )
308
+
309
+ output = hidden_states + residual
310
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
311
+
312
+ return output
313
+
314
+
315
+ class TemporalTransformerBlock(nn.Module):
316
+ """
317
+ A Temporal Transformer Block class.
318
+
319
+ Args:
320
+ - dim (int): Dimension of the block.
321
+ - num_attention_heads (int): Number of attention heads.
322
+ - attention_head_dim (int): Dimension of attention heads.
323
+ - attention_block_types (tuple): Types of attention blocks.
324
+ - dropout (float): Dropout rate.
325
+ - cross_attention_dim (int): Dimension for cross-attention.
326
+ - activation_fn (str): Activation function.
327
+ - attention_bias (bool): Flag for attention bias.
328
+ - upcast_attention (bool): Flag for upcast attention.
329
+ - cross_frame_attention_mode: Mode for cross-frame attention.
330
+ - temporal_position_encoding (bool): Flag for temporal position encoding.
331
+ - temporal_position_encoding_max_len (int): Maximum length for temporal position encoding.
332
+ """
333
+ def __init__(
334
+ self,
335
+ dim,
336
+ num_attention_heads,
337
+ attention_head_dim,
338
+ attention_block_types=(
339
+ "Temporal_Self",
340
+ "Temporal_Self",
341
+ ),
342
+ dropout=0.0,
343
+ cross_attention_dim=768,
344
+ activation_fn="geglu",
345
+ attention_bias=False,
346
+ upcast_attention=False,
347
+ cross_frame_attention_mode=None,
348
+ temporal_position_encoding=False,
349
+ temporal_position_encoding_max_len=24,
350
+ ):
351
+ super().__init__()
352
+
353
+ attention_blocks = []
354
+ norms = []
355
+
356
+ for block_name in attention_block_types:
357
+ attention_blocks.append(
358
+ VersatileAttention(
359
+ attention_mode=block_name.split("_", maxsplit=1)[0],
360
+ cross_attention_dim=cross_attention_dim
361
+ if block_name.endswith("_Cross")
362
+ else None,
363
+ query_dim=dim,
364
+ heads=num_attention_heads,
365
+ dim_head=attention_head_dim,
366
+ dropout=dropout,
367
+ bias=attention_bias,
368
+ upcast_attention=upcast_attention,
369
+ cross_frame_attention_mode=cross_frame_attention_mode,
370
+ temporal_position_encoding=temporal_position_encoding,
371
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
372
+ )
373
+ )
374
+ norms.append(nn.LayerNorm(dim))
375
+
376
+ self.attention_blocks = nn.ModuleList(attention_blocks)
377
+ self.norms = nn.ModuleList(norms)
378
+
379
+ self.ff = FeedForward(dim, dropout=dropout,
380
+ activation_fn=activation_fn)
381
+ self.ff_norm = nn.LayerNorm(dim)
382
+
383
+ def forward(
384
+ self,
385
+ hidden_states,
386
+ encoder_hidden_states=None,
387
+ video_length=None,
388
+ ):
389
+ """
390
+ Forward pass for the TemporalTransformerBlock.
391
+
392
+ Args:
393
+ hidden_states (torch.Tensor): The input hidden states with shape
394
+ (batch_size, video_length, in_channels).
395
+ encoder_hidden_states (torch.Tensor, optional): The encoder hidden states
396
+ with shape (batch_size, encoder_length, in_channels).
397
+ video_length (int, optional): The length of the video.
398
+
399
+ Returns:
400
+ torch.Tensor: The output hidden states with shape
401
+ (batch_size, video_length, in_channels).
402
+ """
403
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
404
+ norm_hidden_states = norm(hidden_states)
405
+ hidden_states = (
406
+ attention_block(
407
+ norm_hidden_states,
408
+ encoder_hidden_states=encoder_hidden_states
409
+ if attention_block.is_cross_attention
410
+ else None,
411
+ video_length=video_length,
412
+ )
413
+ + hidden_states
414
+ )
415
+
416
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
417
+
418
+ output = hidden_states
419
+ return output
420
+
421
+
422
+ class PositionalEncoding(nn.Module):
423
+ """
424
+ Positional Encoding module for transformers.
425
+
426
+ Args:
427
+ - d_model (int): Model dimension.
428
+ - dropout (float): Dropout rate.
429
+ - max_len (int): Maximum length for positional encoding.
430
+ """
431
+ def __init__(self, d_model, dropout=0.0, max_len=24):
432
+ super().__init__()
433
+ self.dropout = nn.Dropout(p=dropout)
434
+ position = torch.arange(max_len).unsqueeze(1)
435
+ div_term = torch.exp(
436
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
437
+ )
438
+ pe = torch.zeros(1, max_len, d_model)
439
+ pe[0, :, 0::2] = torch.sin(position * div_term)
440
+ pe[0, :, 1::2] = torch.cos(position * div_term)
441
+ self.register_buffer("pe", pe)
442
+
443
+ def forward(self, x):
444
+ """
445
+ Forward pass of the PositionalEncoding module.
446
+
447
+ This method takes an input tensor `x` and adds the positional encoding to it. The positional encoding is
448
+ generated based on the input tensor's shape and is added to the input tensor element-wise.
449
+
450
+ Args:
451
+ x (torch.Tensor): The input tensor to be positionally encoded.
452
+
453
+ Returns:
454
+ torch.Tensor: The positionally encoded tensor.
455
+ """
456
+ x = x + self.pe[:, : x.size(1)]
457
+ return self.dropout(x)
458
+
459
+
460
+ class VersatileAttention(Attention):
461
+ """
462
+ Versatile Attention class.
463
+
464
+ Args:
465
+ - attention_mode: Attention mode.
466
+ - temporal_position_encoding (bool): Flag for temporal position encoding.
467
+ - temporal_position_encoding_max_len (int): Maximum length for temporal position encoding.
468
+ """
469
+ def __init__(
470
+ self,
471
+ *args,
472
+ attention_mode=None,
473
+ cross_frame_attention_mode=None,
474
+ temporal_position_encoding=False,
475
+ temporal_position_encoding_max_len=24,
476
+ **kwargs,
477
+ ):
478
+ super().__init__(*args, **kwargs)
479
+ assert attention_mode == "Temporal"
480
+
481
+ self.attention_mode = attention_mode
482
+ self.is_cross_attention = kwargs.get("cross_attention_dim") is not None
483
+
484
+ self.pos_encoder = (
485
+ PositionalEncoding(
486
+ kwargs["query_dim"],
487
+ dropout=0.0,
488
+ max_len=temporal_position_encoding_max_len,
489
+ )
490
+ if (temporal_position_encoding and attention_mode == "Temporal")
491
+ else None
492
+ )
493
+
494
+ def extra_repr(self):
495
+ """
496
+ Returns a string representation of the module with information about the attention mode and whether it is cross-attention.
497
+
498
+ Returns:
499
+ str: A string representation of the module.
500
+ """
501
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
502
+
503
+ def set_use_memory_efficient_attention_xformers(
504
+ self,
505
+ use_memory_efficient_attention_xformers: bool,
506
+ attention_op = None,
507
+ ):
508
+ """
509
+ Sets the use of memory-efficient attention xformers for the VersatileAttention class.
510
+
511
+ Args:
512
+ use_memory_efficient_attention_xformers (bool): A boolean flag indicating whether to use memory-efficient attention xformers or not.
513
+
514
+ Returns:
515
+ None
516
+
517
+ """
518
+ if use_memory_efficient_attention_xformers:
519
+ if not is_xformers_available():
520
+ raise ModuleNotFoundError(
521
+ (
522
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
523
+ " xformers"
524
+ ),
525
+ name="xformers",
526
+ )
527
+
528
+ if not torch.cuda.is_available():
529
+ raise ValueError(
530
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
531
+ " only available for GPU "
532
+ )
533
+
534
+ try:
535
+ # Make sure we can run the memory efficient attention
536
+ _ = xformers.ops.memory_efficient_attention(
537
+ torch.randn((1, 2, 40), device="cuda"),
538
+ torch.randn((1, 2, 40), device="cuda"),
539
+ torch.randn((1, 2, 40), device="cuda"),
540
+ )
541
+ except Exception as e:
542
+ raise e
543
+ processor = AttnProcessor()
544
+ else:
545
+ processor = AttnProcessor()
546
+
547
+ self.set_processor(processor)
548
+
549
+ def forward(
550
+ self,
551
+ hidden_states,
552
+ encoder_hidden_states=None,
553
+ attention_mask=None,
554
+ video_length=None,
555
+ **cross_attention_kwargs,
556
+ ):
557
+ """
558
+ Args:
559
+ hidden_states (`torch.Tensor`):
560
+ The hidden states to be passed through the model.
561
+ encoder_hidden_states (`torch.Tensor`, optional):
562
+ The encoder hidden states to be passed through the model.
563
+ attention_mask (`torch.Tensor`, optional):
564
+ The attention mask to be used in the model.
565
+ video_length (`int`, optional):
566
+ The length of the video.
567
+ cross_attention_kwargs (`dict`, optional):
568
+ Additional keyword arguments to be used for cross-attention.
569
+
570
+ Returns:
571
+ `torch.Tensor`:
572
+ The output tensor after passing through the model.
573
+
574
+ """
575
+ if self.attention_mode == "Temporal":
576
+ d = hidden_states.shape[1] # d means HxW
577
+ hidden_states = rearrange(
578
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
579
+ )
580
+
581
+ if self.pos_encoder is not None:
582
+ hidden_states = self.pos_encoder(hidden_states)
583
+
584
+ encoder_hidden_states = (
585
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
586
+ if encoder_hidden_states is not None
587
+ else encoder_hidden_states
588
+ )
589
+
590
+ else:
591
+ raise NotImplementedError
592
+
593
+ hidden_states = self.processor(
594
+ self,
595
+ hidden_states,
596
+ encoder_hidden_states=encoder_hidden_states,
597
+ attention_mask=attention_mask,
598
+ **cross_attention_kwargs,
599
+ )
600
+
601
+ if self.attention_mode == "Temporal":
602
+ hidden_states = rearrange(
603
+ hidden_states, "(b d) f c -> (b f) d c", d=d)
604
+
605
+ return hidden_states
joyhallo/models/mutual_self_attention.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains the implementation of mutual self-attention,
3
+ which is a type of attention mechanism used in deep learning models.
4
+ The module includes several classes and functions related to attention mechanisms,
5
+ such as BasicTransformerBlock and TemporalBasicTransformerBlock.
6
+ The main purpose of this module is to provide a comprehensive attention mechanism for various tasks in deep learning,
7
+ such as image and video processing, natural language processing, and so on.
8
+ """
9
+
10
+ from typing import Any, Dict, Optional
11
+
12
+ import torch
13
+ from einops import rearrange
14
+
15
+ from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
16
+
17
+
18
+ def torch_dfs(model: torch.nn.Module):
19
+ """
20
+ Perform a depth-first search (DFS) traversal on a PyTorch model's neural network architecture.
21
+
22
+ This function recursively traverses all the children modules of a given PyTorch model and returns a list
23
+ containing all the modules in the model's architecture. The DFS approach starts with the input model and
24
+ explores its children modules depth-wise before backtracking and exploring other branches.
25
+
26
+ Args:
27
+ model (torch.nn.Module): The root module of the neural network to traverse.
28
+
29
+ Returns:
30
+ list: A list of all the modules in the model's architecture.
31
+ """
32
+ result = [model]
33
+ for child in model.children():
34
+ result += torch_dfs(child)
35
+ return result
36
+
37
+
38
+ class ReferenceAttentionControl:
39
+ """
40
+ This class is used to control the reference attention mechanism in a neural network model.
41
+ It is responsible for managing the guidance and fusion blocks, and modifying the self-attention
42
+ and group normalization mechanisms. The class also provides methods for registering reference hooks
43
+ and updating/clearing the internal state of the attention control object.
44
+
45
+ Attributes:
46
+ unet: The UNet model associated with this attention control object.
47
+ mode: The operating mode of the attention control object, either 'write' or 'read'.
48
+ do_classifier_free_guidance: Whether to use classifier-free guidance in the attention mechanism.
49
+ attention_auto_machine_weight: The weight assigned to the attention auto-machine.
50
+ gn_auto_machine_weight: The weight assigned to the group normalization auto-machine.
51
+ style_fidelity: The style fidelity parameter for the attention mechanism.
52
+ reference_attn: Whether to use reference attention in the model.
53
+ reference_adain: Whether to use reference AdaIN in the model.
54
+ fusion_blocks: The type of fusion blocks to use in the model ('midup', 'late', or 'nofusion').
55
+ batch_size: The batch size used for processing video frames.
56
+
57
+ Methods:
58
+ register_reference_hooks: Registers the reference hooks for the attention control object.
59
+ hacked_basic_transformer_inner_forward: The modified inner forward method for the basic transformer block.
60
+ update: Updates the internal state of the attention control object using the provided writer and dtype.
61
+ clear: Clears the internal state of the attention control object.
62
+ """
63
+ def __init__(
64
+ self,
65
+ unet,
66
+ mode="write",
67
+ do_classifier_free_guidance=False,
68
+ attention_auto_machine_weight=float("inf"),
69
+ gn_auto_machine_weight=1.0,
70
+ style_fidelity=1.0,
71
+ reference_attn=True,
72
+ reference_adain=False,
73
+ fusion_blocks="midup",
74
+ batch_size=1,
75
+ ) -> None:
76
+ """
77
+ Initializes the ReferenceAttentionControl class.
78
+
79
+ Args:
80
+ unet (torch.nn.Module): The UNet model.
81
+ mode (str, optional): The mode of operation. Defaults to "write".
82
+ do_classifier_free_guidance (bool, optional): Whether to do classifier-free guidance. Defaults to False.
83
+ attention_auto_machine_weight (float, optional): The weight for attention auto-machine. Defaults to infinity.
84
+ gn_auto_machine_weight (float, optional): The weight for group-norm auto-machine. Defaults to 1.0.
85
+ style_fidelity (float, optional): The style fidelity. Defaults to 1.0.
86
+ reference_attn (bool, optional): Whether to use reference attention. Defaults to True.
87
+ reference_adain (bool, optional): Whether to use reference AdaIN. Defaults to False.
88
+ fusion_blocks (str, optional): The fusion blocks to use. Defaults to "midup".
89
+ batch_size (int, optional): The batch size. Defaults to 1.
90
+
91
+ Raises:
92
+ ValueError: If the mode is not recognized.
93
+ ValueError: If the fusion blocks are not recognized.
94
+ """
95
+ # 10. Modify self attention and group norm
96
+ self.unet = unet
97
+ assert mode in ["read", "write"]
98
+ assert fusion_blocks in ["midup", "full"]
99
+ self.reference_attn = reference_attn
100
+ self.reference_adain = reference_adain
101
+ self.fusion_blocks = fusion_blocks
102
+ self.register_reference_hooks(
103
+ mode,
104
+ do_classifier_free_guidance,
105
+ attention_auto_machine_weight,
106
+ gn_auto_machine_weight,
107
+ style_fidelity,
108
+ reference_attn,
109
+ reference_adain,
110
+ fusion_blocks,
111
+ batch_size=batch_size,
112
+ )
113
+
114
+ def register_reference_hooks(
115
+ self,
116
+ mode,
117
+ do_classifier_free_guidance,
118
+ _attention_auto_machine_weight,
119
+ _gn_auto_machine_weight,
120
+ _style_fidelity,
121
+ _reference_attn,
122
+ _reference_adain,
123
+ _dtype=torch.float16,
124
+ batch_size=1,
125
+ num_images_per_prompt=1,
126
+ device=torch.device("cpu"),
127
+ _fusion_blocks="midup",
128
+ ):
129
+ """
130
+ Registers reference hooks for the model.
131
+
132
+ This function is responsible for registering reference hooks in the model,
133
+ which are used to modify the attention mechanism and group normalization layers.
134
+ It takes various parameters as input, such as mode,
135
+ do_classifier_free_guidance, _attention_auto_machine_weight, _gn_auto_machine_weight, _style_fidelity,
136
+ _reference_attn, _reference_adain, _dtype, batch_size, num_images_per_prompt, device, and _fusion_blocks.
137
+
138
+ Args:
139
+ self: Reference to the instance of the class.
140
+ mode: The mode of operation for the reference hooks.
141
+ do_classifier_free_guidance: A boolean flag indicating whether to use classifier-free guidance.
142
+ _attention_auto_machine_weight: The weight for the attention auto-machine.
143
+ _gn_auto_machine_weight: The weight for the group normalization auto-machine.
144
+ _style_fidelity: The style fidelity for the reference hooks.
145
+ _reference_attn: A boolean flag indicating whether to use reference attention.
146
+ _reference_adain: A boolean flag indicating whether to use reference AdaIN.
147
+ _dtype: The data type for the reference hooks.
148
+ batch_size: The batch size for the reference hooks.
149
+ num_images_per_prompt: The number of images per prompt for the reference hooks.
150
+ device: The device for the reference hooks.
151
+ _fusion_blocks: The fusion blocks for the reference hooks.
152
+
153
+ Returns:
154
+ None
155
+ """
156
+ MODE = mode
157
+ if do_classifier_free_guidance:
158
+ uc_mask = (
159
+ torch.Tensor(
160
+ [1] * batch_size * num_images_per_prompt * 16
161
+ + [0] * batch_size * num_images_per_prompt * 16
162
+ )
163
+ .to(device)
164
+ .bool()
165
+ )
166
+ else:
167
+ uc_mask = (
168
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
169
+ .to(device)
170
+ .bool()
171
+ )
172
+
173
+ def hacked_basic_transformer_inner_forward(
174
+ self,
175
+ hidden_states: torch.FloatTensor,
176
+ attention_mask: Optional[torch.FloatTensor] = None,
177
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
178
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
179
+ timestep: Optional[torch.LongTensor] = None,
180
+ cross_attention_kwargs: Dict[str, Any] = None,
181
+ class_labels: Optional[torch.LongTensor] = None,
182
+ video_length=None,
183
+ ):
184
+ gate_msa = None
185
+ shift_mlp = None
186
+ scale_mlp = None
187
+ gate_mlp = None
188
+
189
+ if self.use_ada_layer_norm: # False
190
+ norm_hidden_states = self.norm1(hidden_states, timestep)
191
+ elif self.use_ada_layer_norm_zero:
192
+ (
193
+ norm_hidden_states,
194
+ gate_msa,
195
+ shift_mlp,
196
+ scale_mlp,
197
+ gate_mlp,
198
+ ) = self.norm1(
199
+ hidden_states,
200
+ timestep,
201
+ class_labels,
202
+ hidden_dtype=hidden_states.dtype,
203
+ )
204
+ else:
205
+ norm_hidden_states = self.norm1(hidden_states)
206
+
207
+ # 1. Self-Attention
208
+ # self.only_cross_attention = False
209
+ cross_attention_kwargs = (
210
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
211
+ )
212
+ if self.only_cross_attention:
213
+ attn_output = self.attn1(
214
+ norm_hidden_states,
215
+ encoder_hidden_states=(
216
+ encoder_hidden_states if self.only_cross_attention else None
217
+ ),
218
+ attention_mask=attention_mask,
219
+ **cross_attention_kwargs,
220
+ )
221
+ else:
222
+ if MODE == "write":
223
+ self.bank.append(norm_hidden_states.clone())
224
+ attn_output = self.attn1(
225
+ norm_hidden_states,
226
+ encoder_hidden_states=(
227
+ encoder_hidden_states if self.only_cross_attention else None
228
+ ),
229
+ attention_mask=attention_mask,
230
+ **cross_attention_kwargs,
231
+ )
232
+ if MODE == "read":
233
+
234
+ bank_fea = [
235
+ rearrange(
236
+ rearrange(
237
+ d,
238
+ "(b s) l c -> b s l c",
239
+ b=norm_hidden_states.shape[0] // video_length,
240
+ )[:, 0, :, :]
241
+ # .unsqueeze(1)
242
+ .repeat(1, video_length, 1, 1),
243
+ "b t l c -> (b t) l c",
244
+ )
245
+ for d in self.bank
246
+ ]
247
+ motion_frames_fea = [rearrange(
248
+ d,
249
+ "(b s) l c -> b s l c",
250
+ b=norm_hidden_states.shape[0] // video_length,
251
+ )[:, 1:, :, :] for d in self.bank]
252
+ modify_norm_hidden_states = torch.cat(
253
+ [norm_hidden_states] + bank_fea, dim=1
254
+ )
255
+ hidden_states_uc = (
256
+ self.attn1(
257
+ norm_hidden_states,
258
+ encoder_hidden_states=modify_norm_hidden_states,
259
+ attention_mask=attention_mask,
260
+ )
261
+ + hidden_states
262
+ )
263
+ if do_classifier_free_guidance:
264
+ hidden_states_c = hidden_states_uc.clone()
265
+ _uc_mask = uc_mask.clone()
266
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
267
+ _uc_mask = (
268
+ torch.Tensor(
269
+ [1] * (hidden_states.shape[0] // 2)
270
+ + [0] * (hidden_states.shape[0] // 2)
271
+ )
272
+ .to(device)
273
+ .bool()
274
+ )
275
+ hidden_states_c[_uc_mask] = (
276
+ self.attn1(
277
+ norm_hidden_states[_uc_mask],
278
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
279
+ attention_mask=attention_mask,
280
+ )
281
+ + hidden_states[_uc_mask]
282
+ )
283
+ hidden_states = hidden_states_c.clone()
284
+ else:
285
+ hidden_states = hidden_states_uc
286
+
287
+ # self.bank.clear()
288
+ if self.attn2 is not None:
289
+ # Cross-Attention
290
+ norm_hidden_states = (
291
+ self.norm2(hidden_states, timestep)
292
+ if self.use_ada_layer_norm
293
+ else self.norm2(hidden_states)
294
+ )
295
+ hidden_states = (
296
+ self.attn2(
297
+ norm_hidden_states,
298
+ encoder_hidden_states=encoder_hidden_states,
299
+ attention_mask=attention_mask,
300
+ )
301
+ + hidden_states
302
+ )
303
+
304
+ # Feed-forward
305
+ hidden_states = self.ff(self.norm3(
306
+ hidden_states)) + hidden_states
307
+
308
+ # Temporal-Attention
309
+ if self.unet_use_temporal_attention:
310
+ d = hidden_states.shape[1]
311
+ hidden_states = rearrange(
312
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
313
+ )
314
+ norm_hidden_states = (
315
+ self.norm_temp(hidden_states, timestep)
316
+ if self.use_ada_layer_norm
317
+ else self.norm_temp(hidden_states)
318
+ )
319
+ hidden_states = (
320
+ self.attn_temp(norm_hidden_states) + hidden_states
321
+ )
322
+ hidden_states = rearrange(
323
+ hidden_states, "(b d) f c -> (b f) d c", d=d
324
+ )
325
+
326
+ return hidden_states, motion_frames_fea
327
+
328
+ if self.use_ada_layer_norm_zero:
329
+ attn_output = gate_msa.unsqueeze(1) * attn_output
330
+ hidden_states = attn_output + hidden_states
331
+
332
+ if self.attn2 is not None:
333
+ norm_hidden_states = (
334
+ self.norm2(hidden_states, timestep)
335
+ if self.use_ada_layer_norm
336
+ else self.norm2(hidden_states)
337
+ )
338
+
339
+ # 2. Cross-Attention
340
+ tmp = norm_hidden_states.shape[0] // encoder_hidden_states.shape[0]
341
+ attn_output = self.attn2(
342
+ norm_hidden_states,
343
+ # TODO: repeat这个地方需要斟酌一下
344
+ encoder_hidden_states=encoder_hidden_states.repeat(
345
+ tmp, 1, 1),
346
+ attention_mask=encoder_attention_mask,
347
+ **cross_attention_kwargs,
348
+ )
349
+ hidden_states = attn_output + hidden_states
350
+
351
+ # 3. Feed-forward
352
+ norm_hidden_states = self.norm3(hidden_states)
353
+
354
+ if self.use_ada_layer_norm_zero:
355
+ norm_hidden_states = (
356
+ norm_hidden_states *
357
+ (1 + scale_mlp[:, None]) + shift_mlp[:, None]
358
+ )
359
+
360
+ ff_output = self.ff(norm_hidden_states)
361
+
362
+ if self.use_ada_layer_norm_zero:
363
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
364
+
365
+ hidden_states = ff_output + hidden_states
366
+
367
+ return hidden_states
368
+
369
+ if self.reference_attn:
370
+ if self.fusion_blocks == "midup":
371
+ attn_modules = [
372
+ module
373
+ for module in (
374
+ torch_dfs(self.unet.mid_block) +
375
+ torch_dfs(self.unet.up_blocks)
376
+ )
377
+ if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
378
+ ]
379
+ elif self.fusion_blocks == "full":
380
+ attn_modules = [
381
+ module
382
+ for module in torch_dfs(self.unet)
383
+ if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
384
+ ]
385
+ attn_modules = sorted(
386
+ attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
387
+ )
388
+
389
+ for i, module in enumerate(attn_modules):
390
+ module._original_inner_forward = module.forward
391
+ if isinstance(module, BasicTransformerBlock):
392
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
393
+ module,
394
+ BasicTransformerBlock)
395
+ if isinstance(module, TemporalBasicTransformerBlock):
396
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
397
+ module,
398
+ TemporalBasicTransformerBlock)
399
+
400
+ module.bank = []
401
+ module.attn_weight = float(i) / float(len(attn_modules))
402
+
403
+ def update(self, writer, dtype=torch.float16):
404
+ """
405
+ Update the model's parameters.
406
+
407
+ Args:
408
+ writer (torch.nn.Module): The model's writer object.
409
+ dtype (torch.dtype, optional): The data type to be used for the update. Defaults to torch.float16.
410
+
411
+ Returns:
412
+ None.
413
+ """
414
+ if self.reference_attn:
415
+ if self.fusion_blocks == "midup":
416
+ reader_attn_modules = [
417
+ module
418
+ for module in (
419
+ torch_dfs(self.unet.mid_block) +
420
+ torch_dfs(self.unet.up_blocks)
421
+ )
422
+ if isinstance(module, TemporalBasicTransformerBlock)
423
+ ]
424
+ writer_attn_modules = [
425
+ module
426
+ for module in (
427
+ torch_dfs(writer.unet.mid_block)
428
+ + torch_dfs(writer.unet.up_blocks)
429
+ )
430
+ if isinstance(module, BasicTransformerBlock)
431
+ ]
432
+ elif self.fusion_blocks == "full":
433
+ reader_attn_modules = [
434
+ module
435
+ for module in torch_dfs(self.unet)
436
+ if isinstance(module, TemporalBasicTransformerBlock)
437
+ ]
438
+ writer_attn_modules = [
439
+ module
440
+ for module in torch_dfs(writer.unet)
441
+ if isinstance(module, BasicTransformerBlock)
442
+ ]
443
+
444
+ assert len(reader_attn_modules) == len(writer_attn_modules)
445
+ reader_attn_modules = sorted(
446
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
447
+ )
448
+ writer_attn_modules = sorted(
449
+ writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
450
+ )
451
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
452
+ r.bank = [v.clone().to(dtype) for v in w.bank]
453
+
454
+
455
+ def clear(self):
456
+ """
457
+ Clears the attention bank of all reader attention modules.
458
+
459
+ This method is used when the `reference_attn` attribute is set to `True`.
460
+ It clears the attention bank of all reader attention modules inside the UNet
461
+ model based on the selected `fusion_blocks` mode.
462
+
463
+ If `fusion_blocks` is set to "midup", it searches for reader attention modules
464
+ in both the mid block and up blocks of the UNet model. If `fusion_blocks` is set
465
+ to "full", it searches for reader attention modules in the entire UNet model.
466
+
467
+ It sorts the reader attention modules by the number of neurons in their
468
+ `norm1.normalized_shape[0]` attribute in descending order. This sorting ensures
469
+ that the modules with more neurons are cleared first.
470
+
471
+ Finally, it iterates through the sorted list of reader attention modules and
472
+ calls the `clear()` method on each module's `bank` attribute to clear the
473
+ attention bank.
474
+ """
475
+ if self.reference_attn:
476
+ if self.fusion_blocks == "midup":
477
+ reader_attn_modules = [
478
+ module
479
+ for module in (
480
+ torch_dfs(self.unet.mid_block) +
481
+ torch_dfs(self.unet.up_blocks)
482
+ )
483
+ if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
484
+ ]
485
+ elif self.fusion_blocks == "full":
486
+ reader_attn_modules = [
487
+ module
488
+ for module in torch_dfs(self.unet)
489
+ if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock))
490
+ ]
491
+ reader_attn_modules = sorted(
492
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
493
+ )
494
+ for r in reader_attn_modules:
495
+ r.bank.clear()
joyhallo/models/resnet.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module defines various components used in the ResNet model, such as InflatedConv3D, InflatedGroupNorm,
3
+ Upsample3D, Downsample3D, ResnetBlock3D, and Mish activation function. These components are used to construct
4
+ a deep neural network model for image classification or other computer vision tasks.
5
+
6
+ Classes:
7
+ - InflatedConv3d: An inflated 3D convolutional layer, inheriting from nn.Conv2d.
8
+ - InflatedGroupNorm: An inflated group normalization layer, inheriting from nn.GroupNorm.
9
+ - Upsample3D: A 3D upsampling module, used to increase the resolution of the input tensor.
10
+ - Downsample3D: A 3D downsampling module, used to decrease the resolution of the input tensor.
11
+ - ResnetBlock3D: A 3D residual block, commonly used in ResNet architectures.
12
+ - Mish: A Mish activation function, which is a smooth, non-monotonic activation function.
13
+
14
+ To use this module, simply import the classes and functions you need and follow the instructions provided in
15
+ the respective class and function docstrings.
16
+ """
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from einops import rearrange
21
+ from torch import nn
22
+
23
+
24
+ class InflatedConv3d(nn.Conv2d):
25
+ """
26
+ InflatedConv3d is a class that inherits from torch.nn.Conv2d and overrides the forward method.
27
+
28
+ This class is used to perform 3D convolution on input tensor x. It is a specialized type of convolutional layer
29
+ commonly used in deep learning models for computer vision tasks. The main difference between a regular Conv2d and
30
+ InflatedConv3d is that InflatedConv3d is designed to handle 3D input tensors, which are typically the result of
31
+ inflating 2D convolutional layers to 3D for use in 3D deep learning tasks.
32
+
33
+ Attributes:
34
+ Same as torch.nn.Conv2d.
35
+
36
+ Methods:
37
+ forward(self, x):
38
+ Performs 3D convolution on the input tensor x using the InflatedConv3d layer.
39
+
40
+ Example:
41
+ conv_layer = InflatedConv3d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
42
+ output = conv_layer(input_tensor)
43
+ """
44
+ def forward(self, x):
45
+ """
46
+ Forward pass of the InflatedConv3d layer.
47
+
48
+ Args:
49
+ x (torch.Tensor): Input tensor to the layer.
50
+
51
+ Returns:
52
+ torch.Tensor: Output tensor after applying the InflatedConv3d layer.
53
+ """
54
+ video_length = x.shape[2]
55
+
56
+ x = rearrange(x, "b c f h w -> (b f) c h w")
57
+ x = super().forward(x)
58
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
59
+
60
+ return x
61
+
62
+
63
+ class InflatedGroupNorm(nn.GroupNorm):
64
+ """
65
+ InflatedGroupNorm is a custom class that inherits from torch.nn.GroupNorm.
66
+ It is used to apply group normalization to 3D tensors.
67
+
68
+ Args:
69
+ num_groups (int): The number of groups to divide the channels into.
70
+ num_channels (int): The number of channels in the input tensor.
71
+ eps (float, optional): A small constant to add to the variance to avoid division by zero. Defaults to 1e-5.
72
+ affine (bool, optional): If True, the module has learnable affine parameters. Defaults to True.
73
+
74
+ Attributes:
75
+ weight (torch.Tensor): The learnable weight tensor for scale.
76
+ bias (torch.Tensor): The learnable bias tensor for shift.
77
+
78
+ Forward method:
79
+ x (torch.Tensor): Input tensor to be normalized.
80
+ return (torch.Tensor): Normalized tensor.
81
+ """
82
+ def forward(self, x):
83
+ """
84
+ Performs a forward pass through the CustomClassName.
85
+
86
+ :param x: Input tensor of shape (batch_size, channels, video_length, height, width).
87
+ :return: Output tensor of shape (batch_size, channels, video_length, height, width).
88
+ """
89
+ video_length = x.shape[2]
90
+
91
+ x = rearrange(x, "b c f h w -> (b f) c h w")
92
+ x = super().forward(x)
93
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
94
+
95
+ return x
96
+
97
+
98
+ class Upsample3D(nn.Module):
99
+ """
100
+ Upsample3D is a PyTorch module that upsamples a 3D tensor.
101
+
102
+ Args:
103
+ channels (int): The number of channels in the input tensor.
104
+ use_conv (bool): Whether to use a convolutional layer for upsampling.
105
+ use_conv_transpose (bool): Whether to use a transposed convolutional layer for upsampling.
106
+ out_channels (int): The number of channels in the output tensor.
107
+ name (str): The name of the convolutional layer.
108
+ """
109
+ def __init__(
110
+ self,
111
+ channels,
112
+ use_conv=False,
113
+ use_conv_transpose=False,
114
+ out_channels=None,
115
+ name="conv",
116
+ ):
117
+ super().__init__()
118
+ self.channels = channels
119
+ self.out_channels = out_channels or channels
120
+ self.use_conv = use_conv
121
+ self.use_conv_transpose = use_conv_transpose
122
+ self.name = name
123
+
124
+ if use_conv_transpose:
125
+ raise NotImplementedError
126
+ if use_conv:
127
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
128
+
129
+ def forward(self, hidden_states, output_size=None):
130
+ """
131
+ Forward pass of the Upsample3D class.
132
+
133
+ Args:
134
+ hidden_states (torch.Tensor): Input tensor to be upsampled.
135
+ output_size (tuple, optional): Desired output size of the upsampled tensor.
136
+
137
+ Returns:
138
+ torch.Tensor: Upsampled tensor.
139
+
140
+ Raises:
141
+ AssertionError: If the number of channels in the input tensor does not match the expected channels.
142
+ """
143
+ assert hidden_states.shape[1] == self.channels
144
+
145
+ if self.use_conv_transpose:
146
+ raise NotImplementedError
147
+
148
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
149
+ dtype = hidden_states.dtype
150
+ if dtype == torch.bfloat16:
151
+ hidden_states = hidden_states.to(torch.float32)
152
+
153
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
154
+ if hidden_states.shape[0] >= 64:
155
+ hidden_states = hidden_states.contiguous()
156
+
157
+ # if `output_size` is passed we force the interpolation output
158
+ # size and do not make use of `scale_factor=2`
159
+ if output_size is None:
160
+ hidden_states = F.interpolate(
161
+ hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
162
+ )
163
+ else:
164
+ hidden_states = F.interpolate(
165
+ hidden_states, size=output_size, mode="nearest"
166
+ )
167
+
168
+ # If the input is bfloat16, we cast back to bfloat16
169
+ if dtype == torch.bfloat16:
170
+ hidden_states = hidden_states.to(dtype)
171
+
172
+ # if self.use_conv:
173
+ # if self.name == "conv":
174
+ # hidden_states = self.conv(hidden_states)
175
+ # else:
176
+ # hidden_states = self.Conv2d_0(hidden_states)
177
+ hidden_states = self.conv(hidden_states)
178
+
179
+ return hidden_states
180
+
181
+
182
+ class Downsample3D(nn.Module):
183
+ """
184
+ The Downsample3D class is a PyTorch module for downsampling a 3D tensor, which is used to
185
+ reduce the spatial resolution of feature maps, commonly in the encoder part of a neural network.
186
+
187
+ Attributes:
188
+ channels (int): Number of input channels.
189
+ use_conv (bool): Flag to use a convolutional layer for downsampling.
190
+ out_channels (int, optional): Number of output channels. Defaults to input channels if None.
191
+ padding (int): Padding added to the input.
192
+ name (str): Name of the convolutional layer used for downsampling.
193
+
194
+ Methods:
195
+ forward(self, hidden_states):
196
+ Downsamples the input tensor hidden_states and returns the downsampled tensor.
197
+ """
198
+ def __init__(
199
+ self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
200
+ ):
201
+ """
202
+ Downsamples the given input in the 3D space.
203
+
204
+ Args:
205
+ channels: The number of input channels.
206
+ use_conv: Whether to use a convolutional layer for downsampling.
207
+ out_channels: The number of output channels. If None, the input channels are used.
208
+ padding: The amount of padding to be added to the input.
209
+ name: The name of the convolutional layer.
210
+ """
211
+ super().__init__()
212
+ self.channels = channels
213
+ self.out_channels = out_channels or channels
214
+ self.use_conv = use_conv
215
+ self.padding = padding
216
+ stride = 2
217
+ self.name = name
218
+
219
+ if use_conv:
220
+ self.conv = InflatedConv3d(
221
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
222
+ )
223
+ else:
224
+ raise NotImplementedError
225
+
226
+ def forward(self, hidden_states):
227
+ """
228
+ Forward pass for the Downsample3D class.
229
+
230
+ Args:
231
+ hidden_states (torch.Tensor): Input tensor to be downsampled.
232
+
233
+ Returns:
234
+ torch.Tensor: Downsampled tensor.
235
+
236
+ Raises:
237
+ AssertionError: If the number of channels in the input tensor does not match the expected channels.
238
+ """
239
+ assert hidden_states.shape[1] == self.channels
240
+ if self.use_conv and self.padding == 0:
241
+ raise NotImplementedError
242
+
243
+ assert hidden_states.shape[1] == self.channels
244
+ hidden_states = self.conv(hidden_states)
245
+
246
+ return hidden_states
247
+
248
+
249
+ class ResnetBlock3D(nn.Module):
250
+ """
251
+ The ResnetBlock3D class defines a 3D residual block, a common building block in ResNet
252
+ architectures for both image and video modeling tasks.
253
+
254
+ Attributes:
255
+ in_channels (int): Number of input channels.
256
+ out_channels (int, optional): Number of output channels, defaults to in_channels if None.
257
+ conv_shortcut (bool): Flag to use a convolutional shortcut.
258
+ dropout (float): Dropout rate.
259
+ temb_channels (int): Number of channels in the time embedding tensor.
260
+ groups (int): Number of groups for the group normalization layers.
261
+ eps (float): Epsilon value for group normalization.
262
+ non_linearity (str): Type of nonlinearity to apply after convolutions.
263
+ time_embedding_norm (str): Type of normalization for the time embedding.
264
+ output_scale_factor (float): Scaling factor for the output tensor.
265
+ use_in_shortcut (bool): Flag to include the input tensor in the shortcut connection.
266
+ use_inflated_groupnorm (bool): Flag to use inflated group normalization layers.
267
+
268
+ Methods:
269
+ forward(self, input_tensor, temb):
270
+ Passes the input tensor and time embedding through the residual block and
271
+ returns the output tensor.
272
+ """
273
+ def __init__(
274
+ self,
275
+ *,
276
+ in_channels,
277
+ out_channels=None,
278
+ conv_shortcut=False,
279
+ dropout=0.0,
280
+ temb_channels=512,
281
+ groups=32,
282
+ groups_out=None,
283
+ pre_norm=True,
284
+ eps=1e-6,
285
+ non_linearity="swish",
286
+ time_embedding_norm="default",
287
+ output_scale_factor=1.0,
288
+ use_in_shortcut=None,
289
+ use_inflated_groupnorm=None,
290
+ ):
291
+ super().__init__()
292
+ self.pre_norm = pre_norm
293
+ self.pre_norm = True
294
+ self.in_channels = in_channels
295
+ out_channels = in_channels if out_channels is None else out_channels
296
+ self.out_channels = out_channels
297
+ self.use_conv_shortcut = conv_shortcut
298
+ self.time_embedding_norm = time_embedding_norm
299
+ self.output_scale_factor = output_scale_factor
300
+
301
+ if groups_out is None:
302
+ groups_out = groups
303
+
304
+ assert use_inflated_groupnorm is not None
305
+ if use_inflated_groupnorm:
306
+ self.norm1 = InflatedGroupNorm(
307
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
308
+ )
309
+ else:
310
+ self.norm1 = torch.nn.GroupNorm(
311
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
312
+ )
313
+
314
+ self.conv1 = InflatedConv3d(
315
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
316
+ )
317
+
318
+ if temb_channels is not None:
319
+ if self.time_embedding_norm == "default":
320
+ time_emb_proj_out_channels = out_channels
321
+ elif self.time_embedding_norm == "scale_shift":
322
+ time_emb_proj_out_channels = out_channels * 2
323
+ else:
324
+ raise ValueError(
325
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
326
+ )
327
+
328
+ self.time_emb_proj = torch.nn.Linear(
329
+ temb_channels, time_emb_proj_out_channels
330
+ )
331
+ else:
332
+ self.time_emb_proj = None
333
+
334
+ if use_inflated_groupnorm:
335
+ self.norm2 = InflatedGroupNorm(
336
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
337
+ )
338
+ else:
339
+ self.norm2 = torch.nn.GroupNorm(
340
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
341
+ )
342
+ self.dropout = torch.nn.Dropout(dropout)
343
+ self.conv2 = InflatedConv3d(
344
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
345
+ )
346
+
347
+ if non_linearity == "swish":
348
+ self.nonlinearity = F.silu()
349
+ elif non_linearity == "mish":
350
+ self.nonlinearity = Mish()
351
+ elif non_linearity == "silu":
352
+ self.nonlinearity = nn.SiLU()
353
+
354
+ self.use_in_shortcut = (
355
+ self.in_channels != self.out_channels
356
+ if use_in_shortcut is None
357
+ else use_in_shortcut
358
+ )
359
+
360
+ self.conv_shortcut = None
361
+ if self.use_in_shortcut:
362
+ self.conv_shortcut = InflatedConv3d(
363
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
364
+ )
365
+
366
+ def forward(self, input_tensor, temb):
367
+ """
368
+ Forward pass for the ResnetBlock3D class.
369
+
370
+ Args:
371
+ input_tensor (torch.Tensor): Input tensor to the ResnetBlock3D layer.
372
+ temb (torch.Tensor): Token embedding tensor.
373
+
374
+ Returns:
375
+ torch.Tensor: Output tensor after passing through the ResnetBlock3D layer.
376
+ """
377
+ hidden_states = input_tensor
378
+
379
+ hidden_states = self.norm1(hidden_states)
380
+ hidden_states = self.nonlinearity(hidden_states)
381
+
382
+ hidden_states = self.conv1(hidden_states)
383
+
384
+ if temb is not None:
385
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
386
+
387
+ if temb is not None and self.time_embedding_norm == "default":
388
+ hidden_states = hidden_states + temb
389
+
390
+ hidden_states = self.norm2(hidden_states)
391
+
392
+ if temb is not None and self.time_embedding_norm == "scale_shift":
393
+ scale, shift = torch.chunk(temb, 2, dim=1)
394
+ hidden_states = hidden_states * (1 + scale) + shift
395
+
396
+ hidden_states = self.nonlinearity(hidden_states)
397
+
398
+ hidden_states = self.dropout(hidden_states)
399
+ hidden_states = self.conv2(hidden_states)
400
+
401
+ if self.conv_shortcut is not None:
402
+ input_tensor = self.conv_shortcut(input_tensor)
403
+
404
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
405
+
406
+ return output_tensor
407
+
408
+
409
+ class Mish(torch.nn.Module):
410
+ """
411
+ The Mish class implements the Mish activation function, a smooth, non-monotonic function
412
+ that can be used in neural networks as an alternative to traditional activation functions like ReLU.
413
+
414
+ Methods:
415
+ forward(self, hidden_states):
416
+ Applies the Mish activation function to the input tensor hidden_states and
417
+ returns the resulting tensor.
418
+ """
419
+ def forward(self, hidden_states):
420
+ """
421
+ Mish activation function.
422
+
423
+ Args:
424
+ hidden_states (torch.Tensor): The input tensor to apply the Mish activation function to.
425
+
426
+ Returns:
427
+ hidden_states (torch.Tensor): The output tensor after applying the Mish activation function.
428
+ """
429
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
joyhallo/models/transformer_2d.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module defines the Transformer2DModel, a PyTorch model that extends ModelMixin and ConfigMixin. It includes
3
+ methods for gradient checkpointing, forward propagation, and various utility functions. The model is designed for
4
+ 2D image-related tasks and uses LoRa (Low-Rank All-Attention) compatible layers for efficient attention computation.
5
+
6
+ The file includes the following import statements:
7
+
8
+ - From dataclasses import dataclass
9
+ - From typing import Any, Dict, Optional
10
+ - Import torch
11
+ - From diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ - From diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
13
+ - From diffusers.models.modeling_utils import ModelMixin
14
+ - From diffusers.models.normalization import AdaLayerNormSingle
15
+ - From diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate,
16
+ is_torch_version)
17
+ - From torch import nn
18
+ - From .attention import BasicTransformerBlock
19
+
20
+ The file also includes the following classes and functions:
21
+
22
+ - Transformer2DModel: A model class that extends ModelMixin and ConfigMixin. It includes methods for gradient
23
+ checkpointing, forward propagation, and various utility functions.
24
+ - _set_gradient_checkpointing: A utility function to set gradient checkpointing for a given module.
25
+ - forward: The forward propagation method for the Transformer2DModel.
26
+
27
+ To use this module, you can import the Transformer2DModel class and create an instance of the model with the desired
28
+ configuration. Then, you can use the forward method to pass input tensors through the model and get the output tensors.
29
+ """
30
+
31
+ from dataclasses import dataclass
32
+ from typing import Any, Dict, Optional
33
+
34
+ import torch
35
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
36
+ # from diffusers.models.embeddings import CaptionProjection
37
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
38
+ from diffusers.models.modeling_utils import ModelMixin
39
+ from diffusers.models.normalization import AdaLayerNormSingle
40
+ from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate,
41
+ is_torch_version)
42
+ from torch import nn
43
+
44
+ from .attention import BasicTransformerBlock
45
+
46
+
47
+ @dataclass
48
+ class Transformer2DModelOutput(BaseOutput):
49
+ """
50
+ The output of [`Transformer2DModel`].
51
+
52
+ Args:
53
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`
54
+ or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
55
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
56
+ distributions for the unnoised latent pixels.
57
+ """
58
+
59
+ sample: torch.FloatTensor
60
+ ref_feature: torch.FloatTensor
61
+
62
+
63
+ class Transformer2DModel(ModelMixin, ConfigMixin):
64
+ """
65
+ A 2D Transformer model for image-like data.
66
+
67
+ Parameters:
68
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
69
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
70
+ in_channels (`int`, *optional*):
71
+ The number of channels in the input and output (specify if the input is **continuous**).
72
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
73
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
74
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
75
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
76
+ This is fixed during training since it is used to learn a number of position embeddings.
77
+ num_vector_embeds (`int`, *optional*):
78
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
79
+ Includes the class for the masked latent pixel.
80
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
81
+ num_embeds_ada_norm ( `int`, *optional*):
82
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
83
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
84
+ added to the hidden states.
85
+
86
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
87
+ attention_bias (`bool`, *optional*):
88
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
89
+ """
90
+
91
+ _supports_gradient_checkpointing = True
92
+
93
+ @register_to_config
94
+ def __init__(
95
+ self,
96
+ num_attention_heads: int = 16,
97
+ attention_head_dim: int = 88,
98
+ in_channels: Optional[int] = None,
99
+ out_channels: Optional[int] = None,
100
+ num_layers: int = 1,
101
+ dropout: float = 0.0,
102
+ norm_num_groups: int = 32,
103
+ cross_attention_dim: Optional[int] = None,
104
+ attention_bias: bool = False,
105
+ num_vector_embeds: Optional[int] = None,
106
+ patch_size: Optional[int] = None,
107
+ activation_fn: str = "geglu",
108
+ num_embeds_ada_norm: Optional[int] = None,
109
+ use_linear_projection: bool = False,
110
+ only_cross_attention: bool = False,
111
+ double_self_attention: bool = False,
112
+ upcast_attention: bool = False,
113
+ norm_type: str = "layer_norm",
114
+ norm_elementwise_affine: bool = True,
115
+ norm_eps: float = 1e-5,
116
+ attention_type: str = "default",
117
+ ):
118
+ super().__init__()
119
+ self.use_linear_projection = use_linear_projection
120
+ self.num_attention_heads = num_attention_heads
121
+ self.attention_head_dim = attention_head_dim
122
+ inner_dim = num_attention_heads * attention_head_dim
123
+
124
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
125
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
126
+
127
+ # 1. Transformer2DModel can process both standard continuous images of
128
+ # shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of
129
+ # shape `(batch_size, num_image_vectors)`
130
+ # Define whether input is continuous or discrete depending on configuration
131
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
132
+ self.is_input_vectorized = num_vector_embeds is not None
133
+ self.is_input_patches = in_channels is not None and patch_size is not None
134
+
135
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
136
+ deprecation_message = (
137
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
138
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
139
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
140
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
141
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
142
+ )
143
+ deprecate(
144
+ "norm_type!=num_embeds_ada_norm",
145
+ "1.0.0",
146
+ deprecation_message,
147
+ standard_warn=False,
148
+ )
149
+ norm_type = "ada_norm"
150
+
151
+ if self.is_input_continuous and self.is_input_vectorized:
152
+ raise ValueError(
153
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
154
+ " sure that either `in_channels` or `num_vector_embeds` is None."
155
+ )
156
+
157
+ if self.is_input_vectorized and self.is_input_patches:
158
+ raise ValueError(
159
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
160
+ " sure that either `num_vector_embeds` or `num_patches` is None."
161
+ )
162
+
163
+ if (
164
+ not self.is_input_continuous
165
+ and not self.is_input_vectorized
166
+ and not self.is_input_patches
167
+ ):
168
+ raise ValueError(
169
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
170
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
171
+ )
172
+
173
+ # 2. Define input layers
174
+ self.in_channels = in_channels
175
+
176
+ self.norm = torch.nn.GroupNorm(
177
+ num_groups=norm_num_groups,
178
+ num_channels=in_channels,
179
+ eps=1e-6,
180
+ affine=True,
181
+ )
182
+ if use_linear_projection:
183
+ self.proj_in = linear_cls(in_channels, inner_dim)
184
+ else:
185
+ self.proj_in = conv_cls(
186
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
187
+ )
188
+
189
+ # 3. Define transformers blocks
190
+ self.transformer_blocks = nn.ModuleList(
191
+ [
192
+ BasicTransformerBlock(
193
+ inner_dim,
194
+ num_attention_heads,
195
+ attention_head_dim,
196
+ dropout=dropout,
197
+ cross_attention_dim=cross_attention_dim,
198
+ activation_fn=activation_fn,
199
+ num_embeds_ada_norm=num_embeds_ada_norm,
200
+ attention_bias=attention_bias,
201
+ only_cross_attention=only_cross_attention,
202
+ double_self_attention=double_self_attention,
203
+ upcast_attention=upcast_attention,
204
+ norm_type=norm_type,
205
+ norm_elementwise_affine=norm_elementwise_affine,
206
+ norm_eps=norm_eps,
207
+ attention_type=attention_type,
208
+ )
209
+ for d in range(num_layers)
210
+ ]
211
+ )
212
+
213
+ # 4. Define output layers
214
+ self.out_channels = in_channels if out_channels is None else out_channels
215
+ # TODO: should use out_channels for continuous projections
216
+ if use_linear_projection:
217
+ self.proj_out = linear_cls(inner_dim, in_channels)
218
+ else:
219
+ self.proj_out = conv_cls(
220
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
221
+ )
222
+
223
+ # 5. PixArt-Alpha blocks.
224
+ self.adaln_single = None
225
+ self.use_additional_conditions = False
226
+ if norm_type == "ada_norm_single":
227
+ self.use_additional_conditions = self.config.sample_size == 128
228
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
229
+ # additional conditions until we find better name
230
+ self.adaln_single = AdaLayerNormSingle(
231
+ inner_dim, use_additional_conditions=self.use_additional_conditions
232
+ )
233
+
234
+ self.caption_projection = None
235
+
236
+ self.gradient_checkpointing = False
237
+
238
+ def _set_gradient_checkpointing(self, module, value=False):
239
+ if hasattr(module, "gradient_checkpointing"):
240
+ module.gradient_checkpointing = value
241
+
242
+ def forward(
243
+ self,
244
+ hidden_states: torch.Tensor,
245
+ encoder_hidden_states: Optional[torch.Tensor] = None,
246
+ timestep: Optional[torch.LongTensor] = None,
247
+ _added_cond_kwargs: Dict[str, torch.Tensor] = None,
248
+ class_labels: Optional[torch.LongTensor] = None,
249
+ cross_attention_kwargs: Dict[str, Any] = None,
250
+ attention_mask: Optional[torch.Tensor] = None,
251
+ encoder_attention_mask: Optional[torch.Tensor] = None,
252
+ return_dict: bool = True,
253
+ ):
254
+ """
255
+ The [`Transformer2DModel`] forward method.
256
+
257
+ Args:
258
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete,
259
+ `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
260
+ Input `hidden_states`.
261
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
262
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
263
+ self-attention.
264
+ timestep ( `torch.LongTensor`, *optional*):
265
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
266
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
267
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
268
+ `AdaLayerZeroNorm`.
269
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
270
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
271
+ `self.processor` in
272
+ [diffusers.models.attention_processor]
273
+ (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
274
+ attention_mask ( `torch.Tensor`, *optional*):
275
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
276
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
277
+ negative values to the attention scores corresponding to "discard" tokens.
278
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
279
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
280
+
281
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
282
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
283
+
284
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
285
+ above. This bias will be added to the cross-attention scores.
286
+ return_dict (`bool`, *optional*, defaults to `True`):
287
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
288
+ tuple.
289
+
290
+ Returns:
291
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
292
+ `tuple` where the first element is the sample tensor.
293
+ """
294
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
295
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
296
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
297
+ # expects mask of shape:
298
+ # [batch, key_tokens]
299
+ # adds singleton query_tokens dimension:
300
+ # [batch, 1, key_tokens]
301
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
302
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
303
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
304
+ if attention_mask is not None and attention_mask.ndim == 2:
305
+ # assume that mask is expressed as:
306
+ # (1 = keep, 0 = discard)
307
+ # convert mask into a bias that can be added to attention scores:
308
+ # (keep = +0, discard = -10000.0)
309
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
310
+ attention_mask = attention_mask.unsqueeze(1)
311
+
312
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
313
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
314
+ encoder_attention_mask = (
315
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
316
+ ) * -10000.0
317
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
318
+
319
+ # Retrieve lora scale.
320
+ lora_scale = (
321
+ cross_attention_kwargs.get("scale", 1.0)
322
+ if cross_attention_kwargs is not None
323
+ else 1.0
324
+ )
325
+
326
+ # 1. Input
327
+ batch, _, height, width = hidden_states.shape
328
+ residual = hidden_states
329
+
330
+ hidden_states = self.norm(hidden_states)
331
+ if not self.use_linear_projection:
332
+ hidden_states = (
333
+ self.proj_in(hidden_states, scale=lora_scale)
334
+ if not USE_PEFT_BACKEND
335
+ else self.proj_in(hidden_states)
336
+ )
337
+ inner_dim = hidden_states.shape[1]
338
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
339
+ batch, height * width, inner_dim
340
+ )
341
+ else:
342
+ inner_dim = hidden_states.shape[1]
343
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
344
+ batch, height * width, inner_dim
345
+ )
346
+ hidden_states = (
347
+ self.proj_in(hidden_states, scale=lora_scale)
348
+ if not USE_PEFT_BACKEND
349
+ else self.proj_in(hidden_states)
350
+ )
351
+
352
+ # 2. Blocks
353
+ if self.caption_projection is not None:
354
+ batch_size = hidden_states.shape[0]
355
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
356
+ encoder_hidden_states = encoder_hidden_states.view(
357
+ batch_size, -1, hidden_states.shape[-1]
358
+ )
359
+
360
+ ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
361
+ for block in self.transformer_blocks:
362
+ if self.training and self.gradient_checkpointing:
363
+
364
+ def create_custom_forward(module, return_dict=None):
365
+ def custom_forward(*inputs):
366
+ if return_dict is not None:
367
+ return module(*inputs, return_dict=return_dict)
368
+
369
+ return module(*inputs)
370
+
371
+ return custom_forward
372
+
373
+ ckpt_kwargs: Dict[str, Any] = (
374
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
375
+ )
376
+ hidden_states = torch.utils.checkpoint.checkpoint(
377
+ create_custom_forward(block),
378
+ hidden_states,
379
+ attention_mask,
380
+ encoder_hidden_states,
381
+ encoder_attention_mask,
382
+ timestep,
383
+ cross_attention_kwargs,
384
+ class_labels,
385
+ **ckpt_kwargs,
386
+ )
387
+ else:
388
+ hidden_states = block(
389
+ hidden_states, # shape [5, 4096, 320]
390
+ attention_mask=attention_mask,
391
+ encoder_hidden_states=encoder_hidden_states, # shape [1,4,768]
392
+ encoder_attention_mask=encoder_attention_mask,
393
+ timestep=timestep,
394
+ cross_attention_kwargs=cross_attention_kwargs,
395
+ class_labels=class_labels,
396
+ )
397
+
398
+ # 3. Output
399
+ output = None
400
+ if self.is_input_continuous:
401
+ if not self.use_linear_projection:
402
+ hidden_states = (
403
+ hidden_states.reshape(batch, height, width, inner_dim)
404
+ .permute(0, 3, 1, 2)
405
+ .contiguous()
406
+ )
407
+ hidden_states = (
408
+ self.proj_out(hidden_states, scale=lora_scale)
409
+ if not USE_PEFT_BACKEND
410
+ else self.proj_out(hidden_states)
411
+ )
412
+ else:
413
+ hidden_states = (
414
+ self.proj_out(hidden_states, scale=lora_scale)
415
+ if not USE_PEFT_BACKEND
416
+ else self.proj_out(hidden_states)
417
+ )
418
+ hidden_states = (
419
+ hidden_states.reshape(batch, height, width, inner_dim)
420
+ .permute(0, 3, 1, 2)
421
+ .contiguous()
422
+ )
423
+
424
+ output = hidden_states + residual
425
+ if not return_dict:
426
+ return (output, ref_feature)
427
+
428
+ return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
joyhallo/models/transformer_3d.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements the Transformer3DModel, a PyTorch model designed for processing
3
+ 3D data such as videos. It extends ModelMixin and ConfigMixin to provide a transformer
4
+ model with support for gradient checkpointing and various types of attention mechanisms.
5
+ The model can be configured with different parameters such as the number of attention heads,
6
+ attention head dimension, and the number of layers. It also supports the use of audio modules
7
+ for enhanced feature extraction from video data.
8
+ """
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Optional
12
+
13
+ import torch
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models import ModelMixin
16
+ from diffusers.utils import BaseOutput
17
+ from einops import rearrange, repeat
18
+ from torch import nn
19
+
20
+ from .attention import (AudioTemporalBasicTransformerBlock,
21
+ TemporalBasicTransformerBlock)
22
+
23
+
24
+ @dataclass
25
+ class Transformer3DModelOutput(BaseOutput):
26
+ """
27
+ The output of the [`Transformer3DModel`].
28
+
29
+ Attributes:
30
+ sample (`torch.FloatTensor`):
31
+ The output tensor from the transformer model, which is the result of processing the input
32
+ hidden states through the transformer blocks and any subsequent layers.
33
+ """
34
+ sample: torch.FloatTensor
35
+
36
+
37
+ class Transformer3DModel(ModelMixin, ConfigMixin):
38
+ """
39
+ Transformer3DModel is a PyTorch model that extends `ModelMixin` and `ConfigMixin` to create a 3D transformer model.
40
+ It implements the forward pass for processing input hidden states, encoder hidden states, and various types of attention masks.
41
+ The model supports gradient checkpointing, which can be enabled by calling the `enable_gradient_checkpointing()` method.
42
+ """
43
+ _supports_gradient_checkpointing = True
44
+
45
+ @register_to_config
46
+ def __init__(
47
+ self,
48
+ num_attention_heads: int = 16,
49
+ attention_head_dim: int = 88,
50
+ in_channels: Optional[int] = None,
51
+ num_layers: int = 1,
52
+ dropout: float = 0.0,
53
+ norm_num_groups: int = 32,
54
+ cross_attention_dim: Optional[int] = None,
55
+ attention_bias: bool = False,
56
+ activation_fn: str = "geglu",
57
+ num_embeds_ada_norm: Optional[int] = None,
58
+ use_linear_projection: bool = False,
59
+ only_cross_attention: bool = False,
60
+ upcast_attention: bool = False,
61
+ unet_use_cross_frame_attention=None,
62
+ unet_use_temporal_attention=None,
63
+ use_audio_module=False,
64
+ depth=0,
65
+ unet_block_name=None,
66
+ stack_enable_blocks_name = None,
67
+ stack_enable_blocks_depth = None,
68
+ ):
69
+ super().__init__()
70
+ self.use_linear_projection = use_linear_projection
71
+ self.num_attention_heads = num_attention_heads
72
+ self.attention_head_dim = attention_head_dim
73
+ inner_dim = num_attention_heads * attention_head_dim
74
+ self.use_audio_module = use_audio_module
75
+ # Define input layers
76
+ self.in_channels = in_channels
77
+
78
+ self.norm = torch.nn.GroupNorm(
79
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
80
+ )
81
+ if use_linear_projection:
82
+ self.proj_in = nn.Linear(in_channels, inner_dim)
83
+ else:
84
+ self.proj_in = nn.Conv2d(
85
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
86
+ )
87
+
88
+ if use_audio_module:
89
+ self.transformer_blocks = nn.ModuleList(
90
+ [
91
+ AudioTemporalBasicTransformerBlock(
92
+ inner_dim,
93
+ num_attention_heads,
94
+ attention_head_dim,
95
+ dropout=dropout,
96
+ cross_attention_dim=cross_attention_dim,
97
+ activation_fn=activation_fn,
98
+ num_embeds_ada_norm=num_embeds_ada_norm,
99
+ attention_bias=attention_bias,
100
+ only_cross_attention=only_cross_attention,
101
+ upcast_attention=upcast_attention,
102
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
103
+ unet_use_temporal_attention=unet_use_temporal_attention,
104
+ depth=depth,
105
+ unet_block_name=unet_block_name,
106
+ stack_enable_blocks_name=stack_enable_blocks_name,
107
+ stack_enable_blocks_depth=stack_enable_blocks_depth,
108
+ )
109
+ for d in range(num_layers)
110
+ ]
111
+ )
112
+ else:
113
+ # Define transformers blocks
114
+ self.transformer_blocks = nn.ModuleList(
115
+ [
116
+ TemporalBasicTransformerBlock(
117
+ inner_dim,
118
+ num_attention_heads,
119
+ attention_head_dim,
120
+ dropout=dropout,
121
+ cross_attention_dim=cross_attention_dim,
122
+ activation_fn=activation_fn,
123
+ num_embeds_ada_norm=num_embeds_ada_norm,
124
+ attention_bias=attention_bias,
125
+ only_cross_attention=only_cross_attention,
126
+ upcast_attention=upcast_attention,
127
+ )
128
+ for d in range(num_layers)
129
+ ]
130
+ )
131
+
132
+ # 4. Define output layers
133
+ if use_linear_projection:
134
+ self.proj_out = nn.Linear(in_channels, inner_dim)
135
+ else:
136
+ self.proj_out = nn.Conv2d(
137
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
138
+ )
139
+
140
+ self.gradient_checkpointing = False
141
+
142
+ def _set_gradient_checkpointing(self, module, value=False):
143
+ if hasattr(module, "gradient_checkpointing"):
144
+ module.gradient_checkpointing = value
145
+
146
+ def forward(
147
+ self,
148
+ hidden_states,
149
+ encoder_hidden_states=None,
150
+ attention_mask=None,
151
+ full_mask=None,
152
+ face_mask=None,
153
+ lip_mask=None,
154
+ motion_scale=None,
155
+ timestep=None,
156
+ return_dict: bool = True,
157
+ ):
158
+ """
159
+ Forward pass for the Transformer3DModel.
160
+
161
+ Args:
162
+ hidden_states (torch.Tensor): The input hidden states.
163
+ encoder_hidden_states (torch.Tensor, optional): The input encoder hidden states.
164
+ attention_mask (torch.Tensor, optional): The attention mask.
165
+ full_mask (torch.Tensor, optional): The full mask.
166
+ face_mask (torch.Tensor, optional): The face mask.
167
+ lip_mask (torch.Tensor, optional): The lip mask.
168
+ timestep (int, optional): The current timestep.
169
+ return_dict (bool, optional): Whether to return a dictionary or a tuple.
170
+
171
+ Returns:
172
+ output (Union[Tuple, BaseOutput]): The output of the Transformer3DModel.
173
+ """
174
+ # Input
175
+ assert (
176
+ hidden_states.dim() == 5
177
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
178
+ video_length = hidden_states.shape[2]
179
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
180
+
181
+ # TODO
182
+ if self.use_audio_module:
183
+ encoder_hidden_states = rearrange(
184
+ encoder_hidden_states,
185
+ "bs f margin dim -> (bs f) margin dim",
186
+ )
187
+ else:
188
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
189
+ encoder_hidden_states = repeat(
190
+ encoder_hidden_states, "b n c -> (b f) n c", f=video_length
191
+ )
192
+
193
+ batch, _, height, weight = hidden_states.shape
194
+ residual = hidden_states
195
+
196
+ hidden_states = self.norm(hidden_states)
197
+ if not self.use_linear_projection:
198
+ hidden_states = self.proj_in(hidden_states)
199
+ inner_dim = hidden_states.shape[1]
200
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
201
+ batch, height * weight, inner_dim
202
+ )
203
+ else:
204
+ inner_dim = hidden_states.shape[1]
205
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
206
+ batch, height * weight, inner_dim
207
+ )
208
+ hidden_states = self.proj_in(hidden_states)
209
+
210
+ # Blocks
211
+ motion_frames = []
212
+ for _, block in enumerate(self.transformer_blocks):
213
+ if isinstance(block, TemporalBasicTransformerBlock):
214
+ hidden_states, motion_frame_fea = block(
215
+ hidden_states,
216
+ encoder_hidden_states=encoder_hidden_states,
217
+ timestep=timestep,
218
+ video_length=video_length,
219
+ )
220
+ motion_frames.append(motion_frame_fea)
221
+ else:
222
+ hidden_states = block(
223
+ hidden_states, # shape [2, 4096, 320]
224
+ encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640]
225
+ attention_mask=attention_mask,
226
+ full_mask=full_mask,
227
+ face_mask=face_mask,
228
+ lip_mask=lip_mask,
229
+ timestep=timestep,
230
+ video_length=video_length,
231
+ motion_scale=motion_scale,
232
+ )
233
+
234
+ # Output
235
+ if not self.use_linear_projection:
236
+ hidden_states = (
237
+ hidden_states.reshape(batch, height, weight, inner_dim)
238
+ .permute(0, 3, 1, 2)
239
+ .contiguous()
240
+ )
241
+ hidden_states = self.proj_out(hidden_states)
242
+ else:
243
+ hidden_states = self.proj_out(hidden_states)
244
+ hidden_states = (
245
+ hidden_states.reshape(batch, height, weight, inner_dim)
246
+ .permute(0, 3, 1, 2)
247
+ .contiguous()
248
+ )
249
+
250
+ output = hidden_states + residual
251
+
252
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
253
+ if not return_dict:
254
+ return (output, motion_frames)
255
+
256
+ return Transformer3DModelOutput(sample=output)
joyhallo/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines the 2D blocks for the UNet model in a PyTorch implementation.
3
+ The UNet model is a popular architecture for image segmentation tasks,
4
+ which consists of an encoder, a decoder, and a skip connection mechanism.
5
+ The 2D blocks in this file include various types of layers, such as ResNet blocks,
6
+ Transformer blocks, and cross-attention blocks,
7
+ which are used to build the encoder and decoder parts of the UNet model.
8
+ The AutoencoderTinyBlock class is a simple autoencoder block for tiny models,
9
+ and the UNetMidBlock2D and CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D,
10
+ and UpBlock2D classes are used for the middle and decoder parts of the UNet model.
11
+ The classes and functions in this file provide a flexible and modular way
12
+ to construct the UNet model for different image segmentation tasks.
13
+ """
14
+
15
+ from typing import Any, Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from diffusers.models.activations import get_activation
19
+ from diffusers.models.attention_processor import Attention
20
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
21
+ from diffusers.models.transformers.dual_transformer_2d import \
22
+ DualTransformer2DModel
23
+ from diffusers.utils import is_torch_version, logging
24
+ from diffusers.utils.torch_utils import apply_freeu
25
+ from torch import nn
26
+
27
+ from .transformer_2d import Transformer2DModel
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ def get_down_block(
33
+ down_block_type: str,
34
+ num_layers: int,
35
+ in_channels: int,
36
+ out_channels: int,
37
+ temb_channels: int,
38
+ add_downsample: bool,
39
+ resnet_eps: float,
40
+ resnet_act_fn: str,
41
+ transformer_layers_per_block: int = 1,
42
+ num_attention_heads: Optional[int] = None,
43
+ resnet_groups: Optional[int] = None,
44
+ cross_attention_dim: Optional[int] = None,
45
+ downsample_padding: Optional[int] = None,
46
+ dual_cross_attention: bool = False,
47
+ use_linear_projection: bool = False,
48
+ only_cross_attention: bool = False,
49
+ upcast_attention: bool = False,
50
+ resnet_time_scale_shift: str = "default",
51
+ attention_type: str = "default",
52
+ attention_head_dim: Optional[int] = None,
53
+ dropout: float = 0.0,
54
+ ):
55
+ """ This function creates and returns a UpBlock2D or CrossAttnUpBlock2D object based on the given up_block_type.
56
+
57
+ Args:
58
+ up_block_type (str): The type of up block to create. Must be either "UpBlock2D" or "CrossAttnUpBlock2D".
59
+ num_layers (int): The number of layers in the ResNet block.
60
+ in_channels (int): The number of input channels.
61
+ out_channels (int): The number of output channels.
62
+ prev_output_channel (int): The number of channels in the previous output.
63
+ temb_channels (int): The number of channels in the token embedding.
64
+ add_upsample (bool): Whether to add an upsample layer after the ResNet block. Defaults to True.
65
+ resnet_eps (float): The epsilon value for the ResNet block. Defaults to 1e-6.
66
+ resnet_act_fn (str): The activation function to use in the ResNet block. Defaults to "swish".
67
+ resnet_groups (int): The number of groups in the ResNet block. Defaults to 32.
68
+ resnet_pre_norm (bool): Whether to use pre-normalization in the ResNet block. Defaults to True.
69
+ output_scale_factor (float): The scale factor to apply to the output. Defaults to 1.0.
70
+
71
+ Returns:
72
+ nn.Module: The created UpBlock2D or CrossAttnUpBlock2D object.
73
+ """
74
+ # If attn head dim is not defined, we default it to the number of heads
75
+ if attention_head_dim is None:
76
+ logger.warning("It is recommended to provide `attention_head_dim` when calling `get_down_block`.")
77
+ logger.warning(f"Defaulting `attention_head_dim` to {num_attention_heads}.")
78
+ attention_head_dim = num_attention_heads
79
+
80
+ down_block_type = (
81
+ down_block_type[7:]
82
+ if down_block_type.startswith("UNetRes")
83
+ else down_block_type
84
+ )
85
+ if down_block_type == "DownBlock2D":
86
+ return DownBlock2D(
87
+ num_layers=num_layers,
88
+ in_channels=in_channels,
89
+ out_channels=out_channels,
90
+ temb_channels=temb_channels,
91
+ dropout=dropout,
92
+ add_downsample=add_downsample,
93
+ resnet_eps=resnet_eps,
94
+ resnet_act_fn=resnet_act_fn,
95
+ resnet_groups=resnet_groups,
96
+ downsample_padding=downsample_padding,
97
+ resnet_time_scale_shift=resnet_time_scale_shift,
98
+ )
99
+
100
+ if down_block_type == "CrossAttnDownBlock2D":
101
+ if cross_attention_dim is None:
102
+ raise ValueError(
103
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
104
+ )
105
+ return CrossAttnDownBlock2D(
106
+ num_layers=num_layers,
107
+ transformer_layers_per_block=transformer_layers_per_block,
108
+ in_channels=in_channels,
109
+ out_channels=out_channels,
110
+ temb_channels=temb_channels,
111
+ dropout=dropout,
112
+ add_downsample=add_downsample,
113
+ resnet_eps=resnet_eps,
114
+ resnet_act_fn=resnet_act_fn,
115
+ resnet_groups=resnet_groups,
116
+ downsample_padding=downsample_padding,
117
+ cross_attention_dim=cross_attention_dim,
118
+ num_attention_heads=num_attention_heads,
119
+ dual_cross_attention=dual_cross_attention,
120
+ use_linear_projection=use_linear_projection,
121
+ only_cross_attention=only_cross_attention,
122
+ upcast_attention=upcast_attention,
123
+ resnet_time_scale_shift=resnet_time_scale_shift,
124
+ attention_type=attention_type,
125
+ )
126
+ raise ValueError(f"{down_block_type} does not exist.")
127
+
128
+
129
+ def get_up_block(
130
+ up_block_type: str,
131
+ num_layers: int,
132
+ in_channels: int,
133
+ out_channels: int,
134
+ prev_output_channel: int,
135
+ temb_channels: int,
136
+ add_upsample: bool,
137
+ resnet_eps: float,
138
+ resnet_act_fn: str,
139
+ resolution_idx: Optional[int] = None,
140
+ transformer_layers_per_block: int = 1,
141
+ num_attention_heads: Optional[int] = None,
142
+ resnet_groups: Optional[int] = None,
143
+ cross_attention_dim: Optional[int] = None,
144
+ dual_cross_attention: bool = False,
145
+ use_linear_projection: bool = False,
146
+ only_cross_attention: bool = False,
147
+ upcast_attention: bool = False,
148
+ resnet_time_scale_shift: str = "default",
149
+ attention_type: str = "default",
150
+ attention_head_dim: Optional[int] = None,
151
+ dropout: float = 0.0,
152
+ ) -> nn.Module:
153
+ """ This function ...
154
+ Args:
155
+ Returns:
156
+ """
157
+ # If attn head dim is not defined, we default it to the number of heads
158
+ if attention_head_dim is None:
159
+ logger.warning("It is recommended to provide `attention_head_dim` when calling `get_up_block`.")
160
+ logger.warning(f"Defaulting `attention_head_dim` to {num_attention_heads}.")
161
+ attention_head_dim = num_attention_heads
162
+
163
+ up_block_type = (
164
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
165
+ )
166
+ if up_block_type == "UpBlock2D":
167
+ return UpBlock2D(
168
+ num_layers=num_layers,
169
+ in_channels=in_channels,
170
+ out_channels=out_channels,
171
+ prev_output_channel=prev_output_channel,
172
+ temb_channels=temb_channels,
173
+ resolution_idx=resolution_idx,
174
+ dropout=dropout,
175
+ add_upsample=add_upsample,
176
+ resnet_eps=resnet_eps,
177
+ resnet_act_fn=resnet_act_fn,
178
+ resnet_groups=resnet_groups,
179
+ resnet_time_scale_shift=resnet_time_scale_shift,
180
+ )
181
+ if up_block_type == "CrossAttnUpBlock2D":
182
+ if cross_attention_dim is None:
183
+ raise ValueError(
184
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
185
+ )
186
+ return CrossAttnUpBlock2D(
187
+ num_layers=num_layers,
188
+ transformer_layers_per_block=transformer_layers_per_block,
189
+ in_channels=in_channels,
190
+ out_channels=out_channels,
191
+ prev_output_channel=prev_output_channel,
192
+ temb_channels=temb_channels,
193
+ resolution_idx=resolution_idx,
194
+ dropout=dropout,
195
+ add_upsample=add_upsample,
196
+ resnet_eps=resnet_eps,
197
+ resnet_act_fn=resnet_act_fn,
198
+ resnet_groups=resnet_groups,
199
+ cross_attention_dim=cross_attention_dim,
200
+ num_attention_heads=num_attention_heads,
201
+ dual_cross_attention=dual_cross_attention,
202
+ use_linear_projection=use_linear_projection,
203
+ only_cross_attention=only_cross_attention,
204
+ upcast_attention=upcast_attention,
205
+ resnet_time_scale_shift=resnet_time_scale_shift,
206
+ attention_type=attention_type,
207
+ )
208
+
209
+ raise ValueError(f"{up_block_type} does not exist.")
210
+
211
+
212
+ class AutoencoderTinyBlock(nn.Module):
213
+ """
214
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
215
+ blocks.
216
+
217
+ Args:
218
+ in_channels (`int`): The number of input channels.
219
+ out_channels (`int`): The number of output channels.
220
+ act_fn (`str`):
221
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
222
+
223
+ Returns:
224
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
225
+ `out_channels`.
226
+ """
227
+
228
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
229
+ super().__init__()
230
+ act_fn = get_activation(act_fn)
231
+ self.conv = nn.Sequential(
232
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
233
+ act_fn,
234
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
235
+ act_fn,
236
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
237
+ )
238
+ self.skip = (
239
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
240
+ if in_channels != out_channels
241
+ else nn.Identity()
242
+ )
243
+ self.fuse = nn.ReLU()
244
+
245
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
246
+ """
247
+ Forward pass of the AutoencoderTinyBlock class.
248
+
249
+ Parameters:
250
+ x (torch.FloatTensor): The input tensor to the AutoencoderTinyBlock.
251
+
252
+ Returns:
253
+ torch.FloatTensor: The output tensor after passing through the AutoencoderTinyBlock.
254
+ """
255
+ return self.fuse(self.conv(x) + self.skip(x))
256
+
257
+
258
+ class UNetMidBlock2D(nn.Module):
259
+ """
260
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
261
+
262
+ Args:
263
+ in_channels (`int`): The number of input channels.
264
+ temb_channels (`int`): The number of temporal embedding channels.
265
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
266
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
267
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
268
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
269
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
270
+ model on tasks with long-range temporal dependencies.
271
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
272
+ resnet_groups (`int`, *optional*, defaults to 32):
273
+ The number of groups to use in the group normalization layers of the resnet blocks.
274
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
275
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
276
+ Whether to use pre-normalization for the resnet blocks.
277
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
278
+ attention_head_dim (`int`, *optional*, defaults to 1):
279
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
280
+ the number of input channels.
281
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
282
+
283
+ Returns:
284
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
285
+ in_channels, height, width)`.
286
+
287
+ """
288
+
289
+ def __init__(
290
+ self,
291
+ in_channels: int,
292
+ temb_channels: int,
293
+ dropout: float = 0.0,
294
+ num_layers: int = 1,
295
+ resnet_eps: float = 1e-6,
296
+ resnet_time_scale_shift: str = "default", # default, spatial
297
+ resnet_act_fn: str = "swish",
298
+ resnet_groups: int = 32,
299
+ attn_groups: Optional[int] = None,
300
+ resnet_pre_norm: bool = True,
301
+ add_attention: bool = True,
302
+ attention_head_dim: int = 1,
303
+ output_scale_factor: float = 1.0,
304
+ ):
305
+ super().__init__()
306
+ resnet_groups = (
307
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
308
+ )
309
+ self.add_attention = add_attention
310
+
311
+ if attn_groups is None:
312
+ attn_groups = (
313
+ resnet_groups if resnet_time_scale_shift == "default" else None
314
+ )
315
+
316
+ # there is always at least one resnet
317
+ resnets = [
318
+ ResnetBlock2D(
319
+ in_channels=in_channels,
320
+ out_channels=in_channels,
321
+ temb_channels=temb_channels,
322
+ eps=resnet_eps,
323
+ groups=resnet_groups,
324
+ dropout=dropout,
325
+ time_embedding_norm=resnet_time_scale_shift,
326
+ non_linearity=resnet_act_fn,
327
+ output_scale_factor=output_scale_factor,
328
+ pre_norm=resnet_pre_norm,
329
+ )
330
+ ]
331
+ attentions = []
332
+
333
+ if attention_head_dim is None:
334
+ logger.warning(
335
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
336
+ )
337
+ attention_head_dim = in_channels
338
+
339
+ for _ in range(num_layers):
340
+ if self.add_attention:
341
+ attentions.append(
342
+ Attention(
343
+ in_channels,
344
+ heads=in_channels // attention_head_dim,
345
+ dim_head=attention_head_dim,
346
+ rescale_output_factor=output_scale_factor,
347
+ eps=resnet_eps,
348
+ norm_num_groups=attn_groups,
349
+ spatial_norm_dim=(
350
+ temb_channels
351
+ if resnet_time_scale_shift == "spatial"
352
+ else None
353
+ ),
354
+ residual_connection=True,
355
+ bias=True,
356
+ upcast_softmax=True,
357
+ _from_deprecated_attn_block=True,
358
+ )
359
+ )
360
+ else:
361
+ attentions.append(None)
362
+
363
+ resnets.append(
364
+ ResnetBlock2D(
365
+ in_channels=in_channels,
366
+ out_channels=in_channels,
367
+ temb_channels=temb_channels,
368
+ eps=resnet_eps,
369
+ groups=resnet_groups,
370
+ dropout=dropout,
371
+ time_embedding_norm=resnet_time_scale_shift,
372
+ non_linearity=resnet_act_fn,
373
+ output_scale_factor=output_scale_factor,
374
+ pre_norm=resnet_pre_norm,
375
+ )
376
+ )
377
+
378
+ self.attentions = nn.ModuleList(attentions)
379
+ self.resnets = nn.ModuleList(resnets)
380
+
381
+ def forward(
382
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
383
+ ) -> torch.FloatTensor:
384
+ """
385
+ Forward pass of the UNetMidBlock2D class.
386
+
387
+ Args:
388
+ hidden_states (torch.FloatTensor): The input tensor to the UNetMidBlock2D.
389
+ temb (Optional[torch.FloatTensor], optional): The token embedding tensor. Defaults to None.
390
+
391
+ Returns:
392
+ torch.FloatTensor: The output tensor after passing through the UNetMidBlock2D.
393
+ """
394
+ # Your implementation here
395
+ hidden_states = self.resnets[0](hidden_states, temb)
396
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
397
+ if attn is not None:
398
+ hidden_states = attn(hidden_states, temb=temb)
399
+ hidden_states = resnet(hidden_states, temb)
400
+
401
+ return hidden_states
402
+
403
+
404
+ class UNetMidBlock2DCrossAttn(nn.Module):
405
+ """
406
+ UNetMidBlock2DCrossAttn is a class that represents a mid-block 2D UNet with cross-attention.
407
+
408
+ This block is responsible for processing the input tensor with a series of residual blocks,
409
+ and applying cross-attention mechanism to attend to the global information in the encoder.
410
+
411
+ Args:
412
+ in_channels (int): The number of input channels.
413
+ temb_channels (int): The number of channels for the token embedding.
414
+ dropout (float, optional): The dropout rate. Defaults to 0.0.
415
+ num_layers (int, optional): The number of layers in the residual blocks. Defaults to 1.
416
+ resnet_eps (float, optional): The epsilon value for the residual blocks. Defaults to 1e-6.
417
+ resnet_time_scale_shift (str, optional): The time scale shift type for the residual blocks. Defaults to "default".
418
+ resnet_act_fn (str, optional): The activation function for the residual blocks. Defaults to "swish".
419
+ resnet_groups (int, optional): The number of groups for the residual blocks. Defaults to 32.
420
+ resnet_pre_norm (bool, optional): Whether to apply pre-normalization for the residual blocks. Defaults to True.
421
+ num_attention_heads (int, optional): The number of attention heads for cross-attention. Defaults to 1.
422
+ cross_attention_dim (int, optional): The dimension of the cross-attention. Defaults to 1280.
423
+ output_scale_factor (float, optional): The scale factor for the output tensor. Defaults to 1.0.
424
+ """
425
+ def __init__(
426
+ self,
427
+ in_channels: int,
428
+ temb_channels: int,
429
+ dropout: float = 0.0,
430
+ num_layers: int = 1,
431
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
432
+ resnet_eps: float = 1e-6,
433
+ resnet_time_scale_shift: str = "default",
434
+ resnet_act_fn: str = "swish",
435
+ resnet_groups: int = 32,
436
+ resnet_pre_norm: bool = True,
437
+ num_attention_heads: int = 1,
438
+ output_scale_factor: float = 1.0,
439
+ cross_attention_dim: int = 1280,
440
+ dual_cross_attention: bool = False,
441
+ use_linear_projection: bool = False,
442
+ upcast_attention: bool = False,
443
+ attention_type: str = "default",
444
+ ):
445
+ super().__init__()
446
+
447
+ self.has_cross_attention = True
448
+ self.num_attention_heads = num_attention_heads
449
+ resnet_groups = (
450
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
451
+ )
452
+
453
+ # support for variable transformer layers per block
454
+ if isinstance(transformer_layers_per_block, int):
455
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
456
+
457
+ # there is always at least one resnet
458
+ resnets = [
459
+ ResnetBlock2D(
460
+ in_channels=in_channels,
461
+ out_channels=in_channels,
462
+ temb_channels=temb_channels,
463
+ eps=resnet_eps,
464
+ groups=resnet_groups,
465
+ dropout=dropout,
466
+ time_embedding_norm=resnet_time_scale_shift,
467
+ non_linearity=resnet_act_fn,
468
+ output_scale_factor=output_scale_factor,
469
+ pre_norm=resnet_pre_norm,
470
+ )
471
+ ]
472
+ attentions = []
473
+
474
+ for i in range(num_layers):
475
+ if not dual_cross_attention:
476
+ attentions.append(
477
+ Transformer2DModel(
478
+ num_attention_heads,
479
+ in_channels // num_attention_heads,
480
+ in_channels=in_channels,
481
+ num_layers=transformer_layers_per_block[i],
482
+ cross_attention_dim=cross_attention_dim,
483
+ norm_num_groups=resnet_groups,
484
+ use_linear_projection=use_linear_projection,
485
+ upcast_attention=upcast_attention,
486
+ attention_type=attention_type,
487
+ )
488
+ )
489
+ else:
490
+ attentions.append(
491
+ DualTransformer2DModel(
492
+ num_attention_heads,
493
+ in_channels // num_attention_heads,
494
+ in_channels=in_channels,
495
+ num_layers=1,
496
+ cross_attention_dim=cross_attention_dim,
497
+ norm_num_groups=resnet_groups,
498
+ )
499
+ )
500
+ resnets.append(
501
+ ResnetBlock2D(
502
+ in_channels=in_channels,
503
+ out_channels=in_channels,
504
+ temb_channels=temb_channels,
505
+ eps=resnet_eps,
506
+ groups=resnet_groups,
507
+ dropout=dropout,
508
+ time_embedding_norm=resnet_time_scale_shift,
509
+ non_linearity=resnet_act_fn,
510
+ output_scale_factor=output_scale_factor,
511
+ pre_norm=resnet_pre_norm,
512
+ )
513
+ )
514
+
515
+ self.attentions = nn.ModuleList(attentions)
516
+ self.resnets = nn.ModuleList(resnets)
517
+
518
+ self.gradient_checkpointing = False
519
+
520
+ def forward(
521
+ self,
522
+ hidden_states: torch.FloatTensor,
523
+ temb: Optional[torch.FloatTensor] = None,
524
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
525
+ attention_mask: Optional[torch.FloatTensor] = None,
526
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
527
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
528
+ ) -> torch.FloatTensor:
529
+ """
530
+ Forward pass for the UNetMidBlock2DCrossAttn class.
531
+
532
+ Args:
533
+ hidden_states (torch.FloatTensor): The input hidden states tensor.
534
+ temb (Optional[torch.FloatTensor], optional): The optional tensor for time embeddings.
535
+ encoder_hidden_states (Optional[torch.FloatTensor], optional): The optional encoder hidden states tensor.
536
+ attention_mask (Optional[torch.FloatTensor], optional): The optional attention mask tensor.
537
+ cross_attention_kwargs (Optional[Dict[str, Any]], optional): The optional cross-attention kwargs tensor.
538
+ encoder_attention_mask (Optional[torch.FloatTensor], optional): The optional encoder attention mask tensor.
539
+
540
+ Returns:
541
+ torch.FloatTensor: The output tensor after passing through the UNetMidBlock2DCrossAttn layers.
542
+ """
543
+ lora_scale = (
544
+ cross_attention_kwargs.get("scale", 1.0)
545
+ if cross_attention_kwargs is not None
546
+ else 1.0
547
+ )
548
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
549
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
550
+ if self.training and self.gradient_checkpointing:
551
+
552
+ def create_custom_forward(module, return_dict=None):
553
+ def custom_forward(*inputs):
554
+ if return_dict is not None:
555
+ return module(*inputs, return_dict=return_dict)
556
+
557
+ return module(*inputs)
558
+
559
+ return custom_forward
560
+
561
+ ckpt_kwargs: Dict[str, Any] = (
562
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
563
+ )
564
+ hidden_states, _ref_feature = attn(
565
+ hidden_states,
566
+ encoder_hidden_states=encoder_hidden_states,
567
+ cross_attention_kwargs=cross_attention_kwargs,
568
+ attention_mask=attention_mask,
569
+ encoder_attention_mask=encoder_attention_mask,
570
+ return_dict=False,
571
+ )
572
+ hidden_states = torch.utils.checkpoint.checkpoint(
573
+ create_custom_forward(resnet),
574
+ hidden_states,
575
+ temb,
576
+ **ckpt_kwargs,
577
+ )
578
+ else:
579
+ hidden_states, _ref_feature = attn(
580
+ hidden_states,
581
+ encoder_hidden_states=encoder_hidden_states,
582
+ cross_attention_kwargs=cross_attention_kwargs,
583
+ attention_mask=attention_mask,
584
+ encoder_attention_mask=encoder_attention_mask,
585
+ return_dict=False,
586
+ )
587
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
588
+
589
+ return hidden_states
590
+
591
+
592
+ class CrossAttnDownBlock2D(nn.Module):
593
+ """
594
+ CrossAttnDownBlock2D is a class that represents a 2D cross-attention downsampling block.
595
+
596
+ This block is used in the UNet model and consists of a series of ResNet blocks and Transformer layers.
597
+ It takes input hidden states, a tensor embedding, and optional encoder hidden states, attention mask,
598
+ and cross-attention kwargs. The block performs a series of operations including downsampling, cross-attention,
599
+ and residual connections.
600
+
601
+ Attributes:
602
+ in_channels (int): The number of input channels.
603
+ out_channels (int): The number of output channels.
604
+ temb_channels (int): The number of tensor embedding channels.
605
+ dropout (float): The dropout rate.
606
+ num_layers (int): The number of ResNet layers.
607
+ transformer_layers_per_block (Union[int, Tuple[int]]): The number of Transformer layers per block.
608
+ resnet_eps (float): The ResNet epsilon value.
609
+ resnet_time_scale_shift (str): The ResNet time scale shift type.
610
+ resnet_act_fn (str): The ResNet activation function.
611
+ resnet_groups (int): The ResNet group size.
612
+ resnet_pre_norm (bool): Whether to use ResNet pre-normalization.
613
+ num_attention_heads (int): The number of attention heads.
614
+ cross_attention_dim (int): The cross-attention dimension.
615
+ output_scale_factor (float): The output scale factor.
616
+ downsample_padding (int): The downsampling padding.
617
+ add_downsample (bool): Whether to add downsampling.
618
+ dual_cross_attention (bool): Whether to use dual cross-attention.
619
+ use_linear_projection (bool): Whether to use linear projection.
620
+ only_cross_attention (bool): Whether to use only cross-attention.
621
+ upcast_attention (bool): Whether to upcast attention.
622
+ attention_type (str): The attention type.
623
+ """
624
+ def __init__(
625
+ self,
626
+ in_channels: int,
627
+ out_channels: int,
628
+ temb_channels: int,
629
+ dropout: float = 0.0,
630
+ num_layers: int = 1,
631
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
632
+ resnet_eps: float = 1e-6,
633
+ resnet_time_scale_shift: str = "default",
634
+ resnet_act_fn: str = "swish",
635
+ resnet_groups: int = 32,
636
+ resnet_pre_norm: bool = True,
637
+ num_attention_heads: int = 1,
638
+ cross_attention_dim: int = 1280,
639
+ output_scale_factor: float = 1.0,
640
+ downsample_padding: int = 1,
641
+ add_downsample: bool = True,
642
+ dual_cross_attention: bool = False,
643
+ use_linear_projection: bool = False,
644
+ only_cross_attention: bool = False,
645
+ upcast_attention: bool = False,
646
+ attention_type: str = "default",
647
+ ):
648
+ super().__init__()
649
+ resnets = []
650
+ attentions = []
651
+
652
+ self.has_cross_attention = True
653
+ self.num_attention_heads = num_attention_heads
654
+ if isinstance(transformer_layers_per_block, int):
655
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
656
+
657
+ for i in range(num_layers):
658
+ in_channels = in_channels if i == 0 else out_channels
659
+ resnets.append(
660
+ ResnetBlock2D(
661
+ in_channels=in_channels,
662
+ out_channels=out_channels,
663
+ temb_channels=temb_channels,
664
+ eps=resnet_eps,
665
+ groups=resnet_groups,
666
+ dropout=dropout,
667
+ time_embedding_norm=resnet_time_scale_shift,
668
+ non_linearity=resnet_act_fn,
669
+ output_scale_factor=output_scale_factor,
670
+ pre_norm=resnet_pre_norm,
671
+ )
672
+ )
673
+ if not dual_cross_attention:
674
+ attentions.append(
675
+ Transformer2DModel(
676
+ num_attention_heads,
677
+ out_channels // num_attention_heads,
678
+ in_channels=out_channels,
679
+ num_layers=transformer_layers_per_block[i],
680
+ cross_attention_dim=cross_attention_dim,
681
+ norm_num_groups=resnet_groups,
682
+ use_linear_projection=use_linear_projection,
683
+ only_cross_attention=only_cross_attention,
684
+ upcast_attention=upcast_attention,
685
+ attention_type=attention_type,
686
+ )
687
+ )
688
+ else:
689
+ attentions.append(
690
+ DualTransformer2DModel(
691
+ num_attention_heads,
692
+ out_channels // num_attention_heads,
693
+ in_channels=out_channels,
694
+ num_layers=1,
695
+ cross_attention_dim=cross_attention_dim,
696
+ norm_num_groups=resnet_groups,
697
+ )
698
+ )
699
+ self.attentions = nn.ModuleList(attentions)
700
+ self.resnets = nn.ModuleList(resnets)
701
+
702
+ if add_downsample:
703
+ self.downsamplers = nn.ModuleList(
704
+ [
705
+ Downsample2D(
706
+ out_channels,
707
+ use_conv=True,
708
+ out_channels=out_channels,
709
+ padding=downsample_padding,
710
+ name="op",
711
+ )
712
+ ]
713
+ )
714
+ else:
715
+ self.downsamplers = None
716
+
717
+ self.gradient_checkpointing = False
718
+
719
+ def forward(
720
+ self,
721
+ hidden_states: torch.FloatTensor,
722
+ temb: Optional[torch.FloatTensor] = None,
723
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
724
+ attention_mask: Optional[torch.FloatTensor] = None,
725
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
726
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
727
+ additional_residuals: Optional[torch.FloatTensor] = None,
728
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
729
+ """
730
+ Forward pass for the CrossAttnDownBlock2D class.
731
+
732
+ Args:
733
+ hidden_states (torch.FloatTensor): The input hidden states.
734
+ temb (Optional[torch.FloatTensor], optional): The token embeddings. Defaults to None.
735
+ encoder_hidden_states (Optional[torch.FloatTensor], optional): The encoder hidden states. Defaults to None.
736
+ attention_mask (Optional[torch.FloatTensor], optional): The attention mask. Defaults to None.
737
+ cross_attention_kwargs (Optional[Dict[str, Any]], optional): The cross-attention kwargs. Defaults to None.
738
+ encoder_attention_mask (Optional[torch.FloatTensor], optional): The encoder attention mask. Defaults to None.
739
+ additional_residuals (Optional[torch.FloatTensor], optional): The additional residuals. Defaults to None.
740
+
741
+ Returns:
742
+ Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: The output hidden states and residuals.
743
+ """
744
+ output_states = ()
745
+
746
+ lora_scale = (
747
+ cross_attention_kwargs.get("scale", 1.0)
748
+ if cross_attention_kwargs is not None
749
+ else 1.0
750
+ )
751
+
752
+ blocks = list(zip(self.resnets, self.attentions))
753
+
754
+ for i, (resnet, attn) in enumerate(blocks):
755
+ if self.training and self.gradient_checkpointing:
756
+
757
+ def create_custom_forward(module, return_dict=None):
758
+ def custom_forward(*inputs):
759
+ if return_dict is not None:
760
+ return module(*inputs, return_dict=return_dict)
761
+
762
+ return module(*inputs)
763
+
764
+ return custom_forward
765
+
766
+ ckpt_kwargs: Dict[str, Any] = (
767
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
768
+ )
769
+ hidden_states = torch.utils.checkpoint.checkpoint(
770
+ create_custom_forward(resnet),
771
+ hidden_states,
772
+ temb,
773
+ **ckpt_kwargs,
774
+ )
775
+ hidden_states, _ref_feature = attn(
776
+ hidden_states,
777
+ encoder_hidden_states=encoder_hidden_states,
778
+ cross_attention_kwargs=cross_attention_kwargs,
779
+ attention_mask=attention_mask,
780
+ encoder_attention_mask=encoder_attention_mask,
781
+ return_dict=False,
782
+ )
783
+ else:
784
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
785
+ hidden_states, _ref_feature = attn(
786
+ hidden_states,
787
+ encoder_hidden_states=encoder_hidden_states,
788
+ cross_attention_kwargs=cross_attention_kwargs,
789
+ attention_mask=attention_mask,
790
+ encoder_attention_mask=encoder_attention_mask,
791
+ return_dict=False,
792
+ )
793
+
794
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
795
+ if i == len(blocks) - 1 and additional_residuals is not None:
796
+ hidden_states = hidden_states + additional_residuals
797
+
798
+ output_states = output_states + (hidden_states,)
799
+
800
+ if self.downsamplers is not None:
801
+ for downsampler in self.downsamplers:
802
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
803
+
804
+ output_states = output_states + (hidden_states,)
805
+
806
+ return hidden_states, output_states
807
+
808
+
809
+ class DownBlock2D(nn.Module):
810
+ """
811
+ DownBlock2D is a class that represents a 2D downsampling block in a neural network.
812
+
813
+ It takes the following parameters:
814
+ - in_channels (int): The number of input channels in the block.
815
+ - out_channels (int): The number of output channels in the block.
816
+ - temb_channels (int): The number of channels in the token embedding.
817
+ - dropout (float): The dropout rate for the block.
818
+ - num_layers (int): The number of layers in the block.
819
+ - resnet_eps (float): The epsilon value for the ResNet layer.
820
+ - resnet_time_scale_shift (str): The type of activation function for the ResNet layer.
821
+ - resnet_act_fn (str): The activation function for the ResNet layer.
822
+ - resnet_groups (int): The number of groups in the ResNet layer.
823
+ - resnet_pre_norm (bool): Whether to apply layer normalization before the ResNet layer.
824
+ - output_scale_factor (float): The scale factor for the output.
825
+ - add_downsample (bool): Whether to add a downsampling layer.
826
+ - downsample_padding (int): The padding value for the downsampling layer.
827
+
828
+ The DownBlock2D class inherits from the nn.Module class and defines the following methods:
829
+ - __init__: Initializes the DownBlock2D class with the given parameters.
830
+ - forward: Forward pass of the DownBlock2D class.
831
+
832
+ The forward method takes the following parameters:
833
+ - hidden_states (torch.FloatTensor): The input tensor to the block.
834
+ - temb (Optional[torch.FloatTensor]): The token embedding tensor.
835
+ - scale (float): The scale factor for the input tensor.
836
+
837
+ The forward method returns a tuple containing the output tensor and a tuple of hidden states.
838
+ """
839
+ def __init__(
840
+ self,
841
+ in_channels: int,
842
+ out_channels: int,
843
+ temb_channels: int,
844
+ dropout: float = 0.0,
845
+ num_layers: int = 1,
846
+ resnet_eps: float = 1e-6,
847
+ resnet_time_scale_shift: str = "default",
848
+ resnet_act_fn: str = "swish",
849
+ resnet_groups: int = 32,
850
+ resnet_pre_norm: bool = True,
851
+ output_scale_factor: float = 1.0,
852
+ add_downsample: bool = True,
853
+ downsample_padding: int = 1,
854
+ ):
855
+ super().__init__()
856
+ resnets = []
857
+
858
+ for i in range(num_layers):
859
+ in_channels = in_channels if i == 0 else out_channels
860
+ resnets.append(
861
+ ResnetBlock2D(
862
+ in_channels=in_channels,
863
+ out_channels=out_channels,
864
+ temb_channels=temb_channels,
865
+ eps=resnet_eps,
866
+ groups=resnet_groups,
867
+ dropout=dropout,
868
+ time_embedding_norm=resnet_time_scale_shift,
869
+ non_linearity=resnet_act_fn,
870
+ output_scale_factor=output_scale_factor,
871
+ pre_norm=resnet_pre_norm,
872
+ )
873
+ )
874
+
875
+ self.resnets = nn.ModuleList(resnets)
876
+
877
+ if add_downsample:
878
+ self.downsamplers = nn.ModuleList(
879
+ [
880
+ Downsample2D(
881
+ out_channels,
882
+ use_conv=True,
883
+ out_channels=out_channels,
884
+ padding=downsample_padding,
885
+ name="op",
886
+ )
887
+ ]
888
+ )
889
+ else:
890
+ self.downsamplers = None
891
+
892
+ self.gradient_checkpointing = False
893
+
894
+ def forward(
895
+ self,
896
+ hidden_states: torch.FloatTensor,
897
+ temb: Optional[torch.FloatTensor] = None,
898
+ scale: float = 1.0,
899
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
900
+ """
901
+ Forward pass of the DownBlock2D class.
902
+
903
+ Args:
904
+ hidden_states (torch.FloatTensor): The input tensor to the DownBlock2D layer.
905
+ temb (Optional[torch.FloatTensor], optional): The token embedding tensor. Defaults to None.
906
+ scale (float, optional): The scale factor for the input tensor. Defaults to 1.0.
907
+
908
+ Returns:
909
+ Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: The output tensor and any additional hidden states.
910
+ """
911
+ output_states = ()
912
+
913
+ for resnet in self.resnets:
914
+ if self.training and self.gradient_checkpointing:
915
+
916
+ def create_custom_forward(module):
917
+ def custom_forward(*inputs):
918
+ return module(*inputs)
919
+
920
+ return custom_forward
921
+
922
+ if is_torch_version(">=", "1.11.0"):
923
+ hidden_states = torch.utils.checkpoint.checkpoint(
924
+ create_custom_forward(resnet),
925
+ hidden_states,
926
+ temb,
927
+ use_reentrant=False,
928
+ )
929
+ else:
930
+ hidden_states = torch.utils.checkpoint.checkpoint(
931
+ create_custom_forward(resnet), hidden_states, temb
932
+ )
933
+ else:
934
+ hidden_states = resnet(hidden_states, temb, scale=scale)
935
+
936
+ output_states = output_states + (hidden_states,)
937
+
938
+ if self.downsamplers is not None:
939
+ for downsampler in self.downsamplers:
940
+ hidden_states = downsampler(hidden_states, scale=scale)
941
+
942
+ output_states = output_states + (hidden_states,)
943
+
944
+ return hidden_states, output_states
945
+
946
+
947
+ class CrossAttnUpBlock2D(nn.Module):
948
+ """
949
+ CrossAttnUpBlock2D is a class that represents a cross-attention UpBlock in a 2D UNet architecture.
950
+
951
+ This block is responsible for upsampling the input tensor and performing cross-attention with the encoder's hidden states.
952
+
953
+ Args:
954
+ in_channels (int): The number of input channels in the tensor.
955
+ out_channels (int): The number of output channels in the tensor.
956
+ prev_output_channel (int): The number of channels in the previous output tensor.
957
+ temb_channels (int): The number of channels in the token embedding tensor.
958
+ resolution_idx (Optional[int]): The index of the resolution in the model.
959
+ dropout (float): The dropout rate for the layer.
960
+ num_layers (int): The number of layers in the ResNet block.
961
+ transformer_layers_per_block (Union[int, Tuple[int]]): The number of transformer layers per block.
962
+ resnet_eps (float): The epsilon value for the ResNet layer.
963
+ resnet_time_scale_shift (str): The type of time scale shift to be applied in the ResNet layer.
964
+ resnet_act_fn (str): The activation function to be used in the ResNet layer.
965
+ resnet_groups (int): The number of groups in the ResNet layer.
966
+ resnet_pre_norm (bool): Whether to use pre-normalization in the ResNet layer.
967
+ num_attention_heads (int): The number of attention heads in the cross-attention layer.
968
+ cross_attention_dim (int): The dimension of the cross-attention layer.
969
+ output_scale_factor (float): The scale factor for the output tensor.
970
+ add_upsample (bool): Whether to add upsampling to the block.
971
+ dual_cross_attention (bool): Whether to use dual cross-attention.
972
+ use_linear_projection (bool): Whether to use linear projection in the cross-attention layer.
973
+ only_cross_attention (bool): Whether to only use cross-attention and no self-attention.
974
+ upcast_attention (bool): Whether to upcast the attention weights.
975
+ attention_type (str): The type of attention to be used in the cross-attention layer.
976
+
977
+ Attributes:
978
+ up_block (nn.Module): The UpBlock module responsible for upsampling the input tensor.
979
+ cross_attn (nn.Module): The cross-attention module that performs attention between
980
+ the decoder's hidden states and the encoder's hidden states.
981
+ resnet_blocks (nn.ModuleList): A list of ResNet blocks that make up the ResNet portion of the block.
982
+ """
983
+
984
+ def __init__(
985
+ self,
986
+ in_channels: int,
987
+ out_channels: int,
988
+ prev_output_channel: int,
989
+ temb_channels: int,
990
+ resolution_idx: Optional[int] = None,
991
+ dropout: float = 0.0,
992
+ num_layers: int = 1,
993
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
994
+ resnet_eps: float = 1e-6,
995
+ resnet_time_scale_shift: str = "default",
996
+ resnet_act_fn: str = "swish",
997
+ resnet_groups: int = 32,
998
+ resnet_pre_norm: bool = True,
999
+ num_attention_heads: int = 1,
1000
+ cross_attention_dim: int = 1280,
1001
+ output_scale_factor: float = 1.0,
1002
+ add_upsample: bool = True,
1003
+ dual_cross_attention: bool = False,
1004
+ use_linear_projection: bool = False,
1005
+ only_cross_attention: bool = False,
1006
+ upcast_attention: bool = False,
1007
+ attention_type: str = "default",
1008
+ ):
1009
+ super().__init__()
1010
+ resnets = []
1011
+ attentions = []
1012
+
1013
+ self.has_cross_attention = True
1014
+ self.num_attention_heads = num_attention_heads
1015
+
1016
+ if isinstance(transformer_layers_per_block, int):
1017
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
1018
+
1019
+ for i in range(num_layers):
1020
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1021
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1022
+
1023
+ resnets.append(
1024
+ ResnetBlock2D(
1025
+ in_channels=resnet_in_channels + res_skip_channels,
1026
+ out_channels=out_channels,
1027
+ temb_channels=temb_channels,
1028
+ eps=resnet_eps,
1029
+ groups=resnet_groups,
1030
+ dropout=dropout,
1031
+ time_embedding_norm=resnet_time_scale_shift,
1032
+ non_linearity=resnet_act_fn,
1033
+ output_scale_factor=output_scale_factor,
1034
+ pre_norm=resnet_pre_norm,
1035
+ )
1036
+ )
1037
+ if not dual_cross_attention:
1038
+ attentions.append(
1039
+ Transformer2DModel(
1040
+ num_attention_heads,
1041
+ out_channels // num_attention_heads,
1042
+ in_channels=out_channels,
1043
+ num_layers=transformer_layers_per_block[i],
1044
+ cross_attention_dim=cross_attention_dim,
1045
+ norm_num_groups=resnet_groups,
1046
+ use_linear_projection=use_linear_projection,
1047
+ only_cross_attention=only_cross_attention,
1048
+ upcast_attention=upcast_attention,
1049
+ attention_type=attention_type,
1050
+ )
1051
+ )
1052
+ else:
1053
+ attentions.append(
1054
+ DualTransformer2DModel(
1055
+ num_attention_heads,
1056
+ out_channels // num_attention_heads,
1057
+ in_channels=out_channels,
1058
+ num_layers=1,
1059
+ cross_attention_dim=cross_attention_dim,
1060
+ norm_num_groups=resnet_groups,
1061
+ )
1062
+ )
1063
+ self.attentions = nn.ModuleList(attentions)
1064
+ self.resnets = nn.ModuleList(resnets)
1065
+
1066
+ if add_upsample:
1067
+ self.upsamplers = nn.ModuleList(
1068
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1069
+ )
1070
+ else:
1071
+ self.upsamplers = None
1072
+
1073
+ self.gradient_checkpointing = False
1074
+ self.resolution_idx = resolution_idx
1075
+
1076
+ def forward(
1077
+ self,
1078
+ hidden_states: torch.FloatTensor,
1079
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1080
+ temb: Optional[torch.FloatTensor] = None,
1081
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1082
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1083
+ upsample_size: Optional[int] = None,
1084
+ attention_mask: Optional[torch.FloatTensor] = None,
1085
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1086
+ ) -> torch.FloatTensor:
1087
+ """
1088
+ Forward pass for the CrossAttnUpBlock2D class.
1089
+
1090
+ Args:
1091
+ self (CrossAttnUpBlock2D): An instance of the CrossAttnUpBlock2D class.
1092
+ hidden_states (torch.FloatTensor): The input hidden states tensor.
1093
+ res_hidden_states_tuple (Tuple[torch.FloatTensor, ...]): A tuple of residual hidden states tensors.
1094
+ temb (Optional[torch.FloatTensor], optional): The token embeddings tensor. Defaults to None.
1095
+ encoder_hidden_states (Optional[torch.FloatTensor], optional): The encoder hidden states tensor. Defaults to None.
1096
+ cross_attention_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for cross attention. Defaults to None.
1097
+ upsample_size (Optional[int], optional): The upsample size. Defaults to None.
1098
+ attention_mask (Optional[torch.FloatTensor], optional): The attention mask tensor. Defaults to None.
1099
+ encoder_attention_mask (Optional[torch.FloatTensor], optional): The encoder attention mask tensor. Defaults to None.
1100
+
1101
+ Returns:
1102
+ torch.FloatTensor: The output tensor after passing through the block.
1103
+ """
1104
+ lora_scale = (
1105
+ cross_attention_kwargs.get("scale", 1.0)
1106
+ if cross_attention_kwargs is not None
1107
+ else 1.0
1108
+ )
1109
+ is_freeu_enabled = (
1110
+ getattr(self, "s1", None)
1111
+ and getattr(self, "s2", None)
1112
+ and getattr(self, "b1", None)
1113
+ and getattr(self, "b2", None)
1114
+ )
1115
+
1116
+ for resnet, attn in zip(self.resnets, self.attentions):
1117
+ # pop res hidden states
1118
+ res_hidden_states = res_hidden_states_tuple[-1]
1119
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1120
+
1121
+ # FreeU: Only operate on the first two stages
1122
+ if is_freeu_enabled:
1123
+ hidden_states, res_hidden_states = apply_freeu(
1124
+ self.resolution_idx,
1125
+ hidden_states,
1126
+ res_hidden_states,
1127
+ s1=self.s1,
1128
+ s2=self.s2,
1129
+ b1=self.b1,
1130
+ b2=self.b2,
1131
+ )
1132
+
1133
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1134
+
1135
+ if self.training and self.gradient_checkpointing:
1136
+
1137
+ def create_custom_forward(module, return_dict=None):
1138
+ def custom_forward(*inputs):
1139
+ if return_dict is not None:
1140
+ return module(*inputs, return_dict=return_dict)
1141
+
1142
+ return module(*inputs)
1143
+
1144
+ return custom_forward
1145
+
1146
+ ckpt_kwargs: Dict[str, Any] = (
1147
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1148
+ )
1149
+ hidden_states = torch.utils.checkpoint.checkpoint(
1150
+ create_custom_forward(resnet),
1151
+ hidden_states,
1152
+ temb,
1153
+ **ckpt_kwargs,
1154
+ )
1155
+ hidden_states, _ref_feature = attn(
1156
+ hidden_states,
1157
+ encoder_hidden_states=encoder_hidden_states,
1158
+ cross_attention_kwargs=cross_attention_kwargs,
1159
+ attention_mask=attention_mask,
1160
+ encoder_attention_mask=encoder_attention_mask,
1161
+ return_dict=False,
1162
+ )
1163
+ else:
1164
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1165
+ hidden_states, _ref_feature = attn(
1166
+ hidden_states,
1167
+ encoder_hidden_states=encoder_hidden_states,
1168
+ cross_attention_kwargs=cross_attention_kwargs,
1169
+ attention_mask=attention_mask,
1170
+ encoder_attention_mask=encoder_attention_mask,
1171
+ return_dict=False,
1172
+ )
1173
+
1174
+ if self.upsamplers is not None:
1175
+ for upsampler in self.upsamplers:
1176
+ hidden_states = upsampler(
1177
+ hidden_states, upsample_size, scale=lora_scale
1178
+ )
1179
+
1180
+ return hidden_states
1181
+
1182
+
1183
+ class UpBlock2D(nn.Module):
1184
+ """
1185
+ UpBlock2D is a class that represents a 2D upsampling block in a neural network.
1186
+
1187
+ This block is used for upsampling the input tensor by a factor of 2 in both dimensions.
1188
+ It takes the previous output channel, input channels, and output channels as input
1189
+ and applies a series of convolutional layers, batch normalization, and activation
1190
+ functions to produce the upsampled tensor.
1191
+
1192
+ Args:
1193
+ in_channels (int): The number of input channels in the tensor.
1194
+ prev_output_channel (int): The number of channels in the previous output tensor.
1195
+ out_channels (int): The number of output channels in the tensor.
1196
+ temb_channels (int): The number of channels in the time embedding tensor.
1197
+ resolution_idx (Optional[int], optional): The index of the resolution in the sequence of resolutions. Defaults to None.
1198
+ dropout (float, optional): The dropout rate to be applied to the convolutional layers. Defaults to 0.0.
1199
+ num_layers (int, optional): The number of convolutional layers in the block. Defaults to 1.
1200
+ resnet_eps (float, optional): The epsilon value used in the batch normalization layer. Defaults to 1e-6.
1201
+ resnet_time_scale_shift (str, optional): The type of activation function to be applied after the convolutional layers. Defaults to "default".
1202
+ resnet_act_fn (str, optional): The activation function to be applied after the batch normalization layer. Defaults to "swish".
1203
+ resnet_groups (int, optional): The number of groups in the group normalization layer. Defaults to 32.
1204
+ resnet_pre_norm (bool, optional): A flag indicating whether to apply layer normalization before the activation function. Defaults to True.
1205
+ output_scale_factor (float, optional): The scale factor to be applied to the output tensor. Defaults to 1.0.
1206
+ add_upsample (bool, optional): A flag indicating whether to add an upsampling layer to the block. Defaults to True.
1207
+
1208
+ Attributes:
1209
+ layers (nn.ModuleList): A list of nn.Module objects representing the convolutional layers in the block.
1210
+ upsample (nn.Module): The upsampling layer in the block, if add_upsample is True.
1211
+
1212
+ """
1213
+
1214
+ def __init__(
1215
+ self,
1216
+ in_channels: int,
1217
+ prev_output_channel: int,
1218
+ out_channels: int,
1219
+ temb_channels: int,
1220
+ resolution_idx: Optional[int] = None,
1221
+ dropout: float = 0.0,
1222
+ num_layers: int = 1,
1223
+ resnet_eps: float = 1e-6,
1224
+ resnet_time_scale_shift: str = "default",
1225
+ resnet_act_fn: str = "swish",
1226
+ resnet_groups: int = 32,
1227
+ resnet_pre_norm: bool = True,
1228
+ output_scale_factor: float = 1.0,
1229
+ add_upsample: bool = True,
1230
+ ):
1231
+ super().__init__()
1232
+ resnets = []
1233
+
1234
+ for i in range(num_layers):
1235
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1236
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1237
+
1238
+ resnets.append(
1239
+ ResnetBlock2D(
1240
+ in_channels=resnet_in_channels + res_skip_channels,
1241
+ out_channels=out_channels,
1242
+ temb_channels=temb_channels,
1243
+ eps=resnet_eps,
1244
+ groups=resnet_groups,
1245
+ dropout=dropout,
1246
+ time_embedding_norm=resnet_time_scale_shift,
1247
+ non_linearity=resnet_act_fn,
1248
+ output_scale_factor=output_scale_factor,
1249
+ pre_norm=resnet_pre_norm,
1250
+ )
1251
+ )
1252
+
1253
+ self.resnets = nn.ModuleList(resnets)
1254
+
1255
+ if add_upsample:
1256
+ self.upsamplers = nn.ModuleList(
1257
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1258
+ )
1259
+ else:
1260
+ self.upsamplers = None
1261
+
1262
+ self.gradient_checkpointing = False
1263
+ self.resolution_idx = resolution_idx
1264
+
1265
+ def forward(
1266
+ self,
1267
+ hidden_states: torch.FloatTensor,
1268
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1269
+ temb: Optional[torch.FloatTensor] = None,
1270
+ upsample_size: Optional[int] = None,
1271
+ scale: float = 1.0,
1272
+ ) -> torch.FloatTensor:
1273
+
1274
+ """
1275
+ Forward pass for the UpBlock2D class.
1276
+
1277
+ Args:
1278
+ self (UpBlock2D): An instance of the UpBlock2D class.
1279
+ hidden_states (torch.FloatTensor): The input tensor to the block.
1280
+ res_hidden_states_tuple (Tuple[torch.FloatTensor, ...]): A tuple of residual hidden states.
1281
+ temb (Optional[torch.FloatTensor], optional): The token embeddings. Defaults to None.
1282
+ upsample_size (Optional[int], optional): The size to upsample the input tensor to. Defaults to None.
1283
+ scale (float, optional): The scale factor to apply to the input tensor. Defaults to 1.0.
1284
+
1285
+ Returns:
1286
+ torch.FloatTensor: The output tensor after passing through the block.
1287
+ """
1288
+ is_freeu_enabled = (
1289
+ getattr(self, "s1", None)
1290
+ and getattr(self, "s2", None)
1291
+ and getattr(self, "b1", None)
1292
+ and getattr(self, "b2", None)
1293
+ )
1294
+
1295
+ for resnet in self.resnets:
1296
+ # pop res hidden states
1297
+ res_hidden_states = res_hidden_states_tuple[-1]
1298
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1299
+
1300
+ # FreeU: Only operate on the first two stages
1301
+ if is_freeu_enabled:
1302
+ hidden_states, res_hidden_states = apply_freeu(
1303
+ self.resolution_idx,
1304
+ hidden_states,
1305
+ res_hidden_states,
1306
+ s1=self.s1,
1307
+ s2=self.s2,
1308
+ b1=self.b1,
1309
+ b2=self.b2,
1310
+ )
1311
+
1312
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1313
+
1314
+ if self.training and self.gradient_checkpointing:
1315
+
1316
+ def create_custom_forward(module):
1317
+ def custom_forward(*inputs):
1318
+ return module(*inputs)
1319
+
1320
+ return custom_forward
1321
+
1322
+ if is_torch_version(">=", "1.11.0"):
1323
+ hidden_states = torch.utils.checkpoint.checkpoint(
1324
+ create_custom_forward(resnet),
1325
+ hidden_states,
1326
+ temb,
1327
+ use_reentrant=False,
1328
+ )
1329
+ else:
1330
+ hidden_states = torch.utils.checkpoint.checkpoint(
1331
+ create_custom_forward(resnet), hidden_states, temb
1332
+ )
1333
+ else:
1334
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1335
+
1336
+ if self.upsamplers is not None:
1337
+ for upsampler in self.upsamplers:
1338
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1339
+
1340
+ return hidden_states
joyhallo/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements the `UNet2DConditionModel`,
3
+ a variant of the 2D U-Net architecture designed for conditional image generation tasks.
4
+ The model is capable of taking a noisy input sample and conditioning it based on additional information such as class labels,
5
+ time steps, and encoder hidden states to produce a denoised output.
6
+
7
+ The `UNet2DConditionModel` leverages various components such as time embeddings,
8
+ class embeddings, and cross-attention mechanisms to integrate the conditioning information effectively.
9
+ It is built upon several sub-blocks including down-blocks, a middle block, and up-blocks,
10
+ each responsible for different stages of the U-Net's downsampling and upsampling process.
11
+
12
+ Key Features:
13
+ - Support for multiple types of down and up blocks, including those with cross-attention capabilities.
14
+ - Flexible configuration of the model's layers, including the number of layers per block and the output channels for each block.
15
+ - Integration of time embeddings and class embeddings to condition the model's output on additional information.
16
+ - Implementation of cross-attention to leverage encoder hidden states for conditional generation.
17
+ - The model supports gradient checkpointing to reduce memory usage during training.
18
+
19
+ The module also includes utility functions and classes such as `UNet2DConditionOutput` for structured output
20
+ and `load_change_cross_attention_dim` for loading and modifying pre-trained models.
21
+
22
+ Example Usage:
23
+ >>> import torch
24
+ >>> from unet_2d_condition_model import UNet2DConditionModel
25
+ >>> model = UNet2DConditionModel(
26
+ ... sample_size=(64, 64),
27
+ ... in_channels=3,
28
+ ... out_channels=3,
29
+ ... encoder_hid_dim=512,
30
+ ... cross_attention_dim=1024,
31
+ ... )
32
+ >>> # Prepare input tensors
33
+ >>> sample = torch.randn(1, 3, 64, 64)
34
+ >>> timestep = 0
35
+ >>> encoder_hidden_states = torch.randn(1, 14, 512)
36
+ >>> # Forward pass through the model
37
+ >>> output = model(sample, timestep, encoder_hidden_states)
38
+
39
+ This module is part of a larger ecosystem of diffusion models and can be used for various conditional image generation tasks.
40
+ """
41
+
42
+ from dataclasses import dataclass
43
+ from os import PathLike
44
+ from pathlib import Path
45
+ from typing import Any, Dict, List, Optional, Tuple, Union
46
+
47
+ import torch
48
+ import torch.utils.checkpoint
49
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
50
+ from diffusers.loaders import UNet2DConditionLoadersMixin
51
+ from diffusers.models.activations import get_activation
52
+ from diffusers.models.attention_processor import (
53
+ ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
54
+ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
55
+ from diffusers.models.embeddings import (GaussianFourierProjection,
56
+ GLIGENTextBoundingboxProjection,
57
+ ImageHintTimeEmbedding,
58
+ ImageProjection, ImageTimeEmbedding,
59
+ TextImageProjection,
60
+ TextImageTimeEmbedding,
61
+ TextTimeEmbedding, TimestepEmbedding,
62
+ Timesteps)
63
+ from diffusers.models.modeling_utils import ModelMixin
64
+ from diffusers.utils import (SAFETENSORS_WEIGHTS_NAME, USE_PEFT_BACKEND,
65
+ WEIGHTS_NAME, BaseOutput, deprecate, logging,
66
+ scale_lora_layers, unscale_lora_layers)
67
+ from safetensors.torch import load_file
68
+ from torch import nn
69
+
70
+ from .unet_2d_blocks import (UNetMidBlock2D, UNetMidBlock2DCrossAttn,
71
+ get_down_block, get_up_block)
72
+
73
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
74
+
75
+ @dataclass
76
+ class UNet2DConditionOutput(BaseOutput):
77
+ """
78
+ The output of [`UNet2DConditionModel`].
79
+
80
+ Args:
81
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
82
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
83
+ """
84
+
85
+ sample: torch.FloatTensor = None
86
+ ref_features: Tuple[torch.FloatTensor] = None
87
+
88
+
89
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
90
+ r"""
91
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
92
+ shaped output.
93
+
94
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
95
+ for all models (such as downloading or saving).
96
+
97
+ Parameters:
98
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
99
+ Height and width of input/output sample.
100
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
101
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
102
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
103
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
104
+ Whether to flip the sin to cos in the time embedding.
105
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
106
+ down_block_types (`Tuple[str]`, *optional*, defaults to
107
+ `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
108
+ The tuple of downsample blocks to use.
109
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
110
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
111
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
112
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
113
+ The tuple of upsample blocks to use.
114
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
115
+ Whether to include self-attention in the basic transformer blocks, see
116
+ [`~models.attention.BasicTransformerBlock`].
117
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
118
+ The tuple of output channels for each block.
119
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
120
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
121
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
122
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
123
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
124
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
125
+ If `None`, normalization and activation layers is skipped in post-processing.
126
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
127
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
128
+ The dimension of the cross attention features.
129
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
130
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
131
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
132
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
133
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
134
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
135
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
136
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
137
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
138
+ encoder_hid_dim (`int`, *optional*, defaults to None):
139
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
140
+ dimension to `cross_attention_dim`.
141
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
142
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
143
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
144
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
145
+ num_attention_heads (`int`, *optional*):
146
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
147
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
148
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
149
+ class_embed_type (`str`, *optional*, defaults to `None`):
150
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
151
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
152
+ addition_embed_type (`str`, *optional*, defaults to `None`):
153
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
154
+ "text". "text" will use the `TextTimeEmbedding` layer.
155
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
156
+ Dimension for the timestep embeddings.
157
+ num_class_embeds (`int`, *optional*, defaults to `None`):
158
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
159
+ class conditioning with `class_embed_type` equal to `None`.
160
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
161
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
162
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
163
+ An optional override for the dimension of the projected time embedding.
164
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
165
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
166
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
167
+ timestep_post_act (`str`, *optional*, defaults to `None`):
168
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
169
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
170
+ The dimension of `cond_proj` layer in the timestep embedding.
171
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
172
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
173
+ *optional*): The dimension of the `class_labels` input when
174
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
175
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
176
+ embeddings with the class embeddings.
177
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
178
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
179
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
180
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
181
+ otherwise.
182
+ """
183
+
184
+ _supports_gradient_checkpointing = True
185
+
186
+ @register_to_config
187
+ def __init__(
188
+ self,
189
+ sample_size: Optional[int] = None,
190
+ in_channels: int = 4,
191
+ _out_channels: int = 4,
192
+ _center_input_sample: bool = False,
193
+ flip_sin_to_cos: bool = True,
194
+ freq_shift: int = 0,
195
+ down_block_types: Tuple[str] = (
196
+ "CrossAttnDownBlock2D",
197
+ "CrossAttnDownBlock2D",
198
+ "CrossAttnDownBlock2D",
199
+ "DownBlock2D",
200
+ ),
201
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
202
+ up_block_types: Tuple[str] = (
203
+ "UpBlock2D",
204
+ "CrossAttnUpBlock2D",
205
+ "CrossAttnUpBlock2D",
206
+ "CrossAttnUpBlock2D",
207
+ ),
208
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
209
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
210
+ layers_per_block: Union[int, Tuple[int]] = 2,
211
+ downsample_padding: int = 1,
212
+ mid_block_scale_factor: float = 1,
213
+ dropout: float = 0.0,
214
+ act_fn: str = "silu",
215
+ norm_num_groups: Optional[int] = 32,
216
+ norm_eps: float = 1e-5,
217
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
218
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
219
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
220
+ encoder_hid_dim: Optional[int] = None,
221
+ encoder_hid_dim_type: Optional[str] = None,
222
+ attention_head_dim: Union[int, Tuple[int]] = 8,
223
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
224
+ dual_cross_attention: bool = False,
225
+ use_linear_projection: bool = False,
226
+ class_embed_type: Optional[str] = None,
227
+ addition_embed_type: Optional[str] = None,
228
+ addition_time_embed_dim: Optional[int] = None,
229
+ num_class_embeds: Optional[int] = None,
230
+ upcast_attention: bool = False,
231
+ resnet_time_scale_shift: str = "default",
232
+ time_embedding_type: str = "positional",
233
+ time_embedding_dim: Optional[int] = None,
234
+ time_embedding_act_fn: Optional[str] = None,
235
+ timestep_post_act: Optional[str] = None,
236
+ time_cond_proj_dim: Optional[int] = None,
237
+ conv_in_kernel: int = 3,
238
+ projection_class_embeddings_input_dim: Optional[int] = None,
239
+ attention_type: str = "default",
240
+ class_embeddings_concat: bool = False,
241
+ mid_block_only_cross_attention: Optional[bool] = None,
242
+ addition_embed_type_num_heads=64,
243
+ _landmark_net=False,
244
+ ):
245
+ super().__init__()
246
+
247
+ self.sample_size = sample_size
248
+
249
+ if num_attention_heads is not None:
250
+ raise ValueError(
251
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads`"
252
+ "because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131."
253
+ "Passing `num_attention_heads` will only be supported in diffusers v0.19."
254
+ )
255
+
256
+ # If `num_attention_heads` is not defined (which is the case for most models)
257
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
258
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
259
+ # when this library was created. The incorrect naming was only discovered much later in
260
+ # https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
261
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
262
+ # which is why we correct for the naming here.
263
+ num_attention_heads = num_attention_heads or attention_head_dim
264
+
265
+ # Check inputs
266
+ if len(down_block_types) != len(up_block_types):
267
+ raise ValueError(
268
+ "Must provide the same number of `down_block_types` as `up_block_types`."
269
+ f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
270
+ )
271
+
272
+ if len(block_out_channels) != len(down_block_types):
273
+ raise ValueError(
274
+ "Must provide the same number of `block_out_channels` as `down_block_types`."
275
+ f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
276
+ )
277
+
278
+ if not isinstance(only_cross_attention, bool) and len(
279
+ only_cross_attention
280
+ ) != len(down_block_types):
281
+ raise ValueError(
282
+ "Must provide the same number of `only_cross_attention` as `down_block_types`."
283
+ f"`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
284
+ )
285
+
286
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
287
+ down_block_types
288
+ ):
289
+ raise ValueError(
290
+ "Must provide the same number of `num_attention_heads` as `down_block_types`."
291
+ f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
292
+ )
293
+
294
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
295
+ down_block_types
296
+ ):
297
+ raise ValueError(
298
+ "Must provide the same number of `attention_head_dim` as `down_block_types`."
299
+ f"`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
300
+ )
301
+
302
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
303
+ down_block_types
304
+ ):
305
+ raise ValueError(
306
+ "Must provide the same number of `cross_attention_dim` as `down_block_types`."
307
+ f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
308
+ )
309
+
310
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
311
+ down_block_types
312
+ ):
313
+ raise ValueError(
314
+ "Must provide the same number of `layers_per_block` as `down_block_types`."
315
+ f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
316
+ )
317
+ if (
318
+ isinstance(transformer_layers_per_block, list)
319
+ and reverse_transformer_layers_per_block is None
320
+ ):
321
+ for layer_number_per_block in transformer_layers_per_block:
322
+ if isinstance(layer_number_per_block, list):
323
+ raise ValueError(
324
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
325
+ )
326
+
327
+ # input
328
+ conv_in_padding = (conv_in_kernel - 1) // 2
329
+ self.conv_in = nn.Conv2d(
330
+ in_channels,
331
+ block_out_channels[0],
332
+ kernel_size=conv_in_kernel,
333
+ padding=conv_in_padding,
334
+ )
335
+
336
+ # time
337
+ if time_embedding_type == "fourier":
338
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
339
+ if time_embed_dim % 2 != 0:
340
+ raise ValueError(
341
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
342
+ )
343
+ self.time_proj = GaussianFourierProjection(
344
+ time_embed_dim // 2,
345
+ set_W_to_weight=False,
346
+ log=False,
347
+ flip_sin_to_cos=flip_sin_to_cos,
348
+ )
349
+ timestep_input_dim = time_embed_dim
350
+ elif time_embedding_type == "positional":
351
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
352
+
353
+ self.time_proj = Timesteps(
354
+ block_out_channels[0], flip_sin_to_cos, freq_shift
355
+ )
356
+ timestep_input_dim = block_out_channels[0]
357
+ else:
358
+ raise ValueError(
359
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
360
+ )
361
+
362
+ self.time_embedding = TimestepEmbedding(
363
+ timestep_input_dim,
364
+ time_embed_dim,
365
+ act_fn=act_fn,
366
+ post_act_fn=timestep_post_act,
367
+ cond_proj_dim=time_cond_proj_dim,
368
+ )
369
+
370
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
371
+ encoder_hid_dim_type = "text_proj"
372
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
373
+ logger.info(
374
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
375
+ )
376
+
377
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
378
+ raise ValueError(
379
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
380
+ )
381
+
382
+ if encoder_hid_dim_type == "text_proj":
383
+ self.encoder_hid_proj = nn.Linear(
384
+ encoder_hid_dim, cross_attention_dim)
385
+ elif encoder_hid_dim_type == "text_image_proj":
386
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
387
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
388
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
389
+ self.encoder_hid_proj = TextImageProjection(
390
+ text_embed_dim=encoder_hid_dim,
391
+ image_embed_dim=cross_attention_dim,
392
+ cross_attention_dim=cross_attention_dim,
393
+ )
394
+ elif encoder_hid_dim_type == "image_proj":
395
+ # Kandinsky 2.2
396
+ self.encoder_hid_proj = ImageProjection(
397
+ image_embed_dim=encoder_hid_dim,
398
+ cross_attention_dim=cross_attention_dim,
399
+ )
400
+ elif encoder_hid_dim_type is not None:
401
+ raise ValueError(
402
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
403
+ )
404
+ else:
405
+ self.encoder_hid_proj = None
406
+
407
+ # class embedding
408
+ if class_embed_type is None and num_class_embeds is not None:
409
+ self.class_embedding = nn.Embedding(
410
+ num_class_embeds, time_embed_dim)
411
+ elif class_embed_type == "timestep":
412
+ self.class_embedding = TimestepEmbedding(
413
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
414
+ )
415
+ elif class_embed_type == "identity":
416
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
417
+ elif class_embed_type == "projection":
418
+ if projection_class_embeddings_input_dim is None:
419
+ raise ValueError(
420
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
421
+ )
422
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
423
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
424
+ # 2. it projects from an arbitrary input dimension.
425
+ #
426
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
427
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
428
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
429
+ self.class_embedding = TimestepEmbedding(
430
+ projection_class_embeddings_input_dim, time_embed_dim
431
+ )
432
+ elif class_embed_type == "simple_projection":
433
+ if projection_class_embeddings_input_dim is None:
434
+ raise ValueError(
435
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
436
+ )
437
+ self.class_embedding = nn.Linear(
438
+ projection_class_embeddings_input_dim, time_embed_dim
439
+ )
440
+ else:
441
+ self.class_embedding = None
442
+
443
+ if addition_embed_type == "text":
444
+ if encoder_hid_dim is not None:
445
+ text_time_embedding_from_dim = encoder_hid_dim
446
+ else:
447
+ text_time_embedding_from_dim = cross_attention_dim
448
+
449
+ self.add_embedding = TextTimeEmbedding(
450
+ text_time_embedding_from_dim,
451
+ time_embed_dim,
452
+ num_heads=addition_embed_type_num_heads,
453
+ )
454
+ elif addition_embed_type == "text_image":
455
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
456
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
457
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
458
+ self.add_embedding = TextImageTimeEmbedding(
459
+ text_embed_dim=cross_attention_dim,
460
+ image_embed_dim=cross_attention_dim,
461
+ time_embed_dim=time_embed_dim,
462
+ )
463
+ elif addition_embed_type == "text_time":
464
+ self.add_time_proj = Timesteps(
465
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
466
+ )
467
+ self.add_embedding = TimestepEmbedding(
468
+ projection_class_embeddings_input_dim, time_embed_dim
469
+ )
470
+ elif addition_embed_type == "image":
471
+ # Kandinsky 2.2
472
+ self.add_embedding = ImageTimeEmbedding(
473
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
474
+ )
475
+ elif addition_embed_type == "image_hint":
476
+ # Kandinsky 2.2 ControlNet
477
+ self.add_embedding = ImageHintTimeEmbedding(
478
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
479
+ )
480
+ elif addition_embed_type is not None:
481
+ raise ValueError(
482
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
483
+ )
484
+
485
+ if time_embedding_act_fn is None:
486
+ self.time_embed_act = None
487
+ else:
488
+ self.time_embed_act = get_activation(time_embedding_act_fn)
489
+
490
+ self.down_blocks = nn.ModuleList([])
491
+ self.up_blocks = nn.ModuleList([])
492
+
493
+ if isinstance(only_cross_attention, bool):
494
+ if mid_block_only_cross_attention is None:
495
+ mid_block_only_cross_attention = only_cross_attention
496
+
497
+ only_cross_attention = [
498
+ only_cross_attention] * len(down_block_types)
499
+
500
+ if mid_block_only_cross_attention is None:
501
+ mid_block_only_cross_attention = False
502
+
503
+ if isinstance(num_attention_heads, int):
504
+ num_attention_heads = (num_attention_heads,) * \
505
+ len(down_block_types)
506
+
507
+ if isinstance(attention_head_dim, int):
508
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
509
+
510
+ if isinstance(cross_attention_dim, int):
511
+ cross_attention_dim = (cross_attention_dim,) * \
512
+ len(down_block_types)
513
+
514
+ if isinstance(layers_per_block, int):
515
+ layers_per_block = [layers_per_block] * len(down_block_types)
516
+
517
+ if isinstance(transformer_layers_per_block, int):
518
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
519
+ down_block_types
520
+ )
521
+
522
+ if class_embeddings_concat:
523
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
524
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
525
+ # regular time embeddings
526
+ blocks_time_embed_dim = time_embed_dim * 2
527
+ else:
528
+ blocks_time_embed_dim = time_embed_dim
529
+
530
+ # down
531
+ output_channel = block_out_channels[0]
532
+ for i, down_block_type in enumerate(down_block_types):
533
+ input_channel = output_channel
534
+ output_channel = block_out_channels[i]
535
+ is_final_block = i == len(block_out_channels) - 1
536
+
537
+ down_block = get_down_block(
538
+ down_block_type,
539
+ num_layers=layers_per_block[i],
540
+ transformer_layers_per_block=transformer_layers_per_block[i],
541
+ in_channels=input_channel,
542
+ out_channels=output_channel,
543
+ temb_channels=blocks_time_embed_dim,
544
+ add_downsample=not is_final_block,
545
+ resnet_eps=norm_eps,
546
+ resnet_act_fn=act_fn,
547
+ resnet_groups=norm_num_groups,
548
+ cross_attention_dim=cross_attention_dim[i],
549
+ num_attention_heads=num_attention_heads[i],
550
+ downsample_padding=downsample_padding,
551
+ dual_cross_attention=dual_cross_attention,
552
+ use_linear_projection=use_linear_projection,
553
+ only_cross_attention=only_cross_attention[i],
554
+ upcast_attention=upcast_attention,
555
+ resnet_time_scale_shift=resnet_time_scale_shift,
556
+ attention_type=attention_type,
557
+ attention_head_dim=(
558
+ attention_head_dim[i]
559
+ if attention_head_dim[i] is not None
560
+ else output_channel
561
+ ),
562
+ dropout=dropout,
563
+ )
564
+ self.down_blocks.append(down_block)
565
+
566
+ # mid
567
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
568
+ self.mid_block = UNetMidBlock2DCrossAttn(
569
+ transformer_layers_per_block=transformer_layers_per_block[-1],
570
+ in_channels=block_out_channels[-1],
571
+ temb_channels=blocks_time_embed_dim,
572
+ dropout=dropout,
573
+ resnet_eps=norm_eps,
574
+ resnet_act_fn=act_fn,
575
+ output_scale_factor=mid_block_scale_factor,
576
+ resnet_time_scale_shift=resnet_time_scale_shift,
577
+ cross_attention_dim=cross_attention_dim[-1],
578
+ num_attention_heads=num_attention_heads[-1],
579
+ resnet_groups=norm_num_groups,
580
+ dual_cross_attention=dual_cross_attention,
581
+ use_linear_projection=use_linear_projection,
582
+ upcast_attention=upcast_attention,
583
+ attention_type=attention_type,
584
+ )
585
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
586
+ raise NotImplementedError(
587
+ f"Unsupport mid_block_type: {mid_block_type}")
588
+ elif mid_block_type == "UNetMidBlock2D":
589
+ self.mid_block = UNetMidBlock2D(
590
+ in_channels=block_out_channels[-1],
591
+ temb_channels=blocks_time_embed_dim,
592
+ dropout=dropout,
593
+ num_layers=0,
594
+ resnet_eps=norm_eps,
595
+ resnet_act_fn=act_fn,
596
+ output_scale_factor=mid_block_scale_factor,
597
+ resnet_groups=norm_num_groups,
598
+ resnet_time_scale_shift=resnet_time_scale_shift,
599
+ add_attention=False,
600
+ )
601
+ elif mid_block_type is None:
602
+ self.mid_block = None
603
+ else:
604
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
605
+
606
+ # count how many layers upsample the images
607
+ self.num_upsamplers = 0
608
+
609
+ # up
610
+ reversed_block_out_channels = list(reversed(block_out_channels))
611
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
612
+ reversed_layers_per_block = list(reversed(layers_per_block))
613
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
614
+ reversed_transformer_layers_per_block = (
615
+ list(reversed(transformer_layers_per_block))
616
+ if reverse_transformer_layers_per_block is None
617
+ else reverse_transformer_layers_per_block
618
+ )
619
+ only_cross_attention = list(reversed(only_cross_attention))
620
+
621
+ output_channel = reversed_block_out_channels[0]
622
+ for i, up_block_type in enumerate(up_block_types):
623
+ is_final_block = i == len(block_out_channels) - 1
624
+
625
+ prev_output_channel = output_channel
626
+ output_channel = reversed_block_out_channels[i]
627
+ input_channel = reversed_block_out_channels[
628
+ min(i + 1, len(block_out_channels) - 1)
629
+ ]
630
+
631
+ # add upsample block for all BUT final layer
632
+ if not is_final_block:
633
+ add_upsample = True
634
+ self.num_upsamplers += 1
635
+ else:
636
+ add_upsample = False
637
+
638
+ up_block = get_up_block(
639
+ up_block_type,
640
+ num_layers=reversed_layers_per_block[i] + 1,
641
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
642
+ in_channels=input_channel,
643
+ out_channels=output_channel,
644
+ prev_output_channel=prev_output_channel,
645
+ temb_channels=blocks_time_embed_dim,
646
+ add_upsample=add_upsample,
647
+ resnet_eps=norm_eps,
648
+ resnet_act_fn=act_fn,
649
+ resolution_idx=i,
650
+ resnet_groups=norm_num_groups,
651
+ cross_attention_dim=reversed_cross_attention_dim[i],
652
+ num_attention_heads=reversed_num_attention_heads[i],
653
+ dual_cross_attention=dual_cross_attention,
654
+ use_linear_projection=use_linear_projection,
655
+ only_cross_attention=only_cross_attention[i],
656
+ upcast_attention=upcast_attention,
657
+ resnet_time_scale_shift=resnet_time_scale_shift,
658
+ attention_type=attention_type,
659
+ attention_head_dim=(
660
+ attention_head_dim[i]
661
+ if attention_head_dim[i] is not None
662
+ else output_channel
663
+ ),
664
+ dropout=dropout,
665
+ )
666
+ self.up_blocks.append(up_block)
667
+ prev_output_channel = output_channel
668
+
669
+ # out
670
+ if norm_num_groups is not None:
671
+ self.conv_norm_out = nn.GroupNorm(
672
+ num_channels=block_out_channels[0],
673
+ num_groups=norm_num_groups,
674
+ eps=norm_eps,
675
+ )
676
+
677
+ self.conv_act = get_activation(act_fn)
678
+
679
+ else:
680
+ self.conv_norm_out = None
681
+ self.conv_act = None
682
+ self.conv_norm_out = None
683
+
684
+ if attention_type in ["gated", "gated-text-image"]:
685
+ positive_len = 768
686
+ if isinstance(cross_attention_dim, int):
687
+ positive_len = cross_attention_dim
688
+ elif isinstance(cross_attention_dim, (tuple, list)):
689
+ positive_len = cross_attention_dim[0]
690
+
691
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
692
+ self.position_net = GLIGENTextBoundingboxProjection(
693
+ positive_len=positive_len,
694
+ out_dim=cross_attention_dim,
695
+ feature_type=feature_type,
696
+ )
697
+
698
+ @property
699
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
700
+ r"""
701
+ Returns:
702
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
703
+ indexed by its weight name.
704
+ """
705
+ # set recursively
706
+ processors = {}
707
+
708
+ def fn_recursive_add_processors(
709
+ name: str,
710
+ module: torch.nn.Module,
711
+ processors: Dict[str, AttentionProcessor],
712
+ ):
713
+ if hasattr(module, "get_processor"):
714
+ processors[f"{name}.processor"] = module.get_processor(
715
+ return_deprecated_lora=True
716
+ )
717
+
718
+ for sub_name, child in module.named_children():
719
+ fn_recursive_add_processors(
720
+ f"{name}.{sub_name}", child, processors)
721
+
722
+ return processors
723
+
724
+ for name, module in self.named_children():
725
+ fn_recursive_add_processors(name, module, processors)
726
+
727
+ return processors
728
+
729
+ def set_attn_processor(
730
+ self,
731
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
732
+ _remove_lora=False,
733
+ ):
734
+ r"""
735
+ Sets the attention processor to use to compute attention.
736
+
737
+ Parameters:
738
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
739
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
740
+ for **all** `Attention` layers.
741
+
742
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
743
+ processor. This is strongly recommended when setting trainable attention processors.
744
+
745
+ """
746
+ count = len(self.attn_processors.keys())
747
+
748
+ if isinstance(processor, dict) and len(processor) != count:
749
+ raise ValueError(
750
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
751
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
752
+ )
753
+
754
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
755
+ if hasattr(module, "set_processor"):
756
+ if not isinstance(processor, dict):
757
+ module.set_processor(processor, _remove_lora=_remove_lora)
758
+ else:
759
+ module.set_processor(
760
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
761
+ )
762
+
763
+ for sub_name, child in module.named_children():
764
+ fn_recursive_attn_processor(
765
+ f"{name}.{sub_name}", child, processor)
766
+
767
+ for name, module in self.named_children():
768
+ fn_recursive_attn_processor(name, module, processor)
769
+
770
+ def set_default_attn_processor(self):
771
+ """
772
+ Disables custom attention processors and sets the default attention implementation.
773
+ """
774
+ if all(
775
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
776
+ for proc in self.attn_processors.values()
777
+ ):
778
+ processor = AttnAddedKVProcessor()
779
+ elif all(
780
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
781
+ for proc in self.attn_processors.values()
782
+ ):
783
+ processor = AttnProcessor()
784
+ else:
785
+ raise ValueError(
786
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
787
+ )
788
+
789
+ self.set_attn_processor(processor, _remove_lora=True)
790
+
791
+ def set_attention_slice(self, slice_size):
792
+ r"""
793
+ Enable sliced attention computation.
794
+
795
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
796
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
797
+
798
+ Args:
799
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
800
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
801
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
802
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
803
+ must be a multiple of `slice_size`.
804
+ """
805
+ sliceable_head_dims = []
806
+
807
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
808
+ if hasattr(module, "set_attention_slice"):
809
+ sliceable_head_dims.append(module.sliceable_head_dim)
810
+
811
+ for child in module.children():
812
+ fn_recursive_retrieve_sliceable_dims(child)
813
+
814
+ # retrieve number of attention layers
815
+ for module in self.children():
816
+ fn_recursive_retrieve_sliceable_dims(module)
817
+
818
+ num_sliceable_layers = len(sliceable_head_dims)
819
+
820
+ if slice_size == "auto":
821
+ # half the attention head size is usually a good trade-off between
822
+ # speed and memory
823
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
824
+ elif slice_size == "max":
825
+ # make smallest slice possible
826
+ slice_size = num_sliceable_layers * [1]
827
+
828
+ slice_size = (
829
+ num_sliceable_layers * [slice_size]
830
+ if not isinstance(slice_size, list)
831
+ else slice_size
832
+ )
833
+
834
+ if len(slice_size) != len(sliceable_head_dims):
835
+ raise ValueError(
836
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
837
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
838
+ )
839
+
840
+ for i, size in enumerate(slice_size):
841
+ dim = sliceable_head_dims[i]
842
+ if size is not None and size > dim:
843
+ raise ValueError(
844
+ f"size {size} has to be smaller or equal to {dim}.")
845
+
846
+ # Recursively walk through all the children.
847
+ # Any children which exposes the set_attention_slice method
848
+ # gets the message
849
+ def fn_recursive_set_attention_slice(
850
+ module: torch.nn.Module, slice_size: List[int]
851
+ ):
852
+ if hasattr(module, "set_attention_slice"):
853
+ module.set_attention_slice(slice_size.pop())
854
+
855
+ for child in module.children():
856
+ fn_recursive_set_attention_slice(child, slice_size)
857
+
858
+ reversed_slice_size = list(reversed(slice_size))
859
+ for module in self.children():
860
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
861
+
862
+ def _set_gradient_checkpointing(self, module, value=False):
863
+ if hasattr(module, "gradient_checkpointing"):
864
+ module.gradient_checkpointing = value
865
+
866
+ def enable_freeu(self, s1, s2, b1, b2):
867
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
868
+
869
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
870
+
871
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
872
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
873
+
874
+ Args:
875
+ s1 (`float`):
876
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
877
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
878
+ s2 (`float`):
879
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
880
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
881
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
882
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
883
+ """
884
+ for _, upsample_block in enumerate(self.up_blocks):
885
+ setattr(upsample_block, "s1", s1)
886
+ setattr(upsample_block, "s2", s2)
887
+ setattr(upsample_block, "b1", b1)
888
+ setattr(upsample_block, "b2", b2)
889
+
890
+ def disable_freeu(self):
891
+ """Disables the FreeU mechanism."""
892
+ freeu_keys = {"s1", "s2", "b1", "b2"}
893
+ for _, upsample_block in enumerate(self.up_blocks):
894
+ for k in freeu_keys:
895
+ if (
896
+ hasattr(upsample_block, k)
897
+ or getattr(upsample_block, k, None) is not None
898
+ ):
899
+ setattr(upsample_block, k, None)
900
+
901
+ def forward(
902
+ self,
903
+ sample: torch.FloatTensor,
904
+ timestep: Union[torch.Tensor, float, int],
905
+ encoder_hidden_states: torch.Tensor,
906
+ cond_tensor: torch.FloatTensor=None,
907
+ class_labels: Optional[torch.Tensor] = None,
908
+ timestep_cond: Optional[torch.Tensor] = None,
909
+ attention_mask: Optional[torch.Tensor] = None,
910
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
911
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
912
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
913
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
914
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
915
+ encoder_attention_mask: Optional[torch.Tensor] = None,
916
+ return_dict: bool = True,
917
+ post_process: bool = False,
918
+ ) -> Union[UNet2DConditionOutput, Tuple]:
919
+ r"""
920
+ The [`UNet2DConditionModel`] forward method.
921
+
922
+ Args:
923
+ sample (`torch.FloatTensor`):
924
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
925
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
926
+ encoder_hidden_states (`torch.FloatTensor`):
927
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
928
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
929
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
930
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
931
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
932
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
933
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
934
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
935
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
936
+ negative values to the attention scores corresponding to "discard" tokens.
937
+ cross_attention_kwargs (`dict`, *optional*):
938
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
939
+ `self.processor` in
940
+ [diffusers.models.attention_processor]
941
+ (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
942
+ added_cond_kwargs: (`dict`, *optional*):
943
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
944
+ are passed along to the UNet blocks.
945
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
946
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
947
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
948
+ A tensor that if specified is added to the residual of the middle unet block.
949
+ encoder_attention_mask (`torch.Tensor`):
950
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
951
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
952
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
953
+ return_dict (`bool`, *optional*, defaults to `True`):
954
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
955
+ tuple.
956
+ cross_attention_kwargs (`dict`, *optional*):
957
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
958
+ added_cond_kwargs: (`dict`, *optional*):
959
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
960
+ are passed along to the UNet blocks.
961
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
962
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
963
+ example from ControlNet side model(s)
964
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
965
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
966
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
967
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
968
+
969
+ Returns:
970
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
971
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
972
+ a `tuple` is returned where the first element is the sample tensor.
973
+ """
974
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
975
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
976
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
977
+ # on the fly if necessary.
978
+ default_overall_up_factor = 2**self.num_upsamplers
979
+
980
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
981
+ forward_upsample_size = False
982
+ upsample_size = None
983
+
984
+ for dim in sample.shape[-2:]:
985
+ if dim % default_overall_up_factor != 0:
986
+ # Forward upsample size to force interpolation output size.
987
+ forward_upsample_size = True
988
+ break
989
+
990
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
991
+ # expects mask of shape:
992
+ # [batch, key_tokens]
993
+ # adds singleton query_tokens dimension:
994
+ # [batch, 1, key_tokens]
995
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
996
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
997
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
998
+ if attention_mask is not None:
999
+ # assume that mask is expressed as:
1000
+ # (1 = keep, 0 = discard)
1001
+ # convert mask into a bias that can be added to attention scores:
1002
+ # (keep = +0, discard = -10000.0)
1003
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1004
+ attention_mask = attention_mask.unsqueeze(1)
1005
+
1006
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1007
+ if encoder_attention_mask is not None:
1008
+ encoder_attention_mask = (
1009
+ 1 - encoder_attention_mask.to(sample.dtype)
1010
+ ) * -10000.0
1011
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1012
+
1013
+ # 0. center input if necessary
1014
+ if self.config.center_input_sample:
1015
+ sample = 2 * sample - 1.0
1016
+
1017
+ # 1. time
1018
+ timesteps = timestep
1019
+ if not torch.is_tensor(timesteps):
1020
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1021
+ # This would be a good case for the `match` statement (Python 3.10+)
1022
+ is_mps = sample.device.type == "mps"
1023
+ if isinstance(timestep, float):
1024
+ dtype = torch.float32 if is_mps else torch.float64
1025
+ else:
1026
+ dtype = torch.int32 if is_mps else torch.int64
1027
+ timesteps = torch.tensor(
1028
+ [timesteps], dtype=dtype, device=sample.device)
1029
+ elif len(timesteps.shape) == 0:
1030
+ timesteps = timesteps[None].to(sample.device)
1031
+
1032
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1033
+ timesteps = timesteps.expand(sample.shape[0])
1034
+
1035
+ t_emb = self.time_proj(timesteps)
1036
+
1037
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1038
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1039
+ # there might be better ways to encapsulate this.
1040
+ t_emb = t_emb.to(dtype=sample.dtype)
1041
+
1042
+ emb = self.time_embedding(t_emb, timestep_cond)
1043
+ aug_emb = None
1044
+
1045
+ if self.class_embedding is not None:
1046
+ if class_labels is None:
1047
+ raise ValueError(
1048
+ "class_labels should be provided when num_class_embeds > 0"
1049
+ )
1050
+
1051
+ if self.config.class_embed_type == "timestep":
1052
+ class_labels = self.time_proj(class_labels)
1053
+
1054
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1055
+ # there might be better ways to encapsulate this.
1056
+ class_labels = class_labels.to(dtype=sample.dtype)
1057
+
1058
+ class_emb = self.class_embedding(
1059
+ class_labels).to(dtype=sample.dtype)
1060
+
1061
+ if self.config.class_embeddings_concat:
1062
+ emb = torch.cat([emb, class_emb], dim=-1)
1063
+ else:
1064
+ emb = emb + class_emb
1065
+
1066
+ if self.config.addition_embed_type == "text":
1067
+ aug_emb = self.add_embedding(encoder_hidden_states)
1068
+ elif self.config.addition_embed_type == "text_image":
1069
+ # Kandinsky 2.1 - style
1070
+ if "image_embeds" not in added_cond_kwargs:
1071
+ raise ValueError(
1072
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image'"
1073
+ "which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1074
+ )
1075
+
1076
+ image_embs = added_cond_kwargs.get("image_embeds")
1077
+ text_embs = added_cond_kwargs.get(
1078
+ "text_embeds", encoder_hidden_states)
1079
+ aug_emb = self.add_embedding(text_embs, image_embs)
1080
+ elif self.config.addition_embed_type == "text_time":
1081
+ # SDXL - style
1082
+ if "text_embeds" not in added_cond_kwargs:
1083
+ raise ValueError(
1084
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time'"
1085
+ "which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1086
+ )
1087
+ text_embeds = added_cond_kwargs.get("text_embeds")
1088
+ if "time_ids" not in added_cond_kwargs:
1089
+ raise ValueError(
1090
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time'"
1091
+ "which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1092
+ )
1093
+ time_ids = added_cond_kwargs.get("time_ids")
1094
+ time_embeds = self.add_time_proj(time_ids.flatten())
1095
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1096
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1097
+ add_embeds = add_embeds.to(emb.dtype)
1098
+ aug_emb = self.add_embedding(add_embeds)
1099
+ elif self.config.addition_embed_type == "image":
1100
+ # Kandinsky 2.2 - style
1101
+ if "image_embeds" not in added_cond_kwargs:
1102
+ raise ValueError(
1103
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image'"
1104
+ "which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1105
+ )
1106
+ image_embs = added_cond_kwargs.get("image_embeds")
1107
+ aug_emb = self.add_embedding(image_embs)
1108
+ elif self.config.addition_embed_type == "image_hint":
1109
+ # Kandinsky 2.2 - style
1110
+ if (
1111
+ "image_embeds" not in added_cond_kwargs
1112
+ or "hint" not in added_cond_kwargs
1113
+ ):
1114
+ raise ValueError(
1115
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint'"
1116
+ "which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1117
+ )
1118
+ image_embs = added_cond_kwargs.get("image_embeds")
1119
+ hint = added_cond_kwargs.get("hint")
1120
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1121
+ sample = torch.cat([sample, hint], dim=1)
1122
+
1123
+ emb = emb + aug_emb if aug_emb is not None else emb
1124
+
1125
+ if self.time_embed_act is not None:
1126
+ emb = self.time_embed_act(emb)
1127
+
1128
+ if (
1129
+ self.encoder_hid_proj is not None
1130
+ and self.config.encoder_hid_dim_type == "text_proj"
1131
+ ):
1132
+ encoder_hidden_states = self.encoder_hid_proj(
1133
+ encoder_hidden_states)
1134
+ elif (
1135
+ self.encoder_hid_proj is not None
1136
+ and self.config.encoder_hid_dim_type == "text_image_proj"
1137
+ ):
1138
+ # Kadinsky 2.1 - style
1139
+ if "image_embeds" not in added_cond_kwargs:
1140
+ raise ValueError(
1141
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj'"
1142
+ "which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1143
+ )
1144
+
1145
+ image_embeds = added_cond_kwargs.get("image_embeds")
1146
+ encoder_hidden_states = self.encoder_hid_proj(
1147
+ encoder_hidden_states, image_embeds
1148
+ )
1149
+ elif (
1150
+ self.encoder_hid_proj is not None
1151
+ and self.config.encoder_hid_dim_type == "image_proj"
1152
+ ):
1153
+ # Kandinsky 2.2 - style
1154
+ if "image_embeds" not in added_cond_kwargs:
1155
+ raise ValueError(
1156
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj'"
1157
+ "which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1158
+ )
1159
+ image_embeds = added_cond_kwargs.get("image_embeds")
1160
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1161
+ elif (
1162
+ self.encoder_hid_proj is not None
1163
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
1164
+ ):
1165
+ if "image_embeds" not in added_cond_kwargs:
1166
+ raise ValueError(
1167
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj'"
1168
+ "which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1169
+ )
1170
+ image_embeds = added_cond_kwargs.get("image_embeds")
1171
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
1172
+ encoder_hidden_states.dtype
1173
+ )
1174
+ encoder_hidden_states = torch.cat(
1175
+ [encoder_hidden_states, image_embeds], dim=1
1176
+ )
1177
+
1178
+ # 2. pre-process
1179
+ sample = self.conv_in(sample)
1180
+ if cond_tensor is not None:
1181
+ sample = sample + cond_tensor
1182
+
1183
+ # 2.5 GLIGEN position net
1184
+ if (
1185
+ cross_attention_kwargs is not None
1186
+ and cross_attention_kwargs.get("gligen", None) is not None
1187
+ ):
1188
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1189
+ gligen_args = cross_attention_kwargs.pop("gligen")
1190
+ cross_attention_kwargs["gligen"] = {
1191
+ "objs": self.position_net(**gligen_args)
1192
+ }
1193
+
1194
+ # 3. down
1195
+ lora_scale = (
1196
+ cross_attention_kwargs.get("scale", 1.0)
1197
+ if cross_attention_kwargs is not None
1198
+ else 1.0
1199
+ )
1200
+ if USE_PEFT_BACKEND:
1201
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1202
+ scale_lora_layers(self, lora_scale)
1203
+
1204
+ is_controlnet = (
1205
+ mid_block_additional_residual is not None
1206
+ and down_block_additional_residuals is not None
1207
+ )
1208
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1209
+ is_adapter = down_intrablock_additional_residuals is not None
1210
+ # maintain backward compatibility for legacy usage, where
1211
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1212
+ # but can only use one or the other
1213
+ if (
1214
+ not is_adapter
1215
+ and mid_block_additional_residual is None
1216
+ and down_block_additional_residuals is not None
1217
+ ):
1218
+ deprecate(
1219
+ "T2I should not use down_block_additional_residuals",
1220
+ "1.3.0",
1221
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1222
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1223
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1224
+ standard_warn=False,
1225
+ )
1226
+ down_intrablock_additional_residuals = down_block_additional_residuals
1227
+ is_adapter = True
1228
+
1229
+ down_block_res_samples = (sample,)
1230
+ for downsample_block in self.down_blocks:
1231
+ if (
1232
+ hasattr(downsample_block, "has_cross_attention")
1233
+ and downsample_block.has_cross_attention
1234
+ ):
1235
+ # For t2i-adapter CrossAttnDownBlock2D
1236
+ additional_residuals = {}
1237
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1238
+ additional_residuals["additional_residuals"] = (
1239
+ down_intrablock_additional_residuals.pop(0)
1240
+ )
1241
+
1242
+ sample, res_samples = downsample_block(
1243
+ hidden_states=sample,
1244
+ temb=emb,
1245
+ encoder_hidden_states=encoder_hidden_states,
1246
+ attention_mask=attention_mask,
1247
+ cross_attention_kwargs=cross_attention_kwargs,
1248
+ encoder_attention_mask=encoder_attention_mask,
1249
+ **additional_residuals,
1250
+ )
1251
+ else:
1252
+ sample, res_samples = downsample_block(
1253
+ hidden_states=sample, temb=emb, scale=lora_scale
1254
+ )
1255
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1256
+ sample += down_intrablock_additional_residuals.pop(0)
1257
+
1258
+ down_block_res_samples += res_samples
1259
+
1260
+ if is_controlnet:
1261
+ new_down_block_res_samples = ()
1262
+
1263
+ for down_block_res_sample, down_block_additional_residual in zip(
1264
+ down_block_res_samples, down_block_additional_residuals
1265
+ ):
1266
+ down_block_res_sample = (
1267
+ down_block_res_sample + down_block_additional_residual
1268
+ )
1269
+ new_down_block_res_samples = new_down_block_res_samples + (
1270
+ down_block_res_sample,
1271
+ )
1272
+
1273
+ down_block_res_samples = new_down_block_res_samples
1274
+
1275
+ # 4. mid
1276
+ if self.mid_block is not None:
1277
+ if (
1278
+ hasattr(self.mid_block, "has_cross_attention")
1279
+ and self.mid_block.has_cross_attention
1280
+ ):
1281
+ sample = self.mid_block(
1282
+ sample,
1283
+ emb,
1284
+ encoder_hidden_states=encoder_hidden_states,
1285
+ attention_mask=attention_mask,
1286
+ cross_attention_kwargs=cross_attention_kwargs,
1287
+ encoder_attention_mask=encoder_attention_mask,
1288
+ )
1289
+ else:
1290
+ sample = self.mid_block(sample, emb)
1291
+
1292
+ # To support T2I-Adapter-XL
1293
+ if (
1294
+ is_adapter
1295
+ and len(down_intrablock_additional_residuals) > 0
1296
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1297
+ ):
1298
+ sample += down_intrablock_additional_residuals.pop(0)
1299
+
1300
+ if is_controlnet:
1301
+ sample = sample + mid_block_additional_residual
1302
+
1303
+ # 5. up
1304
+ for i, upsample_block in enumerate(self.up_blocks):
1305
+ is_final_block = i == len(self.up_blocks) - 1
1306
+
1307
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
1308
+ down_block_res_samples = down_block_res_samples[
1309
+ : -len(upsample_block.resnets)
1310
+ ]
1311
+
1312
+ # if we have not reached the final block and need to forward the
1313
+ # upsample size, we do it here
1314
+ if not is_final_block and forward_upsample_size:
1315
+ upsample_size = down_block_res_samples[-1].shape[2:]
1316
+
1317
+ if (
1318
+ hasattr(upsample_block, "has_cross_attention")
1319
+ and upsample_block.has_cross_attention
1320
+ ):
1321
+ sample = upsample_block(
1322
+ hidden_states=sample,
1323
+ temb=emb,
1324
+ res_hidden_states_tuple=res_samples,
1325
+ encoder_hidden_states=encoder_hidden_states,
1326
+ cross_attention_kwargs=cross_attention_kwargs,
1327
+ upsample_size=upsample_size,
1328
+ attention_mask=attention_mask,
1329
+ encoder_attention_mask=encoder_attention_mask,
1330
+ )
1331
+ else:
1332
+ sample = upsample_block(
1333
+ hidden_states=sample,
1334
+ temb=emb,
1335
+ res_hidden_states_tuple=res_samples,
1336
+ upsample_size=upsample_size,
1337
+ scale=lora_scale,
1338
+ )
1339
+
1340
+ # 6. post-process
1341
+ if post_process:
1342
+ if self.conv_norm_out:
1343
+ sample = self.conv_norm_out(sample)
1344
+ sample = self.conv_act(sample)
1345
+ sample = self.conv_out(sample)
1346
+
1347
+ if USE_PEFT_BACKEND:
1348
+ # remove `lora_scale` from each PEFT layer
1349
+ unscale_lora_layers(self, lora_scale)
1350
+
1351
+ if not return_dict:
1352
+ return (sample,)
1353
+
1354
+ return UNet2DConditionOutput(sample=sample)
1355
+
1356
+ @classmethod
1357
+ def load_change_cross_attention_dim(
1358
+ cls,
1359
+ pretrained_model_path: PathLike,
1360
+ subfolder=None,
1361
+ # unet_additional_kwargs=None,
1362
+ ):
1363
+ """
1364
+ Load or change the cross-attention dimension of a pre-trained model.
1365
+
1366
+ Parameters:
1367
+ pretrained_model_name_or_path (:class:`~typing.Union[str, :class:`~pathlib.Path`]`):
1368
+ The identifier of the pre-trained model or the path to the local folder containing the model.
1369
+ force_download (:class:`~bool`):
1370
+ If True, re-download the model even if it is already cached.
1371
+ resume_download (:class:`~bool`):
1372
+ If True, resume the download of the model if partially downloaded.
1373
+ proxies (:class:`~dict`):
1374
+ A dictionary of proxy servers to use for downloading the model.
1375
+ cache_dir (:class:`~Optional[str]`):
1376
+ The path to the cache directory for storing downloaded models.
1377
+ use_auth_token (:class:`~bool`):
1378
+ If True, use the authentication token for private models.
1379
+ revision (:class:`~str`):
1380
+ The specific model version to use.
1381
+ use_safetensors (:class:`~bool`):
1382
+ If True, use the SafeTensors format for loading the model weights.
1383
+ **kwargs (:class:`~dict`):
1384
+ Additional keyword arguments passed to the model.
1385
+
1386
+ """
1387
+ pretrained_model_path = Path(pretrained_model_path)
1388
+ if subfolder is not None:
1389
+ pretrained_model_path = pretrained_model_path.joinpath(subfolder)
1390
+ config_file = pretrained_model_path / "config.json"
1391
+ if not (config_file.exists() and config_file.is_file()):
1392
+ raise RuntimeError(
1393
+ f"{config_file} does not exist or is not a file")
1394
+
1395
+ unet_config = cls.load_config(config_file)
1396
+ unet_config["cross_attention_dim"] = 1024
1397
+
1398
+ model = cls.from_config(unet_config)
1399
+ # load the vanilla weights
1400
+ if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
1401
+ logger.debug(
1402
+ f"loading safeTensors weights from {pretrained_model_path} ..."
1403
+ )
1404
+ state_dict = load_file(
1405
+ pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
1406
+ )
1407
+
1408
+ elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
1409
+ logger.debug(f"loading weights from {pretrained_model_path} ...")
1410
+ state_dict = torch.load(
1411
+ pretrained_model_path.joinpath(WEIGHTS_NAME),
1412
+ map_location="cpu",
1413
+ weights_only=True,
1414
+ )
1415
+ else:
1416
+ raise FileNotFoundError(
1417
+ f"no weights file found in {pretrained_model_path}")
1418
+
1419
+ model_state_dict = model.state_dict()
1420
+ for k in state_dict:
1421
+ if k in model_state_dict:
1422
+ if state_dict[k].shape != model_state_dict[k].shape:
1423
+ state_dict[k] = model_state_dict[k]
1424
+ # load the weights into the model
1425
+ m, u = model.load_state_dict(state_dict, strict=False)
1426
+ print(m, u)
1427
+
1428
+ return model
joyhallo/models/unet_3d.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is the main file for the UNet3DConditionModel, which defines the UNet3D model architecture.
3
+
4
+ The UNet3D model is a 3D convolutional neural network designed for image segmentation and
5
+ other computer vision tasks. It consists of an encoder, a decoder, and skip connections between
6
+ the corresponding layers of the encoder and decoder. The model can handle 3D data and
7
+ performs well on tasks such as image segmentation, object detection, and video analysis.
8
+
9
+ This file contains the necessary imports, the main UNet3DConditionModel class, and its
10
+ methods for setting attention slice, setting gradient checkpointing, setting attention
11
+ processor, and the forward method for model inference.
12
+
13
+ The module provides a comprehensive solution for 3D image segmentation tasks and can be
14
+ easily extended for other computer vision tasks as well.
15
+ """
16
+
17
+ from collections import OrderedDict
18
+ from dataclasses import dataclass
19
+ from os import PathLike
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.utils.checkpoint
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.models.attention_processor import AttentionProcessor
28
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.utils import (SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME,
31
+ BaseOutput, logging)
32
+ from safetensors.torch import load_file
33
+
34
+ from .resnet import InflatedConv3d, InflatedGroupNorm
35
+ from .unet_3d_blocks import (UNetMidBlock3DCrossAttn, get_down_block,
36
+ get_up_block)
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ @dataclass
42
+ class UNet3DConditionOutput(BaseOutput):
43
+ """
44
+ Data class that serves as the output of the UNet3DConditionModel.
45
+
46
+ Attributes:
47
+ sample (`torch.FloatTensor`):
48
+ A tensor representing the processed sample. The shape and nature of this tensor will depend on the
49
+ specific configuration of the model and the input data.
50
+ """
51
+ sample: torch.FloatTensor
52
+
53
+
54
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
55
+ """
56
+ A 3D UNet model designed to handle conditional image and video generation tasks. This model is particularly
57
+ suited for tasks that require the generation of 3D data, such as volumetric medical imaging or 3D video
58
+ generation, while incorporating additional conditioning information.
59
+
60
+ The model consists of an encoder-decoder structure with skip connections. It utilizes a series of downsampling
61
+ and upsampling blocks, with a middle block for further processing. Each block can be customized with different
62
+ types of layers and attention mechanisms.
63
+
64
+ Parameters:
65
+ sample_size (`int`, optional): The size of the input sample.
66
+ in_channels (`int`, defaults to 8): The number of input channels.
67
+ out_channels (`int`, defaults to 8): The number of output channels.
68
+ center_input_sample (`bool`, defaults to False): Whether to center the input sample.
69
+ flip_sin_to_cos (`bool`, defaults to True): Whether to flip the sine to cosine in the time embedding.
70
+ freq_shift (`int`, defaults to 0): The frequency shift for the time embedding.
71
+ down_block_types (`Tuple[str]`): A tuple of strings specifying the types of downsampling blocks.
72
+ mid_block_type (`str`): The type of middle block.
73
+ up_block_types (`Tuple[str]`): A tuple of strings specifying the types of upsampling blocks.
74
+ only_cross_attention (`Union[bool, Tuple[bool]]`): Whether to use only cross-attention.
75
+ block_out_channels (`Tuple[int]`): A tuple of integers specifying the output channels for each block.
76
+ layers_per_block (`int`, defaults to 2): The number of layers per block.
77
+ downsample_padding (`int`, defaults to 1): The padding used in downsampling.
78
+ mid_block_scale_factor (`float`, defaults to 1): The scale factor for the middle block.
79
+ act_fn (`str`, defaults to 'silu'): The activation function to be used.
80
+ norm_num_groups (`int`, defaults to 32): The number of groups for normalization.
81
+ norm_eps (`float`, defaults to 1e-5): The epsilon for normalization.
82
+ cross_attention_dim (`int`, defaults to 1280): The dimension for cross-attention.
83
+ attention_head_dim (`Union[int, Tuple[int]]`): The dimension for attention heads.
84
+ dual_cross_attention (`bool`, defaults to False): Whether to use dual cross-attention.
85
+ use_linear_projection (`bool`, defaults to False): Whether to use linear projection.
86
+ class_embed_type (`str`, optional): The type of class embedding.
87
+ num_class_embeds (`int`, optional): The number of class embeddings.
88
+ upcast_attention (`bool`, defaults to False): Whether to upcast attention.
89
+ resnet_time_scale_shift (`str`, defaults to 'default'): The time scale shift for the ResNet.
90
+ use_inflated_groupnorm (`bool`, defaults to False): Whether to use inflated group normalization.
91
+ use_motion_module (`bool`, defaults to False): Whether to use a motion module.
92
+ motion_module_resolutions (`Tuple[int]`): A tuple of resolutions for the motion module.
93
+ motion_module_mid_block (`bool`, defaults to False): Whether to use a motion module in the middle block.
94
+ motion_module_decoder_only (`bool`, defaults to False): Whether to use the motion module only in the decoder.
95
+ motion_module_type (`str`, optional): The type of motion module.
96
+ motion_module_kwargs (`dict`): Keyword arguments for the motion module.
97
+ unet_use_cross_frame_attention (`bool`, optional): Whether to use cross-frame attention in the UNet.
98
+ unet_use_temporal_attention (`bool`, optional): Whether to use temporal attention in the UNet.
99
+ use_audio_module (`bool`, defaults to False): Whether to use an audio module.
100
+ audio_attention_dim (`int`, defaults to 768): The dimension for audio attention.
101
+
102
+ The model supports various features such as gradient checkpointing, attention processors, and sliced attention
103
+ computation, making it flexible and efficient for different computational requirements and use cases.
104
+
105
+ The forward method of the model accepts a sample, timestep, and encoder hidden states as input, and it returns
106
+ the processed sample as output. The method also supports additional conditioning information such as class
107
+ labels, audio embeddings, and masks for specialized tasks.
108
+
109
+ The from_pretrained_2d class method allows loading a pre-trained 2D UNet model and adapting it for 3D tasks by
110
+ incorporating motion modules and other 3D specific features.
111
+ """
112
+
113
+ _supports_gradient_checkpointing = True
114
+
115
+ @register_to_config
116
+ def __init__(
117
+ self,
118
+ sample_size: Optional[int] = None,
119
+ in_channels: int = 8,
120
+ out_channels: int = 8,
121
+ flip_sin_to_cos: bool = True,
122
+ freq_shift: int = 0,
123
+ down_block_types: Tuple[str] = (
124
+ "CrossAttnDownBlock3D",
125
+ "CrossAttnDownBlock3D",
126
+ "CrossAttnDownBlock3D",
127
+ "DownBlock3D",
128
+ ),
129
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
130
+ up_block_types: Tuple[str] = (
131
+ "UpBlock3D",
132
+ "CrossAttnUpBlock3D",
133
+ "CrossAttnUpBlock3D",
134
+ "CrossAttnUpBlock3D",
135
+ ),
136
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
137
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
138
+ layers_per_block: int = 2,
139
+ downsample_padding: int = 1,
140
+ mid_block_scale_factor: float = 1,
141
+ act_fn: str = "silu",
142
+ norm_num_groups: int = 32,
143
+ norm_eps: float = 1e-5,
144
+ cross_attention_dim: int = 1280,
145
+ attention_head_dim: Union[int, Tuple[int]] = 8,
146
+ dual_cross_attention: bool = False,
147
+ use_linear_projection: bool = False,
148
+ class_embed_type: Optional[str] = None,
149
+ num_class_embeds: Optional[int] = None,
150
+ upcast_attention: bool = False,
151
+ resnet_time_scale_shift: str = "default",
152
+ use_inflated_groupnorm=False,
153
+ # Additional
154
+ use_motion_module=False,
155
+ motion_module_resolutions=(1, 2, 4, 8),
156
+ motion_module_mid_block=False,
157
+ motion_module_decoder_only=False,
158
+ motion_module_type=None,
159
+ motion_module_kwargs=None,
160
+ unet_use_cross_frame_attention=None,
161
+ unet_use_temporal_attention=None,
162
+ # audio
163
+ use_audio_module=False,
164
+ audio_attention_dim=768,
165
+ stack_enable_blocks_name=None,
166
+ stack_enable_blocks_depth=None,
167
+ ):
168
+ super().__init__()
169
+
170
+ self.sample_size = sample_size
171
+ time_embed_dim = block_out_channels[0] * 4
172
+
173
+ # input
174
+ self.conv_in = InflatedConv3d(
175
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
176
+ )
177
+
178
+ # time
179
+ self.time_proj = Timesteps(
180
+ block_out_channels[0], flip_sin_to_cos, freq_shift)
181
+ timestep_input_dim = block_out_channels[0]
182
+
183
+ self.time_embedding = TimestepEmbedding(
184
+ timestep_input_dim, time_embed_dim)
185
+
186
+ # class embedding
187
+ if class_embed_type is None and num_class_embeds is not None:
188
+ self.class_embedding = nn.Embedding(
189
+ num_class_embeds, time_embed_dim)
190
+ elif class_embed_type == "timestep":
191
+ self.class_embedding = TimestepEmbedding(
192
+ timestep_input_dim, time_embed_dim)
193
+ elif class_embed_type == "identity":
194
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
195
+ else:
196
+ self.class_embedding = None
197
+
198
+ self.down_blocks = nn.ModuleList([])
199
+ self.mid_block = None
200
+ self.up_blocks = nn.ModuleList([])
201
+
202
+ if isinstance(only_cross_attention, bool):
203
+ only_cross_attention = [
204
+ only_cross_attention] * len(down_block_types)
205
+
206
+ if isinstance(attention_head_dim, int):
207
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
208
+
209
+ # down
210
+ output_channel = block_out_channels[0]
211
+ for i, down_block_type in enumerate(down_block_types):
212
+ res = 2**i
213
+ input_channel = output_channel
214
+ output_channel = block_out_channels[i]
215
+ is_final_block = i == len(block_out_channels) - 1
216
+
217
+ down_block = get_down_block(
218
+ down_block_type,
219
+ num_layers=layers_per_block,
220
+ in_channels=input_channel,
221
+ out_channels=output_channel,
222
+ temb_channels=time_embed_dim,
223
+ add_downsample=not is_final_block,
224
+ resnet_eps=norm_eps,
225
+ resnet_act_fn=act_fn,
226
+ resnet_groups=norm_num_groups,
227
+ cross_attention_dim=cross_attention_dim,
228
+ attn_num_head_channels=attention_head_dim[i],
229
+ downsample_padding=downsample_padding,
230
+ dual_cross_attention=dual_cross_attention,
231
+ use_linear_projection=use_linear_projection,
232
+ only_cross_attention=only_cross_attention[i],
233
+ upcast_attention=upcast_attention,
234
+ resnet_time_scale_shift=resnet_time_scale_shift,
235
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
236
+ unet_use_temporal_attention=unet_use_temporal_attention,
237
+ use_inflated_groupnorm=use_inflated_groupnorm,
238
+ use_motion_module=use_motion_module
239
+ and (res in motion_module_resolutions)
240
+ and (not motion_module_decoder_only),
241
+ motion_module_type=motion_module_type,
242
+ motion_module_kwargs=motion_module_kwargs,
243
+ use_audio_module=use_audio_module,
244
+ audio_attention_dim=audio_attention_dim,
245
+ depth=i,
246
+ stack_enable_blocks_name=stack_enable_blocks_name,
247
+ stack_enable_blocks_depth=stack_enable_blocks_depth,
248
+ )
249
+ self.down_blocks.append(down_block)
250
+
251
+ # mid
252
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
253
+ self.mid_block = UNetMidBlock3DCrossAttn(
254
+ in_channels=block_out_channels[-1],
255
+ temb_channels=time_embed_dim,
256
+ resnet_eps=norm_eps,
257
+ resnet_act_fn=act_fn,
258
+ output_scale_factor=mid_block_scale_factor,
259
+ resnet_time_scale_shift=resnet_time_scale_shift,
260
+ cross_attention_dim=cross_attention_dim,
261
+ attn_num_head_channels=attention_head_dim[-1],
262
+ resnet_groups=norm_num_groups,
263
+ dual_cross_attention=dual_cross_attention,
264
+ use_linear_projection=use_linear_projection,
265
+ upcast_attention=upcast_attention,
266
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
267
+ unet_use_temporal_attention=unet_use_temporal_attention,
268
+ use_inflated_groupnorm=use_inflated_groupnorm,
269
+ use_motion_module=use_motion_module and motion_module_mid_block,
270
+ motion_module_type=motion_module_type,
271
+ motion_module_kwargs=motion_module_kwargs,
272
+ use_audio_module=use_audio_module,
273
+ audio_attention_dim=audio_attention_dim,
274
+ depth=3,
275
+ stack_enable_blocks_name=stack_enable_blocks_name,
276
+ stack_enable_blocks_depth=stack_enable_blocks_depth,
277
+ )
278
+ else:
279
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
280
+
281
+ # count how many layers upsample the videos
282
+ self.num_upsamplers = 0
283
+
284
+ # up
285
+ reversed_block_out_channels = list(reversed(block_out_channels))
286
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
287
+ only_cross_attention = list(reversed(only_cross_attention))
288
+ output_channel = reversed_block_out_channels[0]
289
+ for i, up_block_type in enumerate(up_block_types):
290
+ res = 2 ** (3 - i)
291
+ is_final_block = i == len(block_out_channels) - 1
292
+
293
+ prev_output_channel = output_channel
294
+ output_channel = reversed_block_out_channels[i]
295
+ input_channel = reversed_block_out_channels[
296
+ min(i + 1, len(block_out_channels) - 1)
297
+ ]
298
+
299
+ # add upsample block for all BUT final layer
300
+ if not is_final_block:
301
+ add_upsample = True
302
+ self.num_upsamplers += 1
303
+ else:
304
+ add_upsample = False
305
+
306
+ up_block = get_up_block(
307
+ up_block_type,
308
+ num_layers=layers_per_block + 1,
309
+ in_channels=input_channel,
310
+ out_channels=output_channel,
311
+ prev_output_channel=prev_output_channel,
312
+ temb_channels=time_embed_dim,
313
+ add_upsample=add_upsample,
314
+ resnet_eps=norm_eps,
315
+ resnet_act_fn=act_fn,
316
+ resnet_groups=norm_num_groups,
317
+ cross_attention_dim=cross_attention_dim,
318
+ attn_num_head_channels=reversed_attention_head_dim[i],
319
+ dual_cross_attention=dual_cross_attention,
320
+ use_linear_projection=use_linear_projection,
321
+ only_cross_attention=only_cross_attention[i],
322
+ upcast_attention=upcast_attention,
323
+ resnet_time_scale_shift=resnet_time_scale_shift,
324
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
325
+ unet_use_temporal_attention=unet_use_temporal_attention,
326
+ use_inflated_groupnorm=use_inflated_groupnorm,
327
+ use_motion_module=use_motion_module
328
+ and (res in motion_module_resolutions),
329
+ motion_module_type=motion_module_type,
330
+ motion_module_kwargs=motion_module_kwargs,
331
+ use_audio_module=use_audio_module,
332
+ audio_attention_dim=audio_attention_dim,
333
+ depth=3-i,
334
+ stack_enable_blocks_name=stack_enable_blocks_name,
335
+ stack_enable_blocks_depth=stack_enable_blocks_depth,
336
+ )
337
+ self.up_blocks.append(up_block)
338
+ prev_output_channel = output_channel
339
+
340
+ # out
341
+ if use_inflated_groupnorm:
342
+ self.conv_norm_out = InflatedGroupNorm(
343
+ num_channels=block_out_channels[0],
344
+ num_groups=norm_num_groups,
345
+ eps=norm_eps,
346
+ )
347
+ else:
348
+ self.conv_norm_out = nn.GroupNorm(
349
+ num_channels=block_out_channels[0],
350
+ num_groups=norm_num_groups,
351
+ eps=norm_eps,
352
+ )
353
+ self.conv_act = nn.SiLU()
354
+ self.conv_out = InflatedConv3d(
355
+ block_out_channels[0], out_channels, kernel_size=3, padding=1
356
+ )
357
+
358
+ @property
359
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
360
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
361
+ r"""
362
+ Returns:
363
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
364
+ indexed by its weight name.
365
+ """
366
+ # set recursively
367
+ processors = {}
368
+
369
+ def fn_recursive_add_processors(
370
+ name: str,
371
+ module: torch.nn.Module,
372
+ processors: Dict[str, AttentionProcessor],
373
+ ):
374
+ if hasattr(module, "set_processor"):
375
+ processors[f"{name}.processor"] = module.processor
376
+
377
+ for sub_name, child in module.named_children():
378
+ if "temporal_transformer" not in sub_name:
379
+ fn_recursive_add_processors(
380
+ f"{name}.{sub_name}", child, processors)
381
+
382
+ return processors
383
+
384
+ for name, module in self.named_children():
385
+ if "temporal_transformer" not in name:
386
+ fn_recursive_add_processors(name, module, processors)
387
+
388
+ return processors
389
+
390
+ def set_attention_slice(self, slice_size):
391
+ r"""
392
+ Enable sliced attention computation.
393
+
394
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
395
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
396
+
397
+ Args:
398
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
399
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
400
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
401
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
402
+ must be a multiple of `slice_size`.
403
+ """
404
+ sliceable_head_dims = []
405
+
406
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
407
+ if hasattr(module, "set_attention_slice"):
408
+ sliceable_head_dims.append(module.sliceable_head_dim)
409
+
410
+ for child in module.children():
411
+ fn_recursive_retrieve_slicable_dims(child)
412
+
413
+ # retrieve number of attention layers
414
+ for module in self.children():
415
+ fn_recursive_retrieve_slicable_dims(module)
416
+
417
+ num_slicable_layers = len(sliceable_head_dims)
418
+
419
+ if slice_size == "auto":
420
+ # half the attention head size is usually a good trade-off between
421
+ # speed and memory
422
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
423
+ elif slice_size == "max":
424
+ # make smallest slice possible
425
+ slice_size = num_slicable_layers * [1]
426
+
427
+ slice_size = (
428
+ num_slicable_layers * [slice_size]
429
+ if not isinstance(slice_size, list)
430
+ else slice_size
431
+ )
432
+
433
+ if len(slice_size) != len(sliceable_head_dims):
434
+ raise ValueError(
435
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
436
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
437
+ )
438
+
439
+ for i, size in enumerate(slice_size):
440
+ dim = sliceable_head_dims[i]
441
+ if size is not None and size > dim:
442
+ raise ValueError(
443
+ f"size {size} has to be smaller or equal to {dim}.")
444
+
445
+ # Recursively walk through all the children.
446
+ # Any children which exposes the set_attention_slice method
447
+ # gets the message
448
+ def fn_recursive_set_attention_slice(
449
+ module: torch.nn.Module, slice_size: List[int]
450
+ ):
451
+ if hasattr(module, "set_attention_slice"):
452
+ module.set_attention_slice(slice_size.pop())
453
+
454
+ for child in module.children():
455
+ fn_recursive_set_attention_slice(child, slice_size)
456
+
457
+ reversed_slice_size = list(reversed(slice_size))
458
+ for module in self.children():
459
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
460
+
461
+ def _set_gradient_checkpointing(self, module, value=False):
462
+ if hasattr(module, "gradient_checkpointing"):
463
+ module.gradient_checkpointing = value
464
+
465
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
466
+ def set_attn_processor(
467
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
468
+ ):
469
+ r"""
470
+ Sets the attention processor to use to compute attention.
471
+
472
+ Parameters:
473
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
474
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
475
+ for **all** `Attention` layers.
476
+
477
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
478
+ processor. This is strongly recommended when setting trainable attention processors.
479
+
480
+ """
481
+ count = len(self.attn_processors.keys())
482
+
483
+ if isinstance(processor, dict) and len(processor) != count:
484
+ raise ValueError(
485
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
486
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
487
+ )
488
+
489
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
490
+ if hasattr(module, "set_processor"):
491
+ if not isinstance(processor, dict):
492
+ module.set_processor(processor)
493
+ else:
494
+ module.set_processor(processor.pop(f"{name}.processor"))
495
+
496
+ for sub_name, child in module.named_children():
497
+ if "temporal_transformer" not in sub_name:
498
+ fn_recursive_attn_processor(
499
+ f"{name}.{sub_name}", child, processor)
500
+
501
+ for name, module in self.named_children():
502
+ if "temporal_transformer" not in name:
503
+ fn_recursive_attn_processor(name, module, processor)
504
+
505
+ def forward(
506
+ self,
507
+ sample: torch.FloatTensor,
508
+ timestep: Union[torch.Tensor, float, int],
509
+ encoder_hidden_states: torch.Tensor,
510
+ audio_embedding: Optional[torch.Tensor] = None,
511
+ class_labels: Optional[torch.Tensor] = None,
512
+ mask_cond_fea: Optional[torch.Tensor] = None,
513
+ attention_mask: Optional[torch.Tensor] = None,
514
+ full_mask: Optional[torch.Tensor] = None,
515
+ face_mask: Optional[torch.Tensor] = None,
516
+ lip_mask: Optional[torch.Tensor] = None,
517
+ motion_scale: Optional[torch.Tensor] = None,
518
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
519
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
520
+ return_dict: bool = True,
521
+ # start: bool = False,
522
+ ) -> Union[UNet3DConditionOutput, Tuple]:
523
+ r"""
524
+ Args:
525
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
526
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
527
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states, face_emb
528
+ return_dict (`bool`, *optional*, defaults to `True`):
529
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
530
+
531
+ mask_cond_fea (`torch.FloatTensor`, *optional*): mask_feature tensor
532
+ audio_embedding (`torch.FloatTensor`, *optional*): audio embedding tensor, audio_emb
533
+ full_mask (`torch.FloatTensor`, *optional*): full mask tensor, full_mask
534
+ face_mask (`torch.FloatTensor`, *optional*): face mask tensor, face_mask
535
+ lip_mask (`torch.FloatTensor`, *optional*): lip mask tensor, lip_mask
536
+
537
+ Returns:
538
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
539
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
540
+ returning a tuple, the first element is the sample tensor.
541
+ """
542
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
543
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
544
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
545
+ # on the fly if necessary.
546
+ default_overall_up_factor = 2**self.num_upsamplers
547
+
548
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
549
+ forward_upsample_size = False
550
+ upsample_size = None
551
+
552
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
553
+ logger.info(
554
+ "Forward upsample size to force interpolation output size.")
555
+ forward_upsample_size = True
556
+
557
+ # prepare attention_mask
558
+ if attention_mask is not None:
559
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
560
+ attention_mask = attention_mask.unsqueeze(1)
561
+
562
+ # center input if necessary
563
+ if self.config.center_input_sample:
564
+ sample = 2 * sample - 1.0
565
+
566
+ # time
567
+ timesteps = timestep
568
+ if not torch.is_tensor(timesteps):
569
+ # This would be a good case for the `match` statement (Python 3.10+)
570
+ is_mps = sample.device.type == "mps"
571
+ if isinstance(timestep, float):
572
+ dtype = torch.float32 if is_mps else torch.float64
573
+ else:
574
+ dtype = torch.int32 if is_mps else torch.int64
575
+ timesteps = torch.tensor(
576
+ [timesteps], dtype=dtype, device=sample.device)
577
+ elif len(timesteps.shape) == 0:
578
+ timesteps = timesteps[None].to(sample.device)
579
+
580
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
581
+ timesteps = timesteps.expand(sample.shape[0])
582
+
583
+ t_emb = self.time_proj(timesteps)
584
+
585
+ # timesteps does not contain any weights and will always return f32 tensors
586
+ # but time_embedding might actually be running in fp16. so we need to cast here.
587
+ # there might be better ways to encapsulate this.
588
+ t_emb = t_emb.to(dtype=self.dtype)
589
+ emb = self.time_embedding(t_emb)
590
+
591
+ if self.class_embedding is not None:
592
+ if class_labels is None:
593
+ raise ValueError(
594
+ "class_labels should be provided when num_class_embeds > 0"
595
+ )
596
+
597
+ if self.config.class_embed_type == "timestep":
598
+ class_labels = self.time_proj(class_labels)
599
+
600
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
601
+ emb = emb + class_emb
602
+
603
+ # pre-process
604
+ sample = self.conv_in(sample)
605
+ if mask_cond_fea is not None:
606
+ sample = sample + mask_cond_fea
607
+
608
+ # down
609
+ down_block_res_samples = (sample,)
610
+ for downsample_block in self.down_blocks:
611
+ if (
612
+ hasattr(downsample_block, "has_cross_attention")
613
+ and downsample_block.has_cross_attention
614
+ ):
615
+ sample, res_samples = downsample_block(
616
+ hidden_states=sample,
617
+ temb=emb,
618
+ encoder_hidden_states=encoder_hidden_states,
619
+ attention_mask=attention_mask,
620
+ full_mask=full_mask,
621
+ face_mask=face_mask,
622
+ lip_mask=lip_mask,
623
+ audio_embedding=audio_embedding,
624
+ motion_scale=motion_scale,
625
+ )
626
+ # print("")
627
+ else:
628
+ sample, res_samples = downsample_block(
629
+ hidden_states=sample,
630
+ temb=emb,
631
+ encoder_hidden_states=encoder_hidden_states,
632
+ # audio_embedding=audio_embedding,
633
+ )
634
+ # print("")
635
+
636
+ down_block_res_samples += res_samples
637
+
638
+ if down_block_additional_residuals is not None:
639
+ new_down_block_res_samples = ()
640
+
641
+ for down_block_res_sample, down_block_additional_residual in zip(
642
+ down_block_res_samples, down_block_additional_residuals
643
+ ):
644
+ down_block_res_sample = (
645
+ down_block_res_sample + down_block_additional_residual
646
+ )
647
+ new_down_block_res_samples += (down_block_res_sample,)
648
+
649
+ down_block_res_samples = new_down_block_res_samples
650
+
651
+ # mid
652
+ sample = self.mid_block(
653
+ sample,
654
+ emb,
655
+ encoder_hidden_states=encoder_hidden_states,
656
+ attention_mask=attention_mask,
657
+ full_mask=full_mask,
658
+ face_mask=face_mask,
659
+ lip_mask=lip_mask,
660
+ audio_embedding=audio_embedding,
661
+ motion_scale=motion_scale,
662
+ )
663
+
664
+ if mid_block_additional_residual is not None:
665
+ sample = sample + mid_block_additional_residual
666
+
667
+ # up
668
+ for i, upsample_block in enumerate(self.up_blocks):
669
+ is_final_block = i == len(self.up_blocks) - 1
670
+
671
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
672
+ down_block_res_samples = down_block_res_samples[
673
+ : -len(upsample_block.resnets)
674
+ ]
675
+
676
+ # if we have not reached the final block and need to forward the
677
+ # upsample size, we do it here
678
+ if not is_final_block and forward_upsample_size:
679
+ upsample_size = down_block_res_samples[-1].shape[2:]
680
+
681
+ if (
682
+ hasattr(upsample_block, "has_cross_attention")
683
+ and upsample_block.has_cross_attention
684
+ ):
685
+ sample = upsample_block(
686
+ hidden_states=sample,
687
+ temb=emb,
688
+ res_hidden_states_tuple=res_samples,
689
+ encoder_hidden_states=encoder_hidden_states,
690
+ upsample_size=upsample_size,
691
+ attention_mask=attention_mask,
692
+ full_mask=full_mask,
693
+ face_mask=face_mask,
694
+ lip_mask=lip_mask,
695
+ audio_embedding=audio_embedding,
696
+ motion_scale=motion_scale,
697
+ )
698
+ else:
699
+ sample = upsample_block(
700
+ hidden_states=sample,
701
+ temb=emb,
702
+ res_hidden_states_tuple=res_samples,
703
+ upsample_size=upsample_size,
704
+ encoder_hidden_states=encoder_hidden_states,
705
+ # audio_embedding=audio_embedding,
706
+ )
707
+
708
+ # post-process
709
+ sample = self.conv_norm_out(sample)
710
+ sample = self.conv_act(sample)
711
+ sample = self.conv_out(sample)
712
+
713
+ if not return_dict:
714
+ return (sample,)
715
+
716
+ return UNet3DConditionOutput(sample=sample)
717
+
718
+ @classmethod
719
+ def from_pretrained_2d(
720
+ cls,
721
+ pretrained_model_path: PathLike,
722
+ motion_module_path: PathLike,
723
+ subfolder=None,
724
+ unet_additional_kwargs=None,
725
+ mm_zero_proj_out=False,
726
+ use_landmark=True,
727
+ ):
728
+ """
729
+ Load a pre-trained 2D UNet model from a given directory.
730
+
731
+ Parameters:
732
+ pretrained_model_path (`str` or `PathLike`):
733
+ Path to the directory containing a pre-trained 2D UNet model.
734
+ dtype (`torch.dtype`, *optional*):
735
+ The data type of the loaded model. If not provided, the default data type is used.
736
+ device (`torch.device`, *optional*):
737
+ The device on which the loaded model will be placed. If not provided, the default device is used.
738
+ **kwargs (`Any`):
739
+ Additional keyword arguments passed to the model.
740
+
741
+ Returns:
742
+ `UNet3DConditionModel`:
743
+ The loaded 2D UNet model.
744
+ """
745
+ pretrained_model_path = Path(pretrained_model_path)
746
+ motion_module_path = Path(motion_module_path)
747
+ if subfolder is not None:
748
+ pretrained_model_path = pretrained_model_path.joinpath(subfolder)
749
+ logger.info(
750
+ f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
751
+ )
752
+
753
+ config_file = pretrained_model_path / "config.json"
754
+ if not (config_file.exists() and config_file.is_file()):
755
+ raise RuntimeError(
756
+ f"{config_file} does not exist or is not a file")
757
+
758
+ unet_config = cls.load_config(config_file)
759
+ unet_config["_class_name"] = cls.__name__
760
+ unet_config["down_block_types"] = [
761
+ "CrossAttnDownBlock3D",
762
+ "CrossAttnDownBlock3D",
763
+ "CrossAttnDownBlock3D",
764
+ "DownBlock3D",
765
+ ]
766
+ unet_config["up_block_types"] = [
767
+ "UpBlock3D",
768
+ "CrossAttnUpBlock3D",
769
+ "CrossAttnUpBlock3D",
770
+ "CrossAttnUpBlock3D",
771
+ ]
772
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
773
+ if use_landmark:
774
+ unet_config["in_channels"] = 8
775
+ unet_config["out_channels"] = 8
776
+
777
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
778
+ # load the vanilla weights
779
+ if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
780
+ logger.debug(
781
+ f"loading safeTensors weights from {pretrained_model_path} ..."
782
+ )
783
+ state_dict = load_file(
784
+ pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
785
+ )
786
+
787
+ elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
788
+ logger.debug(f"loading weights from {pretrained_model_path} ...")
789
+ state_dict = torch.load(
790
+ pretrained_model_path.joinpath(WEIGHTS_NAME),
791
+ map_location="cpu",
792
+ weights_only=True,
793
+ )
794
+ else:
795
+ raise FileNotFoundError(
796
+ f"no weights file found in {pretrained_model_path}")
797
+
798
+ # load the motion module weights
799
+ if motion_module_path.exists() and motion_module_path.is_file():
800
+ if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
801
+ print(
802
+ f"Load motion module params from {motion_module_path}")
803
+ motion_state_dict = torch.load(
804
+ motion_module_path, map_location="cpu", weights_only=True
805
+ )
806
+ elif motion_module_path.suffix.lower() == ".safetensors":
807
+ motion_state_dict = load_file(motion_module_path, device="cpu")
808
+ else:
809
+ raise RuntimeError(
810
+ f"unknown file format for motion module weights: {motion_module_path.suffix}"
811
+ )
812
+ if mm_zero_proj_out:
813
+ logger.info(
814
+ "Zero initialize proj_out layers in motion module...")
815
+ new_motion_state_dict = OrderedDict()
816
+ for k in motion_state_dict:
817
+ if "proj_out" in k:
818
+ continue
819
+ new_motion_state_dict[k] = motion_state_dict[k]
820
+ motion_state_dict = new_motion_state_dict
821
+
822
+ # merge the state dicts
823
+ state_dict.update(motion_state_dict)
824
+
825
+ model_state_dict = model.state_dict()
826
+ for k in state_dict:
827
+ if k in model_state_dict:
828
+ if state_dict[k].shape != model_state_dict[k].shape:
829
+ state_dict[k] = model_state_dict[k]
830
+ # load the weights into the model
831
+ m, u = model.load_state_dict(state_dict, strict=False)
832
+ logger.debug(
833
+ f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
834
+
835
+ params = [
836
+ p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()
837
+ ]
838
+ logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module")
839
+
840
+ return model
joyhallo/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,1398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module defines various 3D UNet blocks used in the video model.
3
+
4
+ The blocks include:
5
+ - UNetMidBlock3DCrossAttn: The middle block of the UNet with cross attention.
6
+ - CrossAttnDownBlock3D: The downsampling block with cross attention.
7
+ - DownBlock3D: The standard downsampling block without cross attention.
8
+ - CrossAttnUpBlock3D: The upsampling block with cross attention.
9
+ - UpBlock3D: The standard upsampling block without cross attention.
10
+
11
+ These blocks are used to construct the 3D UNet architecture for video-related tasks.
12
+ """
13
+
14
+ import torch
15
+ from einops import rearrange
16
+ from torch import nn
17
+
18
+ from .motion_module import get_motion_module
19
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
20
+ from .transformer_3d import Transformer3DModel
21
+
22
+
23
+ def get_down_block(
24
+ down_block_type,
25
+ num_layers,
26
+ in_channels,
27
+ out_channels,
28
+ temb_channels,
29
+ add_downsample,
30
+ resnet_eps,
31
+ resnet_act_fn,
32
+ attn_num_head_channels,
33
+ resnet_groups=None,
34
+ cross_attention_dim=None,
35
+ audio_attention_dim=None,
36
+ downsample_padding=None,
37
+ dual_cross_attention=False,
38
+ use_linear_projection=False,
39
+ only_cross_attention=False,
40
+ upcast_attention=False,
41
+ resnet_time_scale_shift="default",
42
+ unet_use_cross_frame_attention=None,
43
+ unet_use_temporal_attention=None,
44
+ use_inflated_groupnorm=None,
45
+ use_motion_module=None,
46
+ motion_module_type=None,
47
+ motion_module_kwargs=None,
48
+ use_audio_module=None,
49
+ depth=0,
50
+ stack_enable_blocks_name=None,
51
+ stack_enable_blocks_depth=None,
52
+ ):
53
+ """
54
+ Factory function to instantiate a down-block module for the 3D UNet architecture.
55
+
56
+ Down blocks are used in the downsampling part of the U-Net to reduce the spatial dimensions
57
+ of the feature maps while increasing the depth. This function can create blocks with or without
58
+ cross attention based on the specified parameters.
59
+
60
+ Parameters:
61
+ - down_block_type (str): The type of down block to instantiate.
62
+ - num_layers (int): The number of layers in the block.
63
+ - in_channels (int): The number of input channels.
64
+ - out_channels (int): The number of output channels.
65
+ - temb_channels (int): The number of token embedding channels.
66
+ - add_downsample (bool): Flag to add a downsampling layer.
67
+ - resnet_eps (float): Epsilon for residual block stability.
68
+ - resnet_act_fn (callable): Activation function for the residual block.
69
+ - ... (remaining parameters): Additional parameters for configuring the block.
70
+
71
+ Returns:
72
+ - nn.Module: An instance of a down-sampling block module.
73
+ """
74
+ down_block_type = (
75
+ down_block_type[7:]
76
+ if down_block_type.startswith("UNetRes")
77
+ else down_block_type
78
+ )
79
+ if down_block_type == "DownBlock3D":
80
+ return DownBlock3D(
81
+ num_layers=num_layers,
82
+ in_channels=in_channels,
83
+ out_channels=out_channels,
84
+ temb_channels=temb_channels,
85
+ add_downsample=add_downsample,
86
+ resnet_eps=resnet_eps,
87
+ resnet_act_fn=resnet_act_fn,
88
+ resnet_groups=resnet_groups,
89
+ downsample_padding=downsample_padding,
90
+ resnet_time_scale_shift=resnet_time_scale_shift,
91
+ use_inflated_groupnorm=use_inflated_groupnorm,
92
+ use_motion_module=use_motion_module,
93
+ motion_module_type=motion_module_type,
94
+ motion_module_kwargs=motion_module_kwargs,
95
+ )
96
+
97
+ if down_block_type == "CrossAttnDownBlock3D":
98
+ if cross_attention_dim is None:
99
+ raise ValueError(
100
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
101
+ )
102
+ return CrossAttnDownBlock3D(
103
+ num_layers=num_layers,
104
+ in_channels=in_channels,
105
+ out_channels=out_channels,
106
+ temb_channels=temb_channels,
107
+ add_downsample=add_downsample,
108
+ resnet_eps=resnet_eps,
109
+ resnet_act_fn=resnet_act_fn,
110
+ resnet_groups=resnet_groups,
111
+ downsample_padding=downsample_padding,
112
+ cross_attention_dim=cross_attention_dim,
113
+ audio_attention_dim=audio_attention_dim,
114
+ attn_num_head_channels=attn_num_head_channels,
115
+ dual_cross_attention=dual_cross_attention,
116
+ use_linear_projection=use_linear_projection,
117
+ only_cross_attention=only_cross_attention,
118
+ upcast_attention=upcast_attention,
119
+ resnet_time_scale_shift=resnet_time_scale_shift,
120
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
121
+ unet_use_temporal_attention=unet_use_temporal_attention,
122
+ use_inflated_groupnorm=use_inflated_groupnorm,
123
+ use_motion_module=use_motion_module,
124
+ motion_module_type=motion_module_type,
125
+ motion_module_kwargs=motion_module_kwargs,
126
+ use_audio_module=use_audio_module,
127
+ depth=depth,
128
+ stack_enable_blocks_name=stack_enable_blocks_name,
129
+ stack_enable_blocks_depth=stack_enable_blocks_depth,
130
+ )
131
+ raise ValueError(f"{down_block_type} does not exist.")
132
+
133
+
134
+ def get_up_block(
135
+ up_block_type,
136
+ num_layers,
137
+ in_channels,
138
+ out_channels,
139
+ prev_output_channel,
140
+ temb_channels,
141
+ add_upsample,
142
+ resnet_eps,
143
+ resnet_act_fn,
144
+ attn_num_head_channels,
145
+ resnet_groups=None,
146
+ cross_attention_dim=None,
147
+ audio_attention_dim=None,
148
+ dual_cross_attention=False,
149
+ use_linear_projection=False,
150
+ only_cross_attention=False,
151
+ upcast_attention=False,
152
+ resnet_time_scale_shift="default",
153
+ unet_use_cross_frame_attention=None,
154
+ unet_use_temporal_attention=None,
155
+ use_inflated_groupnorm=None,
156
+ use_motion_module=None,
157
+ motion_module_type=None,
158
+ motion_module_kwargs=None,
159
+ use_audio_module=None,
160
+ depth=0,
161
+ stack_enable_blocks_name=None,
162
+ stack_enable_blocks_depth=None,
163
+ ):
164
+ """
165
+ Factory function to instantiate an up-block module for the 3D UNet architecture.
166
+
167
+ Up blocks are used in the upsampling part of the U-Net to increase the spatial dimensions
168
+ of the feature maps while decreasing the depth. This function can create blocks with or without
169
+ cross attention based on the specified parameters.
170
+
171
+ Parameters:
172
+ - up_block_type (str): The type of up block to instantiate.
173
+ - num_layers (int): The number of layers in the block.
174
+ - in_channels (int): The number of input channels.
175
+ - out_channels (int): The number of output channels.
176
+ - prev_output_channel (int): The number of channels from the previous layer's output.
177
+ - temb_channels (int): The number of token embedding channels.
178
+ - add_upsample (bool): Flag to add an upsampling layer.
179
+ - resnet_eps (float): Epsilon for residual block stability.
180
+ - resnet_act_fn (callable): Activation function for the residual block.
181
+ - ... (remaining parameters): Additional parameters for configuring the block.
182
+
183
+ Returns:
184
+ - nn.Module: An instance of an up-sampling block module.
185
+ """
186
+ up_block_type = (
187
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
188
+ )
189
+ if up_block_type == "UpBlock3D":
190
+ return UpBlock3D(
191
+ num_layers=num_layers,
192
+ in_channels=in_channels,
193
+ out_channels=out_channels,
194
+ prev_output_channel=prev_output_channel,
195
+ temb_channels=temb_channels,
196
+ add_upsample=add_upsample,
197
+ resnet_eps=resnet_eps,
198
+ resnet_act_fn=resnet_act_fn,
199
+ resnet_groups=resnet_groups,
200
+ resnet_time_scale_shift=resnet_time_scale_shift,
201
+ use_inflated_groupnorm=use_inflated_groupnorm,
202
+ use_motion_module=use_motion_module,
203
+ motion_module_type=motion_module_type,
204
+ motion_module_kwargs=motion_module_kwargs,
205
+ )
206
+
207
+ if up_block_type == "CrossAttnUpBlock3D":
208
+ if cross_attention_dim is None:
209
+ raise ValueError(
210
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
211
+ )
212
+ return CrossAttnUpBlock3D(
213
+ num_layers=num_layers,
214
+ in_channels=in_channels,
215
+ out_channels=out_channels,
216
+ prev_output_channel=prev_output_channel,
217
+ temb_channels=temb_channels,
218
+ add_upsample=add_upsample,
219
+ resnet_eps=resnet_eps,
220
+ resnet_act_fn=resnet_act_fn,
221
+ resnet_groups=resnet_groups,
222
+ cross_attention_dim=cross_attention_dim,
223
+ audio_attention_dim=audio_attention_dim,
224
+ attn_num_head_channels=attn_num_head_channels,
225
+ dual_cross_attention=dual_cross_attention,
226
+ use_linear_projection=use_linear_projection,
227
+ only_cross_attention=only_cross_attention,
228
+ upcast_attention=upcast_attention,
229
+ resnet_time_scale_shift=resnet_time_scale_shift,
230
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
231
+ unet_use_temporal_attention=unet_use_temporal_attention,
232
+ use_inflated_groupnorm=use_inflated_groupnorm,
233
+ use_motion_module=use_motion_module,
234
+ motion_module_type=motion_module_type,
235
+ motion_module_kwargs=motion_module_kwargs,
236
+ use_audio_module=use_audio_module,
237
+ depth=depth,
238
+ stack_enable_blocks_name=stack_enable_blocks_name,
239
+ stack_enable_blocks_depth=stack_enable_blocks_depth,
240
+ )
241
+ raise ValueError(f"{up_block_type} does not exist.")
242
+
243
+
244
+ class UNetMidBlock3DCrossAttn(nn.Module):
245
+ """
246
+ A 3D UNet middle block with cross attention mechanism. This block is part of the U-Net architecture
247
+ and is used for feature extraction in the middle of the downsampling path.
248
+
249
+ Parameters:
250
+ - in_channels (int): Number of input channels.
251
+ - temb_channels (int): Number of token embedding channels.
252
+ - dropout (float): Dropout rate.
253
+ - num_layers (int): Number of layers in the block.
254
+ - resnet_eps (float): Epsilon for residual block.
255
+ - resnet_time_scale_shift (str): Time scale shift for time embedding normalization.
256
+ - resnet_act_fn (str): Activation function for the residual block.
257
+ - resnet_groups (int): Number of groups for the convolutions in the residual block.
258
+ - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block.
259
+ - attn_num_head_channels (int): Number of attention heads.
260
+ - cross_attention_dim (int): Dimensionality of the cross attention layers.
261
+ - audio_attention_dim (int): Dimensionality of the audio attention layers.
262
+ - dual_cross_attention (bool): Whether to use dual cross attention.
263
+ - use_linear_projection (bool): Whether to use linear projection in attention.
264
+ - upcast_attention (bool): Whether to upcast attention to the original input dimension.
265
+ - unet_use_cross_frame_attention (bool): Whether to use cross frame attention in U-Net.
266
+ - unet_use_temporal_attention (bool): Whether to use temporal attention in U-Net.
267
+ - use_inflated_groupnorm (bool): Whether to use inflated group normalization.
268
+ - use_motion_module (bool): Whether to use motion module.
269
+ - motion_module_type (str): Type of motion module.
270
+ - motion_module_kwargs (dict): Keyword arguments for the motion module.
271
+ - use_audio_module (bool): Whether to use audio module.
272
+ - depth (int): Depth of the block in the network.
273
+ - stack_enable_blocks_name (str): Name of the stack enable blocks.
274
+ - stack_enable_blocks_depth (int): Depth of the stack enable blocks.
275
+
276
+ Forward method:
277
+ The forward method applies the residual blocks, cross attention, and optional motion and audio modules
278
+ to the input hidden states. It returns the transformed hidden states.
279
+ """
280
+ def __init__(
281
+ self,
282
+ in_channels: int,
283
+ temb_channels: int,
284
+ dropout: float = 0.0,
285
+ num_layers: int = 1,
286
+ resnet_eps: float = 1e-6,
287
+ resnet_time_scale_shift: str = "default",
288
+ resnet_act_fn: str = "swish",
289
+ resnet_groups: int = 32,
290
+ resnet_pre_norm: bool = True,
291
+ attn_num_head_channels=1,
292
+ output_scale_factor=1.0,
293
+ cross_attention_dim=1280,
294
+ audio_attention_dim=1024,
295
+ dual_cross_attention=False,
296
+ use_linear_projection=False,
297
+ upcast_attention=False,
298
+ unet_use_cross_frame_attention=None,
299
+ unet_use_temporal_attention=None,
300
+ use_inflated_groupnorm=None,
301
+ use_motion_module=None,
302
+ motion_module_type=None,
303
+ motion_module_kwargs=None,
304
+ use_audio_module=None,
305
+ depth=0,
306
+ stack_enable_blocks_name=None,
307
+ stack_enable_blocks_depth=None,
308
+ ):
309
+ super().__init__()
310
+
311
+ self.has_cross_attention = True
312
+ self.attn_num_head_channels = attn_num_head_channels
313
+ resnet_groups = (
314
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
315
+ )
316
+
317
+ # there is always at least one resnet
318
+ resnets = [
319
+ ResnetBlock3D(
320
+ in_channels=in_channels,
321
+ out_channels=in_channels,
322
+ temb_channels=temb_channels,
323
+ eps=resnet_eps,
324
+ groups=resnet_groups,
325
+ dropout=dropout,
326
+ time_embedding_norm=resnet_time_scale_shift,
327
+ non_linearity=resnet_act_fn,
328
+ output_scale_factor=output_scale_factor,
329
+ pre_norm=resnet_pre_norm,
330
+ use_inflated_groupnorm=use_inflated_groupnorm,
331
+ )
332
+ ]
333
+ attentions = []
334
+ motion_modules = []
335
+ audio_modules = []
336
+
337
+ for _ in range(num_layers):
338
+ if dual_cross_attention:
339
+ raise NotImplementedError
340
+ attentions.append(
341
+ Transformer3DModel(
342
+ attn_num_head_channels,
343
+ in_channels // attn_num_head_channels,
344
+ in_channels=in_channels,
345
+ num_layers=1,
346
+ cross_attention_dim=cross_attention_dim,
347
+ norm_num_groups=resnet_groups,
348
+ use_linear_projection=use_linear_projection,
349
+ upcast_attention=upcast_attention,
350
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
351
+ unet_use_temporal_attention=unet_use_temporal_attention,
352
+ )
353
+ )
354
+ audio_modules.append(
355
+ Transformer3DModel(
356
+ attn_num_head_channels,
357
+ in_channels // attn_num_head_channels,
358
+ in_channels=in_channels,
359
+ num_layers=1,
360
+ cross_attention_dim=audio_attention_dim,
361
+ norm_num_groups=resnet_groups,
362
+ use_linear_projection=use_linear_projection,
363
+ upcast_attention=upcast_attention,
364
+ use_audio_module=use_audio_module,
365
+ depth=depth,
366
+ unet_block_name="mid",
367
+ stack_enable_blocks_name=stack_enable_blocks_name,
368
+ stack_enable_blocks_depth=stack_enable_blocks_depth,
369
+ )
370
+ if use_audio_module
371
+ else None
372
+ )
373
+
374
+ motion_modules.append(
375
+ get_motion_module(
376
+ in_channels=in_channels,
377
+ motion_module_type=motion_module_type,
378
+ motion_module_kwargs=motion_module_kwargs,
379
+ )
380
+ if use_motion_module
381
+ else None
382
+ )
383
+ resnets.append(
384
+ ResnetBlock3D(
385
+ in_channels=in_channels,
386
+ out_channels=in_channels,
387
+ temb_channels=temb_channels,
388
+ eps=resnet_eps,
389
+ groups=resnet_groups,
390
+ dropout=dropout,
391
+ time_embedding_norm=resnet_time_scale_shift,
392
+ non_linearity=resnet_act_fn,
393
+ output_scale_factor=output_scale_factor,
394
+ pre_norm=resnet_pre_norm,
395
+ use_inflated_groupnorm=use_inflated_groupnorm,
396
+ )
397
+ )
398
+
399
+ self.attentions = nn.ModuleList(attentions)
400
+ self.resnets = nn.ModuleList(resnets)
401
+ self.audio_modules = nn.ModuleList(audio_modules)
402
+ self.motion_modules = nn.ModuleList(motion_modules)
403
+
404
+ def forward(
405
+ self,
406
+ hidden_states,
407
+ temb=None,
408
+ encoder_hidden_states=None,
409
+ attention_mask=None,
410
+ full_mask=None,
411
+ face_mask=None,
412
+ lip_mask=None,
413
+ audio_embedding=None,
414
+ motion_scale=None,
415
+ ):
416
+ """
417
+ Forward pass for the UNetMidBlock3DCrossAttn class.
418
+
419
+ Args:
420
+ self (UNetMidBlock3DCrossAttn): An instance of the UNetMidBlock3DCrossAttn class.
421
+ hidden_states (Tensor): The input hidden states tensor.
422
+ temb (Tensor, optional): The input temporal embedding tensor. Defaults to None.
423
+ encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None.
424
+ attention_mask (Tensor, optional): The attention mask tensor. Defaults to None.
425
+ full_mask (Tensor, optional): The full mask tensor. Defaults to None.
426
+ face_mask (Tensor, optional): The face mask tensor. Defaults to None.
427
+ lip_mask (Tensor, optional): The lip mask tensor. Defaults to None.
428
+ audio_embedding (Tensor, optional): The audio embedding tensor. Defaults to None.
429
+
430
+ Returns:
431
+ Tensor: The output tensor after passing through the UNetMidBlock3DCrossAttn layers.
432
+ """
433
+ hidden_states = self.resnets[0](hidden_states, temb)
434
+ for attn, resnet, audio_module, motion_module in zip(
435
+ self.attentions, self.resnets[1:], self.audio_modules, self.motion_modules
436
+ ):
437
+ hidden_states, motion_frame = attn(
438
+ hidden_states,
439
+ encoder_hidden_states=encoder_hidden_states,
440
+ return_dict=False,
441
+ ) # .sample
442
+ if len(motion_frame[0]) > 0:
443
+ # if motion_frame[0][0].numel() > 0:
444
+ motion_frames = motion_frame[0][0]
445
+ motion_frames = rearrange(
446
+ motion_frames,
447
+ "b f (d1 d2) c -> b c f d1 d2",
448
+ d1=hidden_states.size(-1),
449
+ )
450
+
451
+ else:
452
+ motion_frames = torch.zeros(
453
+ hidden_states.shape[0],
454
+ hidden_states.shape[1],
455
+ 4,
456
+ hidden_states.shape[3],
457
+ hidden_states.shape[4],
458
+ )
459
+
460
+ n_motion_frames = motion_frames.size(2)
461
+ if audio_module is not None:
462
+ hidden_states = (
463
+ audio_module(
464
+ hidden_states,
465
+ encoder_hidden_states=audio_embedding,
466
+ attention_mask=attention_mask,
467
+ full_mask=full_mask,
468
+ face_mask=face_mask,
469
+ lip_mask=lip_mask,
470
+ motion_scale=motion_scale,
471
+ return_dict=False,
472
+ )
473
+ )[0] # .sample
474
+ if motion_module is not None:
475
+ motion_frames = motion_frames.to(
476
+ device=hidden_states.device, dtype=hidden_states.dtype
477
+ )
478
+
479
+ _hidden_states = (
480
+ torch.cat([motion_frames, hidden_states], dim=2)
481
+ if n_motion_frames > 0
482
+ else hidden_states
483
+ )
484
+ hidden_states = motion_module(
485
+ _hidden_states, encoder_hidden_states=encoder_hidden_states
486
+ )
487
+ hidden_states = hidden_states[:, :, n_motion_frames:]
488
+
489
+ hidden_states = resnet(hidden_states, temb)
490
+
491
+ return hidden_states
492
+
493
+
494
+ class CrossAttnDownBlock3D(nn.Module):
495
+ """
496
+ A 3D downsampling block with cross attention for the U-Net architecture.
497
+
498
+ Parameters:
499
+ - (same as above, refer to the constructor for details)
500
+
501
+ Forward method:
502
+ The forward method downsamples the input hidden states using residual blocks and cross attention.
503
+ It also applies optional motion and audio modules. The method supports gradient checkpointing
504
+ to save memory during training.
505
+ """
506
+ def __init__(
507
+ self,
508
+ in_channels: int,
509
+ out_channels: int,
510
+ temb_channels: int,
511
+ dropout: float = 0.0,
512
+ num_layers: int = 1,
513
+ resnet_eps: float = 1e-6,
514
+ resnet_time_scale_shift: str = "default",
515
+ resnet_act_fn: str = "swish",
516
+ resnet_groups: int = 32,
517
+ resnet_pre_norm: bool = True,
518
+ attn_num_head_channels=1,
519
+ cross_attention_dim=1280,
520
+ audio_attention_dim=1024,
521
+ output_scale_factor=1.0,
522
+ downsample_padding=1,
523
+ add_downsample=True,
524
+ dual_cross_attention=False,
525
+ use_linear_projection=False,
526
+ only_cross_attention=False,
527
+ upcast_attention=False,
528
+ unet_use_cross_frame_attention=None,
529
+ unet_use_temporal_attention=None,
530
+ use_inflated_groupnorm=None,
531
+ use_motion_module=None,
532
+ motion_module_type=None,
533
+ motion_module_kwargs=None,
534
+ use_audio_module=None,
535
+ depth=0,
536
+ stack_enable_blocks_name=None,
537
+ stack_enable_blocks_depth=None,
538
+ ):
539
+ super().__init__()
540
+ resnets = []
541
+ attentions = []
542
+ audio_modules = []
543
+ motion_modules = []
544
+
545
+ self.has_cross_attention = True
546
+ self.attn_num_head_channels = attn_num_head_channels
547
+
548
+ for i in range(num_layers):
549
+ in_channels = in_channels if i == 0 else out_channels
550
+ resnets.append(
551
+ ResnetBlock3D(
552
+ in_channels=in_channels,
553
+ out_channels=out_channels,
554
+ temb_channels=temb_channels,
555
+ eps=resnet_eps,
556
+ groups=resnet_groups,
557
+ dropout=dropout,
558
+ time_embedding_norm=resnet_time_scale_shift,
559
+ non_linearity=resnet_act_fn,
560
+ output_scale_factor=output_scale_factor,
561
+ pre_norm=resnet_pre_norm,
562
+ use_inflated_groupnorm=use_inflated_groupnorm,
563
+ )
564
+ )
565
+ if dual_cross_attention:
566
+ raise NotImplementedError
567
+ attentions.append(
568
+ Transformer3DModel(
569
+ attn_num_head_channels,
570
+ out_channels // attn_num_head_channels,
571
+ in_channels=out_channels,
572
+ num_layers=1,
573
+ cross_attention_dim=cross_attention_dim,
574
+ norm_num_groups=resnet_groups,
575
+ use_linear_projection=use_linear_projection,
576
+ only_cross_attention=only_cross_attention,
577
+ upcast_attention=upcast_attention,
578
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
579
+ unet_use_temporal_attention=unet_use_temporal_attention,
580
+ )
581
+ )
582
+ # TODO:检查维度
583
+ audio_modules.append(
584
+ Transformer3DModel(
585
+ attn_num_head_channels,
586
+ in_channels // attn_num_head_channels,
587
+ in_channels=out_channels,
588
+ num_layers=1,
589
+ cross_attention_dim=audio_attention_dim,
590
+ norm_num_groups=resnet_groups,
591
+ use_linear_projection=use_linear_projection,
592
+ only_cross_attention=only_cross_attention,
593
+ upcast_attention=upcast_attention,
594
+ use_audio_module=use_audio_module,
595
+ depth=depth,
596
+ unet_block_name="down",
597
+ stack_enable_blocks_name=stack_enable_blocks_name,
598
+ stack_enable_blocks_depth=stack_enable_blocks_depth,
599
+ )
600
+ if use_audio_module
601
+ else None
602
+ )
603
+ motion_modules.append(
604
+ get_motion_module(
605
+ in_channels=out_channels,
606
+ motion_module_type=motion_module_type,
607
+ motion_module_kwargs=motion_module_kwargs,
608
+ )
609
+ if use_motion_module
610
+ else None
611
+ )
612
+
613
+ self.attentions = nn.ModuleList(attentions)
614
+ self.resnets = nn.ModuleList(resnets)
615
+ self.audio_modules = nn.ModuleList(audio_modules)
616
+ self.motion_modules = nn.ModuleList(motion_modules)
617
+
618
+ if add_downsample:
619
+ self.downsamplers = nn.ModuleList(
620
+ [
621
+ Downsample3D(
622
+ out_channels,
623
+ use_conv=True,
624
+ out_channels=out_channels,
625
+ padding=downsample_padding,
626
+ name="op",
627
+ )
628
+ ]
629
+ )
630
+ else:
631
+ self.downsamplers = None
632
+
633
+ self.gradient_checkpointing = False
634
+
635
+ def forward(
636
+ self,
637
+ hidden_states,
638
+ temb=None,
639
+ encoder_hidden_states=None,
640
+ attention_mask=None,
641
+ full_mask=None,
642
+ face_mask=None,
643
+ lip_mask=None,
644
+ audio_embedding=None,
645
+ motion_scale=None,
646
+ ):
647
+ """
648
+ Defines the forward pass for the CrossAttnDownBlock3D class.
649
+
650
+ Parameters:
651
+ - hidden_states : torch.Tensor
652
+ The input tensor to the block.
653
+ temb : torch.Tensor, optional
654
+ The token embeddings from the previous block.
655
+ encoder_hidden_states : torch.Tensor, optional
656
+ The hidden states from the encoder.
657
+ attention_mask : torch.Tensor, optional
658
+ The attention mask for the cross-attention mechanism.
659
+ full_mask : torch.Tensor, optional
660
+ The full mask for the cross-attention mechanism.
661
+ face_mask : torch.Tensor, optional
662
+ The face mask for the cross-attention mechanism.
663
+ lip_mask : torch.Tensor, optional
664
+ The lip mask for the cross-attention mechanism.
665
+ audio_embedding : torch.Tensor, optional
666
+ The audio embedding for the cross-attention mechanism.
667
+
668
+ Returns:
669
+ -- torch.Tensor
670
+ The output tensor from the block.
671
+ """
672
+ output_states = ()
673
+
674
+ for _, (resnet, attn, audio_module, motion_module) in enumerate(
675
+ zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules)
676
+ ):
677
+ # self.gradient_checkpointing = False
678
+ if self.training and self.gradient_checkpointing:
679
+
680
+ def create_custom_forward(module, return_dict=None):
681
+ def custom_forward(*inputs):
682
+ if return_dict is not None:
683
+ return module(*inputs, return_dict=return_dict)
684
+
685
+ return module(*inputs)
686
+
687
+ return custom_forward
688
+
689
+ hidden_states = torch.utils.checkpoint.checkpoint(
690
+ create_custom_forward(resnet), hidden_states, temb
691
+ )
692
+
693
+ motion_frames = []
694
+ hidden_states, motion_frame = torch.utils.checkpoint.checkpoint(
695
+ create_custom_forward(attn, return_dict=False),
696
+ hidden_states,
697
+ encoder_hidden_states,
698
+ )
699
+ if len(motion_frame[0]) > 0:
700
+ motion_frames = motion_frame[0][0]
701
+ # motion_frames = torch.cat(motion_frames, dim=0)
702
+ motion_frames = rearrange(
703
+ motion_frames,
704
+ "b f (d1 d2) c -> b c f d1 d2",
705
+ d1=hidden_states.size(-1),
706
+ )
707
+
708
+ else:
709
+ motion_frames = torch.zeros(
710
+ hidden_states.shape[0],
711
+ hidden_states.shape[1],
712
+ 4,
713
+ hidden_states.shape[3],
714
+ hidden_states.shape[4],
715
+ )
716
+
717
+ n_motion_frames = motion_frames.size(2)
718
+
719
+ if audio_module is not None:
720
+ # audio_embedding = audio_embedding
721
+ hidden_states = torch.utils.checkpoint.checkpoint(
722
+ create_custom_forward(audio_module, return_dict=False),
723
+ hidden_states,
724
+ audio_embedding,
725
+ attention_mask,
726
+ full_mask,
727
+ face_mask,
728
+ lip_mask,
729
+ motion_scale,
730
+ )[0]
731
+
732
+ # add motion module
733
+ if motion_module is not None:
734
+ motion_frames = motion_frames.to(
735
+ device=hidden_states.device, dtype=hidden_states.dtype
736
+ )
737
+ _hidden_states = torch.cat(
738
+ [motion_frames, hidden_states], dim=2
739
+ ) # if n_motion_frames > 0 else hidden_states
740
+ hidden_states = torch.utils.checkpoint.checkpoint(
741
+ create_custom_forward(motion_module),
742
+ _hidden_states,
743
+ encoder_hidden_states,
744
+ )
745
+ hidden_states = hidden_states[:, :, n_motion_frames:]
746
+
747
+ else:
748
+ hidden_states = resnet(hidden_states, temb)
749
+ hidden_states = attn(
750
+ hidden_states,
751
+ encoder_hidden_states=encoder_hidden_states,
752
+ ).sample
753
+ if audio_module is not None:
754
+ hidden_states = audio_module(
755
+ hidden_states,
756
+ audio_embedding,
757
+ attention_mask=attention_mask,
758
+ full_mask=full_mask,
759
+ face_mask=face_mask,
760
+ lip_mask=lip_mask,
761
+ return_dict=False,
762
+ )[0]
763
+ # add motion module
764
+ if motion_module is not None:
765
+ hidden_states = motion_module(
766
+ hidden_states, encoder_hidden_states=encoder_hidden_states
767
+ )
768
+
769
+ output_states += (hidden_states,)
770
+
771
+ if self.downsamplers is not None:
772
+ for downsampler in self.downsamplers:
773
+ hidden_states = downsampler(hidden_states)
774
+
775
+ output_states += (hidden_states,)
776
+
777
+ return hidden_states, output_states
778
+
779
+
780
+ class DownBlock3D(nn.Module):
781
+ """
782
+ A 3D downsampling block for the U-Net architecture. This block performs downsampling operations
783
+ using residual blocks and an optional motion module.
784
+
785
+ Parameters:
786
+ - in_channels (int): Number of input channels.
787
+ - out_channels (int): Number of output channels.
788
+ - temb_channels (int): Number of token embedding channels.
789
+ - dropout (float): Dropout rate for the block.
790
+ - num_layers (int): Number of layers in the block.
791
+ - resnet_eps (float): Epsilon for residual block stability.
792
+ - resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding.
793
+ - resnet_act_fn (str): Activation function used in the residual block.
794
+ - resnet_groups (int): Number of groups for the convolutions in the residual block.
795
+ - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block.
796
+ - output_scale_factor (float): Scaling factor for the block's output.
797
+ - add_downsample (bool): Whether to add a downsampling layer.
798
+ - downsample_padding (int): Padding for the downsampling layer.
799
+ - use_inflated_groupnorm (bool): Whether to use inflated group normalization.
800
+ - use_motion_module (bool): Whether to include a motion module.
801
+ - motion_module_type (str): Type of motion module to use.
802
+ - motion_module_kwargs (dict): Keyword arguments for the motion module.
803
+
804
+ Forward method:
805
+ The forward method processes the input hidden states through the residual blocks and optional
806
+ motion modules, followed by an optional downsampling step. It supports gradient checkpointing
807
+ during training to reduce memory usage.
808
+ """
809
+ def __init__(
810
+ self,
811
+ in_channels: int,
812
+ out_channels: int,
813
+ temb_channels: int,
814
+ dropout: float = 0.0,
815
+ num_layers: int = 1,
816
+ resnet_eps: float = 1e-6,
817
+ resnet_time_scale_shift: str = "default",
818
+ resnet_act_fn: str = "swish",
819
+ resnet_groups: int = 32,
820
+ resnet_pre_norm: bool = True,
821
+ output_scale_factor=1.0,
822
+ add_downsample=True,
823
+ downsample_padding=1,
824
+ use_inflated_groupnorm=None,
825
+ use_motion_module=None,
826
+ motion_module_type=None,
827
+ motion_module_kwargs=None,
828
+ ):
829
+ super().__init__()
830
+ resnets = []
831
+ motion_modules = []
832
+
833
+ # use_motion_module = False
834
+ for i in range(num_layers):
835
+ in_channels = in_channels if i == 0 else out_channels
836
+ resnets.append(
837
+ ResnetBlock3D(
838
+ in_channels=in_channels,
839
+ out_channels=out_channels,
840
+ temb_channels=temb_channels,
841
+ eps=resnet_eps,
842
+ groups=resnet_groups,
843
+ dropout=dropout,
844
+ time_embedding_norm=resnet_time_scale_shift,
845
+ non_linearity=resnet_act_fn,
846
+ output_scale_factor=output_scale_factor,
847
+ pre_norm=resnet_pre_norm,
848
+ use_inflated_groupnorm=use_inflated_groupnorm,
849
+ )
850
+ )
851
+ motion_modules.append(
852
+ get_motion_module(
853
+ in_channels=out_channels,
854
+ motion_module_type=motion_module_type,
855
+ motion_module_kwargs=motion_module_kwargs,
856
+ )
857
+ if use_motion_module
858
+ else None
859
+ )
860
+
861
+ self.resnets = nn.ModuleList(resnets)
862
+ self.motion_modules = nn.ModuleList(motion_modules)
863
+
864
+ if add_downsample:
865
+ self.downsamplers = nn.ModuleList(
866
+ [
867
+ Downsample3D(
868
+ out_channels,
869
+ use_conv=True,
870
+ out_channels=out_channels,
871
+ padding=downsample_padding,
872
+ name="op",
873
+ )
874
+ ]
875
+ )
876
+ else:
877
+ self.downsamplers = None
878
+
879
+ self.gradient_checkpointing = False
880
+
881
+ def forward(
882
+ self,
883
+ hidden_states,
884
+ temb=None,
885
+ encoder_hidden_states=None,
886
+ ):
887
+ """
888
+ forward method for the DownBlock3D class.
889
+
890
+ Args:
891
+ hidden_states (Tensor): The input tensor to the DownBlock3D layer.
892
+ temb (Tensor, optional): The token embeddings, if using transformer.
893
+ encoder_hidden_states (Tensor, optional): The hidden states from the encoder.
894
+
895
+ Returns:
896
+ Tensor: The output tensor after passing through the DownBlock3D layer.
897
+ """
898
+ output_states = ()
899
+
900
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
901
+ # print(f"DownBlock3D {self.gradient_checkpointing = }")
902
+ if self.training and self.gradient_checkpointing:
903
+
904
+ def create_custom_forward(module):
905
+ def custom_forward(*inputs):
906
+ return module(*inputs)
907
+
908
+ return custom_forward
909
+
910
+ hidden_states = torch.utils.checkpoint.checkpoint(
911
+ create_custom_forward(resnet), hidden_states, temb
912
+ )
913
+
914
+ else:
915
+ hidden_states = resnet(hidden_states, temb)
916
+
917
+ # add motion module
918
+ hidden_states = (
919
+ motion_module(
920
+ hidden_states, encoder_hidden_states=encoder_hidden_states
921
+ )
922
+ if motion_module is not None
923
+ else hidden_states
924
+ )
925
+
926
+ output_states += (hidden_states,)
927
+
928
+ if self.downsamplers is not None:
929
+ for downsampler in self.downsamplers:
930
+ hidden_states = downsampler(hidden_states)
931
+
932
+ output_states += (hidden_states,)
933
+
934
+ return hidden_states, output_states
935
+
936
+
937
+ class CrossAttnUpBlock3D(nn.Module):
938
+ """
939
+ Standard 3D downsampling block for the U-Net architecture. This block performs downsampling
940
+ operations in the U-Net using residual blocks and an optional motion module.
941
+
942
+ Parameters:
943
+ - in_channels (int): Number of input channels.
944
+ - out_channels (int): Number of output channels.
945
+ - temb_channels (int): Number of channels for the temporal embedding.
946
+ - dropout (float): Dropout rate for the block.
947
+ - num_layers (int): Number of layers in the block.
948
+ - resnet_eps (float): Epsilon for residual block stability.
949
+ - resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding.
950
+ - resnet_act_fn (str): Activation function used in the residual block.
951
+ - resnet_groups (int): Number of groups for the convolutions in the residual block.
952
+ - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block.
953
+ - output_scale_factor (float): Scaling factor for the block's output.
954
+ - add_downsample (bool): Whether to add a downsampling layer.
955
+ - downsample_padding (int): Padding for the downsampling layer.
956
+ - use_inflated_groupnorm (bool): Whether to use inflated group normalization.
957
+ - use_motion_module (bool): Whether to include a motion module.
958
+ - motion_module_type (str): Type of motion module to use.
959
+ - motion_module_kwargs (dict): Keyword arguments for the motion module.
960
+
961
+ Forward method:
962
+ The forward method processes the input hidden states through the residual blocks and optional
963
+ motion modules, followed by an optional downsampling step. It supports gradient checkpointing
964
+ during training to reduce memory usage.
965
+ """
966
+ def __init__(
967
+ self,
968
+ in_channels: int,
969
+ out_channels: int,
970
+ prev_output_channel: int,
971
+ temb_channels: int,
972
+ dropout: float = 0.0,
973
+ num_layers: int = 1,
974
+ resnet_eps: float = 1e-6,
975
+ resnet_time_scale_shift: str = "default",
976
+ resnet_act_fn: str = "swish",
977
+ resnet_groups: int = 32,
978
+ resnet_pre_norm: bool = True,
979
+ attn_num_head_channels=1,
980
+ cross_attention_dim=1280,
981
+ audio_attention_dim=1024,
982
+ output_scale_factor=1.0,
983
+ add_upsample=True,
984
+ dual_cross_attention=False,
985
+ use_linear_projection=False,
986
+ only_cross_attention=False,
987
+ upcast_attention=False,
988
+ unet_use_cross_frame_attention=None,
989
+ unet_use_temporal_attention=None,
990
+ use_motion_module=None,
991
+ use_inflated_groupnorm=None,
992
+ motion_module_type=None,
993
+ motion_module_kwargs=None,
994
+ use_audio_module=None,
995
+ depth=0,
996
+ stack_enable_blocks_name=None,
997
+ stack_enable_blocks_depth=None,
998
+ ):
999
+ super().__init__()
1000
+ resnets = []
1001
+ attentions = []
1002
+ audio_modules = []
1003
+ motion_modules = []
1004
+
1005
+ self.has_cross_attention = True
1006
+ self.attn_num_head_channels = attn_num_head_channels
1007
+
1008
+ for i in range(num_layers):
1009
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1010
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1011
+
1012
+ resnets.append(
1013
+ ResnetBlock3D(
1014
+ in_channels=resnet_in_channels + res_skip_channels,
1015
+ out_channels=out_channels,
1016
+ temb_channels=temb_channels,
1017
+ eps=resnet_eps,
1018
+ groups=resnet_groups,
1019
+ dropout=dropout,
1020
+ time_embedding_norm=resnet_time_scale_shift,
1021
+ non_linearity=resnet_act_fn,
1022
+ output_scale_factor=output_scale_factor,
1023
+ pre_norm=resnet_pre_norm,
1024
+ use_inflated_groupnorm=use_inflated_groupnorm,
1025
+ )
1026
+ )
1027
+
1028
+ if dual_cross_attention:
1029
+ raise NotImplementedError
1030
+ attentions.append(
1031
+ Transformer3DModel(
1032
+ attn_num_head_channels,
1033
+ out_channels // attn_num_head_channels,
1034
+ in_channels=out_channels,
1035
+ num_layers=1,
1036
+ cross_attention_dim=cross_attention_dim,
1037
+ norm_num_groups=resnet_groups,
1038
+ use_linear_projection=use_linear_projection,
1039
+ only_cross_attention=only_cross_attention,
1040
+ upcast_attention=upcast_attention,
1041
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
1042
+ unet_use_temporal_attention=unet_use_temporal_attention,
1043
+ )
1044
+ )
1045
+ audio_modules.append(
1046
+ Transformer3DModel(
1047
+ attn_num_head_channels,
1048
+ in_channels // attn_num_head_channels,
1049
+ in_channels=out_channels,
1050
+ num_layers=1,
1051
+ cross_attention_dim=audio_attention_dim,
1052
+ norm_num_groups=resnet_groups,
1053
+ use_linear_projection=use_linear_projection,
1054
+ only_cross_attention=only_cross_attention,
1055
+ upcast_attention=upcast_attention,
1056
+ use_audio_module=use_audio_module,
1057
+ depth=depth,
1058
+ unet_block_name="up",
1059
+ stack_enable_blocks_name=stack_enable_blocks_name,
1060
+ stack_enable_blocks_depth=stack_enable_blocks_depth,
1061
+ )
1062
+ if use_audio_module
1063
+ else None
1064
+ )
1065
+ motion_modules.append(
1066
+ get_motion_module(
1067
+ in_channels=out_channels,
1068
+ motion_module_type=motion_module_type,
1069
+ motion_module_kwargs=motion_module_kwargs,
1070
+ )
1071
+ if use_motion_module
1072
+ else None
1073
+ )
1074
+
1075
+ self.attentions = nn.ModuleList(attentions)
1076
+ self.resnets = nn.ModuleList(resnets)
1077
+ self.audio_modules = nn.ModuleList(audio_modules)
1078
+ self.motion_modules = nn.ModuleList(motion_modules)
1079
+
1080
+ if add_upsample:
1081
+ self.upsamplers = nn.ModuleList(
1082
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
1083
+ )
1084
+ else:
1085
+ self.upsamplers = None
1086
+
1087
+ self.gradient_checkpointing = False
1088
+
1089
+ def forward(
1090
+ self,
1091
+ hidden_states,
1092
+ res_hidden_states_tuple,
1093
+ temb=None,
1094
+ encoder_hidden_states=None,
1095
+ upsample_size=None,
1096
+ attention_mask=None,
1097
+ full_mask=None,
1098
+ face_mask=None,
1099
+ lip_mask=None,
1100
+ audio_embedding=None,
1101
+ motion_scale=None,
1102
+ ):
1103
+ """
1104
+ Forward pass for the CrossAttnUpBlock3D class.
1105
+
1106
+ Args:
1107
+ self (CrossAttnUpBlock3D): An instance of the CrossAttnUpBlock3D class.
1108
+ hidden_states (Tensor): The input hidden states tensor.
1109
+ res_hidden_states_tuple (Tuple[Tensor]): A tuple of residual hidden states tensors.
1110
+ temb (Tensor, optional): The token embeddings tensor. Defaults to None.
1111
+ encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None.
1112
+ upsample_size (int, optional): The upsample size. Defaults to None.
1113
+ attention_mask (Tensor, optional): The attention mask tensor. Defaults to None.
1114
+ full_mask (Tensor, optional): The full mask tensor. Defaults to None.
1115
+ face_mask (Tensor, optional): The face mask tensor. Defaults to None.
1116
+ lip_mask (Tensor, optional): The lip mask tensor. Defaults to None.
1117
+ audio_embedding (Tensor, optional): The audio embedding tensor. Defaults to None.
1118
+
1119
+ Returns:
1120
+ Tensor: The output tensor after passing through the CrossAttnUpBlock3D.
1121
+ """
1122
+ for _, (resnet, attn, audio_module, motion_module) in enumerate(
1123
+ zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules)
1124
+ ):
1125
+ # pop res hidden states
1126
+ res_hidden_states = res_hidden_states_tuple[-1]
1127
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1128
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1129
+
1130
+ if self.training and self.gradient_checkpointing:
1131
+
1132
+ def create_custom_forward(module, return_dict=None):
1133
+ def custom_forward(*inputs):
1134
+ if return_dict is not None:
1135
+ return module(*inputs, return_dict=return_dict)
1136
+
1137
+ return module(*inputs)
1138
+
1139
+ return custom_forward
1140
+
1141
+ hidden_states = torch.utils.checkpoint.checkpoint(
1142
+ create_custom_forward(resnet), hidden_states, temb
1143
+ )
1144
+
1145
+ motion_frames = []
1146
+ hidden_states, motion_frame = torch.utils.checkpoint.checkpoint(
1147
+ create_custom_forward(attn, return_dict=False),
1148
+ hidden_states,
1149
+ encoder_hidden_states,
1150
+ )
1151
+ if len(motion_frame[0]) > 0:
1152
+ motion_frames = motion_frame[0][0]
1153
+ # motion_frames = torch.cat(motion_frames, dim=0)
1154
+ motion_frames = rearrange(
1155
+ motion_frames,
1156
+ "b f (d1 d2) c -> b c f d1 d2",
1157
+ d1=hidden_states.size(-1),
1158
+ )
1159
+ else:
1160
+ motion_frames = torch.zeros(
1161
+ hidden_states.shape[0],
1162
+ hidden_states.shape[1],
1163
+ 4,
1164
+ hidden_states.shape[3],
1165
+ hidden_states.shape[4],
1166
+ )
1167
+
1168
+ n_motion_frames = motion_frames.size(2)
1169
+
1170
+ if audio_module is not None:
1171
+ # audio_embedding = audio_embedding
1172
+ hidden_states = torch.utils.checkpoint.checkpoint(
1173
+ create_custom_forward(audio_module, return_dict=False),
1174
+ hidden_states,
1175
+ audio_embedding,
1176
+ attention_mask,
1177
+ full_mask,
1178
+ face_mask,
1179
+ lip_mask,
1180
+ motion_scale,
1181
+ )[0]
1182
+
1183
+ # add motion module
1184
+ if motion_module is not None:
1185
+ motion_frames = motion_frames.to(
1186
+ device=hidden_states.device, dtype=hidden_states.dtype
1187
+ )
1188
+
1189
+ _hidden_states = (
1190
+ torch.cat([motion_frames, hidden_states], dim=2)
1191
+ if n_motion_frames > 0
1192
+ else hidden_states
1193
+ )
1194
+ hidden_states = torch.utils.checkpoint.checkpoint(
1195
+ create_custom_forward(motion_module),
1196
+ _hidden_states,
1197
+ encoder_hidden_states,
1198
+ )
1199
+ hidden_states = hidden_states[:, :, n_motion_frames:]
1200
+ else:
1201
+ hidden_states = resnet(hidden_states, temb)
1202
+ hidden_states = attn(
1203
+ hidden_states,
1204
+ encoder_hidden_states=encoder_hidden_states,
1205
+ ).sample
1206
+
1207
+ if audio_module is not None:
1208
+
1209
+ hidden_states = (
1210
+ audio_module(
1211
+ hidden_states,
1212
+ encoder_hidden_states=audio_embedding,
1213
+ attention_mask=attention_mask,
1214
+ full_mask=full_mask,
1215
+ face_mask=face_mask,
1216
+ lip_mask=lip_mask,
1217
+ )
1218
+ ).sample
1219
+ # add motion module
1220
+ hidden_states = (
1221
+ motion_module(
1222
+ hidden_states, encoder_hidden_states=encoder_hidden_states
1223
+ )
1224
+ if motion_module is not None
1225
+ else hidden_states
1226
+ )
1227
+
1228
+ if self.upsamplers is not None:
1229
+ for upsampler in self.upsamplers:
1230
+ hidden_states = upsampler(hidden_states, upsample_size)
1231
+
1232
+ return hidden_states
1233
+
1234
+
1235
+ class UpBlock3D(nn.Module):
1236
+ """
1237
+ 3D upsampling block with cross attention for the U-Net architecture. This block performs
1238
+ upsampling operations and incorporates cross attention mechanisms, which allow the model to
1239
+ focus on different parts of the input when upscaling.
1240
+
1241
+ Parameters:
1242
+ - in_channels (int): Number of input channels.
1243
+ - out_channels (int): Number of output channels.
1244
+ - prev_output_channel (int): Number of channels from the previous layer's output.
1245
+ - temb_channels (int): Number of channels for the temporal embedding.
1246
+ - dropout (float): Dropout rate for the block.
1247
+ - num_layers (int): Number of layers in the block.
1248
+ - resnet_eps (float): Epsilon for residual block stability.
1249
+ - resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding.
1250
+ - resnet_act_fn (str): Activation function used in the residual block.
1251
+ - resnet_groups (int): Number of groups for the convolutions in the residual block.
1252
+ - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block.
1253
+ - attn_num_head_channels (int): Number of attention heads for the cross attention mechanism.
1254
+ - cross_attention_dim (int): Dimensionality of the cross attention layers.
1255
+ - audio_attention_dim (int): Dimensionality of the audio attention layers.
1256
+ - output_scale_factor (float): Scaling factor for the block's output.
1257
+ - add_upsample (bool): Whether to add an upsampling layer.
1258
+ - dual_cross_attention (bool): Whether to use dual cross attention (not implemented).
1259
+ - use_linear_projection (bool): Whether to use linear projection in the cross attention.
1260
+ - only_cross_attention (bool): Whether to use only cross attention (no self-attention).
1261
+ - upcast_attention (bool): Whether to upcast attention to the original input dimension.
1262
+ - unet_use_cross_frame_attention (bool): Whether to use cross frame attention in U-Net.
1263
+ - unet_use_temporal_attention (bool): Whether to use temporal attention in U-Net.
1264
+ - use_motion_module (bool): Whether to include a motion module.
1265
+ - use_inflated_groupnorm (bool): Whether to use inflated group normalization.
1266
+ - motion_module_type (str): Type of motion module to use.
1267
+ - motion_module_kwargs (dict): Keyword arguments for the motion module.
1268
+ - use_audio_module (bool): Whether to include an audio module.
1269
+ - depth (int): Depth of the block in the network.
1270
+ - stack_enable_blocks_name (str): Name of the stack enable blocks.
1271
+ - stack_enable_blocks_depth (int): Depth of the stack enable blocks.
1272
+
1273
+ Forward method:
1274
+ The forward method upsamples the input hidden states and residual hidden states, processes
1275
+ them through the residual and cross attention blocks, and optional motion and audio modules.
1276
+ It supports gradient checkpointing during training.
1277
+ """
1278
+ def __init__(
1279
+ self,
1280
+ in_channels: int,
1281
+ prev_output_channel: int,
1282
+ out_channels: int,
1283
+ temb_channels: int,
1284
+ dropout: float = 0.0,
1285
+ num_layers: int = 1,
1286
+ resnet_eps: float = 1e-6,
1287
+ resnet_time_scale_shift: str = "default",
1288
+ resnet_act_fn: str = "swish",
1289
+ resnet_groups: int = 32,
1290
+ resnet_pre_norm: bool = True,
1291
+ output_scale_factor=1.0,
1292
+ add_upsample=True,
1293
+ use_inflated_groupnorm=None,
1294
+ use_motion_module=None,
1295
+ motion_module_type=None,
1296
+ motion_module_kwargs=None,
1297
+ ):
1298
+ super().__init__()
1299
+ resnets = []
1300
+ motion_modules = []
1301
+
1302
+ # use_motion_module = False
1303
+ for i in range(num_layers):
1304
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1305
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1306
+
1307
+ resnets.append(
1308
+ ResnetBlock3D(
1309
+ in_channels=resnet_in_channels + res_skip_channels,
1310
+ out_channels=out_channels,
1311
+ temb_channels=temb_channels,
1312
+ eps=resnet_eps,
1313
+ groups=resnet_groups,
1314
+ dropout=dropout,
1315
+ time_embedding_norm=resnet_time_scale_shift,
1316
+ non_linearity=resnet_act_fn,
1317
+ output_scale_factor=output_scale_factor,
1318
+ pre_norm=resnet_pre_norm,
1319
+ use_inflated_groupnorm=use_inflated_groupnorm,
1320
+ )
1321
+ )
1322
+ motion_modules.append(
1323
+ get_motion_module(
1324
+ in_channels=out_channels,
1325
+ motion_module_type=motion_module_type,
1326
+ motion_module_kwargs=motion_module_kwargs,
1327
+ )
1328
+ if use_motion_module
1329
+ else None
1330
+ )
1331
+
1332
+ self.resnets = nn.ModuleList(resnets)
1333
+ self.motion_modules = nn.ModuleList(motion_modules)
1334
+
1335
+ if add_upsample:
1336
+ self.upsamplers = nn.ModuleList(
1337
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
1338
+ )
1339
+ else:
1340
+ self.upsamplers = None
1341
+
1342
+ self.gradient_checkpointing = False
1343
+
1344
+ def forward(
1345
+ self,
1346
+ hidden_states,
1347
+ res_hidden_states_tuple,
1348
+ temb=None,
1349
+ upsample_size=None,
1350
+ encoder_hidden_states=None,
1351
+ ):
1352
+ """
1353
+ Forward pass for the UpBlock3D class.
1354
+
1355
+ Args:
1356
+ self (UpBlock3D): An instance of the UpBlock3D class.
1357
+ hidden_states (Tensor): The input hidden states tensor.
1358
+ res_hidden_states_tuple (Tuple[Tensor]): A tuple of residual hidden states tensors.
1359
+ temb (Tensor, optional): The token embeddings tensor. Defaults to None.
1360
+ upsample_size (int, optional): The upsample size. Defaults to None.
1361
+ encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None.
1362
+
1363
+ Returns:
1364
+ Tensor: The output tensor after passing through the UpBlock3D layers.
1365
+ """
1366
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
1367
+ # pop res hidden states
1368
+ res_hidden_states = res_hidden_states_tuple[-1]
1369
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1370
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1371
+
1372
+ # print(f"UpBlock3D {self.gradient_checkpointing = }")
1373
+ if self.training and self.gradient_checkpointing:
1374
+
1375
+ def create_custom_forward(module):
1376
+ def custom_forward(*inputs):
1377
+ return module(*inputs)
1378
+
1379
+ return custom_forward
1380
+
1381
+ hidden_states = torch.utils.checkpoint.checkpoint(
1382
+ create_custom_forward(resnet), hidden_states, temb
1383
+ )
1384
+ else:
1385
+ hidden_states = resnet(hidden_states, temb)
1386
+ hidden_states = (
1387
+ motion_module(
1388
+ hidden_states, encoder_hidden_states=encoder_hidden_states
1389
+ )
1390
+ if motion_module is not None
1391
+ else hidden_states
1392
+ )
1393
+
1394
+ if self.upsamplers is not None:
1395
+ for upsampler in self.upsamplers:
1396
+ hidden_states = upsampler(hidden_states, upsample_size)
1397
+
1398
+ return hidden_states
joyhallo/models/wav2vec.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
3
+ It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
4
+ such as feature extraction and encoding.
5
+
6
+ Classes:
7
+ Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
8
+
9
+ Functions:
10
+ linear_interpolation: Interpolates the features based on the sequence length.
11
+ """
12
+
13
+ import torch.nn.functional as F
14
+ from transformers import Wav2Vec2Model
15
+ from transformers.modeling_outputs import BaseModelOutput
16
+
17
+
18
+ class Wav2VecModel(Wav2Vec2Model):
19
+ """
20
+ Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
21
+ It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
22
+ ...
23
+
24
+ Attributes:
25
+ base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
26
+
27
+ Methods:
28
+ forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
29
+ , output_attentions=None, output_hidden_states=None, return_dict=None):
30
+ Forward pass of the Wav2VecModel.
31
+ It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
32
+
33
+ feature_extract(input_values, seq_len):
34
+ Extracts features from the input_values using the base model.
35
+
36
+ encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
37
+ Encodes the extracted features using the base model and returns the encoded features.
38
+ """
39
+ def forward(
40
+ self,
41
+ input_values,
42
+ seq_len,
43
+ attention_mask=None,
44
+ mask_time_indices=None,
45
+ output_attentions=None,
46
+ output_hidden_states=None,
47
+ return_dict=None,
48
+ ):
49
+ """
50
+ Forward pass of the Wav2Vec model.
51
+
52
+ Args:
53
+ self: The instance of the model.
54
+ input_values: The input values (waveform) to the model.
55
+ seq_len: The sequence length of the input values.
56
+ attention_mask: Attention mask to be used for the model.
57
+ mask_time_indices: Mask indices to be used for the model.
58
+ output_attentions: If set to True, returns attentions.
59
+ output_hidden_states: If set to True, returns hidden states.
60
+ return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
61
+
62
+ Returns:
63
+ The output of the Wav2Vec model.
64
+ """
65
+ self.config.output_attentions = True
66
+
67
+ output_hidden_states = (
68
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
69
+ )
70
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
71
+
72
+ extract_features = self.feature_extractor(input_values)
73
+ extract_features = extract_features.transpose(1, 2)
74
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
75
+
76
+ if attention_mask is not None:
77
+ # compute reduced attention_mask corresponding to feature vectors
78
+ attention_mask = self._get_feature_vector_attention_mask(
79
+ extract_features.shape[1], attention_mask, add_adapter=False
80
+ )
81
+
82
+ hidden_states, extract_features = self.feature_projection(extract_features)
83
+ hidden_states = self._mask_hidden_states(
84
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
85
+ )
86
+
87
+ encoder_outputs = self.encoder(
88
+ hidden_states,
89
+ attention_mask=attention_mask,
90
+ output_attentions=output_attentions,
91
+ output_hidden_states=output_hidden_states,
92
+ return_dict=return_dict,
93
+ )
94
+
95
+ hidden_states = encoder_outputs[0]
96
+
97
+ if self.adapter is not None:
98
+ hidden_states = self.adapter(hidden_states)
99
+
100
+ if not return_dict:
101
+ return (hidden_states, ) + encoder_outputs[1:]
102
+ return BaseModelOutput(
103
+ last_hidden_state=hidden_states,
104
+ hidden_states=encoder_outputs.hidden_states,
105
+ attentions=encoder_outputs.attentions,
106
+ )
107
+
108
+
109
+ def feature_extract(
110
+ self,
111
+ input_values,
112
+ seq_len,
113
+ ):
114
+ """
115
+ Extracts features from the input values and returns the extracted features.
116
+
117
+ Parameters:
118
+ input_values (torch.Tensor): The input values to be processed.
119
+ seq_len (torch.Tensor): The sequence lengths of the input values.
120
+
121
+ Returns:
122
+ extracted_features (torch.Tensor): The extracted features from the input values.
123
+ """
124
+ extract_features = self.feature_extractor(input_values)
125
+ extract_features = extract_features.transpose(1, 2)
126
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
127
+
128
+ return extract_features
129
+
130
+ def encode(
131
+ self,
132
+ extract_features,
133
+ attention_mask=None,
134
+ mask_time_indices=None,
135
+ output_attentions=None,
136
+ output_hidden_states=None,
137
+ return_dict=None,
138
+ ):
139
+ """
140
+ Encodes the input features into the output space.
141
+
142
+ Args:
143
+ extract_features (torch.Tensor): The extracted features from the audio signal.
144
+ attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
145
+ mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
146
+ output_attentions (bool, optional): If set to True, returns the attention weights.
147
+ output_hidden_states (bool, optional): If set to True, returns all hidden states.
148
+ return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
149
+
150
+ Returns:
151
+ The encoded output features.
152
+ """
153
+ self.config.output_attentions = True
154
+
155
+ output_hidden_states = (
156
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
157
+ )
158
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
159
+
160
+ if attention_mask is not None:
161
+ # compute reduced attention_mask corresponding to feature vectors
162
+ attention_mask = self._get_feature_vector_attention_mask(
163
+ extract_features.shape[1], attention_mask, add_adapter=False
164
+ )
165
+
166
+ hidden_states, extract_features = self.feature_projection(extract_features)
167
+ hidden_states = self._mask_hidden_states(
168
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
169
+ )
170
+
171
+ encoder_outputs = self.encoder(
172
+ hidden_states,
173
+ attention_mask=attention_mask,
174
+ output_attentions=output_attentions,
175
+ output_hidden_states=output_hidden_states,
176
+ return_dict=return_dict,
177
+ )
178
+
179
+ hidden_states = encoder_outputs[0]
180
+
181
+ if self.adapter is not None:
182
+ hidden_states = self.adapter(hidden_states)
183
+
184
+ if not return_dict:
185
+ return (hidden_states, ) + encoder_outputs[1:]
186
+ return BaseModelOutput(
187
+ last_hidden_state=hidden_states,
188
+ hidden_states=encoder_outputs.hidden_states,
189
+ attentions=encoder_outputs.attentions,
190
+ )
191
+
192
+
193
+ def linear_interpolation(features, seq_len):
194
+ """
195
+ Transpose the features to interpolate linearly.
196
+
197
+ Args:
198
+ features (torch.Tensor): The extracted features to be interpolated.
199
+ seq_len (torch.Tensor): The sequence lengths of the features.
200
+
201
+ Returns:
202
+ torch.Tensor: The interpolated features.
203
+ """
204
+ features = features.transpose(1, 2)
205
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
206
+ return output_features.transpose(1, 2)
joyhallo/utils/__init__.py ADDED
File without changes
joyhallo/utils/config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides utility functions for configuration manipulation.
3
+ """
4
+
5
+ from typing import Dict
6
+
7
+
8
+ def filter_non_none(dict_obj: Dict):
9
+ """
10
+ Filters out key-value pairs from the given dictionary where the value is None.
11
+
12
+ Args:
13
+ dict_obj (Dict): The dictionary to be filtered.
14
+
15
+ Returns:
16
+ Dict: The dictionary with key-value pairs removed where the value was None.
17
+
18
+ This function creates a new dictionary containing only the key-value pairs from
19
+ the original dictionary where the value is not None. It then clears the original
20
+ dictionary and updates it with the filtered key-value pairs.
21
+ """
22
+ non_none_filter = { k: v for k, v in dict_obj.items() if v is not None }
23
+ dict_obj.clear()
24
+ dict_obj.update(non_none_filter)
25
+ return dict_obj
joyhallo/utils/util.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ utils.py
3
+
4
+ This module provides utility functions for various tasks such as setting random seeds,
5
+ importing modules from files, managing checkpoint files, and saving video files from
6
+ sequences of PIL images.
7
+
8
+ Functions:
9
+ seed_everything(seed)
10
+ import_filename(filename)
11
+ delete_additional_ckpt(base_path, num_keep)
12
+ save_videos_from_pil(pil_images, path, fps=8)
13
+
14
+ Dependencies:
15
+ importlib
16
+ os
17
+ os.path as osp
18
+ random
19
+ shutil
20
+ sys
21
+ pathlib.Path
22
+ av
23
+ cv2
24
+ mediapipe as mp
25
+ numpy as np
26
+ torch
27
+ torchvision
28
+ einops.rearrange
29
+ moviepy.editor.AudioFileClip, VideoClip
30
+ PIL.Image
31
+
32
+ Examples:
33
+ seed_everything(42)
34
+ imported_module = import_filename('path/to/your/module.py')
35
+ delete_additional_ckpt('path/to/checkpoints', 1)
36
+ save_videos_from_pil(pil_images, 'output/video.mp4', fps=12)
37
+
38
+ The functions in this module ensure reproducibility of experiments by seeding random number
39
+ generators, allow dynamic importing of modules, manage checkpoint files by deleting extra ones,
40
+ and provide a way to save sequences of images as video files.
41
+
42
+ Function Details:
43
+ seed_everything(seed)
44
+ Seeds all random number generators to ensure reproducibility.
45
+
46
+ import_filename(filename)
47
+ Imports a module from a given file location.
48
+
49
+ delete_additional_ckpt(base_path, num_keep)
50
+ Deletes additional checkpoint files in the given directory.
51
+
52
+ save_videos_from_pil(pil_images, path, fps=8)
53
+ Saves a sequence of images as a video using the Pillow library.
54
+
55
+ Attributes:
56
+ _ (str): Placeholder for static type checking
57
+ """
58
+
59
+ import importlib
60
+ import os
61
+ import os.path as osp
62
+ import random
63
+ import shutil
64
+ import subprocess
65
+ import sys
66
+ from pathlib import Path
67
+ from typing import List
68
+
69
+ import av
70
+ import cv2
71
+ import mediapipe as mp
72
+ import numpy as np
73
+ import torch
74
+ import torchvision
75
+ from einops import rearrange
76
+ from moviepy.editor import AudioFileClip, VideoClip
77
+ from PIL import Image
78
+
79
+
80
+ def seed_everything(seed):
81
+ """
82
+ Seeds all random number generators to ensure reproducibility.
83
+
84
+ Args:
85
+ seed (int): The seed value to set for all random number generators.
86
+ """
87
+ torch.manual_seed(seed)
88
+ torch.cuda.manual_seed_all(seed)
89
+ np.random.seed(seed % (2**32))
90
+ random.seed(seed)
91
+
92
+
93
+ def import_filename(filename):
94
+ """
95
+ Import a module from a given file location.
96
+
97
+ Args:
98
+ filename (str): The path to the file containing the module to be imported.
99
+
100
+ Returns:
101
+ module: The imported module.
102
+
103
+ Raises:
104
+ ImportError: If the module cannot be imported.
105
+
106
+ Example:
107
+ >>> imported_module = import_filename('path/to/your/module.py')
108
+ """
109
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
110
+ module = importlib.util.module_from_spec(spec)
111
+ sys.modules[spec.name] = module
112
+ spec.loader.exec_module(module)
113
+ return module
114
+
115
+
116
+ def delete_additional_ckpt(base_path, num_keep):
117
+ """
118
+ Deletes additional checkpoint files in the given directory.
119
+
120
+ Args:
121
+ base_path (str): The path to the directory containing the checkpoint files.
122
+ num_keep (int): The number of most recent checkpoint files to keep.
123
+
124
+ Returns:
125
+ None
126
+
127
+ Raises:
128
+ FileNotFoundError: If the base_path does not exist.
129
+
130
+ Example:
131
+ >>> delete_additional_ckpt('path/to/checkpoints', 1)
132
+ # This will delete all but the most recent checkpoint file in 'path/to/checkpoints'.
133
+ """
134
+ dirs = []
135
+ for d in os.listdir(base_path):
136
+ if d.startswith("checkpoint-"):
137
+ dirs.append(d)
138
+ num_tot = len(dirs)
139
+ if num_tot <= num_keep:
140
+ return
141
+ # ensure ckpt is sorted and delete the ealier!
142
+ del_dirs = sorted(dirs, key=lambda x: int(
143
+ x.split("-")[-1]))[: num_tot - num_keep]
144
+ for d in del_dirs:
145
+ path_to_dir = osp.join(base_path, d)
146
+ if osp.exists(path_to_dir):
147
+ shutil.rmtree(path_to_dir)
148
+
149
+
150
+ def save_videos_from_pil(pil_images, path, fps=8):
151
+ """
152
+ Save a sequence of images as a video using the Pillow library.
153
+
154
+ Args:
155
+ pil_images (List[PIL.Image]): A list of PIL.Image objects representing the frames of the video.
156
+ path (str): The output file path for the video.
157
+ fps (int, optional): The frames per second rate of the video. Defaults to 8.
158
+
159
+ Returns:
160
+ None
161
+
162
+ Raises:
163
+ ValueError: If the save format is not supported.
164
+
165
+ This function takes a list of PIL.Image objects and saves them as a video file with a specified frame rate.
166
+ The output file format is determined by the file extension of the provided path. Supported formats include
167
+ .mp4, .avi, and .mkv. The function uses the Pillow library to handle the image processing and video
168
+ creation.
169
+ """
170
+ save_fmt = Path(path).suffix
171
+ os.makedirs(os.path.dirname(path), exist_ok=True)
172
+ width, height = pil_images[0].size
173
+
174
+ if save_fmt == ".mp4":
175
+ codec = "libx264"
176
+ container = av.open(path, "w")
177
+ stream = container.add_stream(codec, rate=fps)
178
+
179
+ stream.width = width
180
+ stream.height = height
181
+
182
+ for pil_image in pil_images:
183
+ # pil_image = Image.fromarray(image_arr).convert("RGB")
184
+ av_frame = av.VideoFrame.from_image(pil_image)
185
+ container.mux(stream.encode(av_frame))
186
+ container.mux(stream.encode())
187
+ container.close()
188
+
189
+ elif save_fmt == ".gif":
190
+ pil_images[0].save(
191
+ fp=path,
192
+ format="GIF",
193
+ append_images=pil_images[1:],
194
+ save_all=True,
195
+ duration=(1 / fps * 1000),
196
+ loop=0,
197
+ )
198
+ else:
199
+ raise ValueError("Unsupported file type. Use .mp4 or .gif.")
200
+
201
+
202
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
203
+ """
204
+ Save a grid of videos as an animation or video.
205
+
206
+ Args:
207
+ videos (torch.Tensor): A tensor of shape (batch_size, channels, time, height, width)
208
+ containing the videos to save.
209
+ path (str): The path to save the video grid. Supported formats are .mp4, .avi, and .gif.
210
+ rescale (bool, optional): If True, rescale the video to the original resolution.
211
+ Defaults to False.
212
+ n_rows (int, optional): The number of rows in the video grid. Defaults to 6.
213
+ fps (int, optional): The frame rate of the saved video. Defaults to 8.
214
+
215
+ Raises:
216
+ ValueError: If the video format is not supported.
217
+
218
+ Returns:
219
+ None
220
+ """
221
+ videos = rearrange(videos, "b c t h w -> t b c h w")
222
+ # height, width = videos.shape[-2:]
223
+ outputs = []
224
+
225
+ for x in videos:
226
+ x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
227
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
228
+ if rescale:
229
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
230
+ x = (x * 255).numpy().astype(np.uint8)
231
+ x = Image.fromarray(x)
232
+
233
+ outputs.append(x)
234
+
235
+ os.makedirs(os.path.dirname(path), exist_ok=True)
236
+
237
+ save_videos_from_pil(outputs, path, fps)
238
+
239
+
240
+ def read_frames(video_path):
241
+ """
242
+ Reads video frames from a given video file.
243
+
244
+ Args:
245
+ video_path (str): The path to the video file.
246
+
247
+ Returns:
248
+ container (av.container.InputContainer): The input container object
249
+ containing the video stream.
250
+
251
+ Raises:
252
+ FileNotFoundError: If the video file is not found.
253
+ RuntimeError: If there is an error in reading the video stream.
254
+
255
+ The function reads the video frames from the specified video file using the
256
+ Python AV library (av). It returns an input container object that contains
257
+ the video stream. If the video file is not found, it raises a FileNotFoundError,
258
+ and if there is an error in reading the video stream, it raises a RuntimeError.
259
+ """
260
+ container = av.open(video_path)
261
+
262
+ video_stream = next(s for s in container.streams if s.type == "video")
263
+ frames = []
264
+ for packet in container.demux(video_stream):
265
+ for frame in packet.decode():
266
+ image = Image.frombytes(
267
+ "RGB",
268
+ (frame.width, frame.height),
269
+ frame.to_rgb().to_ndarray(),
270
+ )
271
+ frames.append(image)
272
+
273
+ return frames
274
+
275
+
276
+ def get_fps(video_path):
277
+ """
278
+ Get the frame rate (FPS) of a video file.
279
+
280
+ Args:
281
+ video_path (str): The path to the video file.
282
+
283
+ Returns:
284
+ int: The frame rate (FPS) of the video file.
285
+ """
286
+ container = av.open(video_path)
287
+ video_stream = next(s for s in container.streams if s.type == "video")
288
+ fps = video_stream.average_rate
289
+ container.close()
290
+ return fps
291
+
292
+
293
+ def tensor_to_video(tensor, output_video_file, audio_source, fps=25):
294
+ """
295
+ Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file.
296
+
297
+ Args:
298
+ tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w].
299
+ output_video_file (str): The file path where the output video will be saved.
300
+ audio_source (str): The path to the audio file (WAV file) that contains the audio track to be added.
301
+ fps (int): The frame rate of the output video. Default is 25 fps.
302
+ """
303
+ tensor = tensor.permute(1, 2, 3, 0).cpu(
304
+ ).numpy() # convert to [f, h, w, c]
305
+ tensor = np.clip(tensor * 255, 0, 255).astype(
306
+ np.uint8
307
+ ) # to [0, 255]
308
+
309
+ def make_frame(t):
310
+ # get index
311
+ frame_index = min(int(t * fps), tensor.shape[0] - 1)
312
+ return tensor[frame_index]
313
+ new_video_clip = VideoClip(make_frame, duration=tensor.shape[0] / fps)
314
+ audio_clip = AudioFileClip(audio_source).subclip(0, tensor.shape[0] / fps)
315
+ new_video_clip = new_video_clip.set_audio(audio_clip)
316
+ new_video_clip.write_videofile(output_video_file, fps=fps, audio_codec='aac')
317
+
318
+
319
+ silhouette_ids = [
320
+ 10, 338, 297, 332, 284, 251, 389, 356, 454, 323, 361, 288,
321
+ 397, 365, 379, 378, 400, 377, 152, 148, 176, 149, 150, 136,
322
+ 172, 58, 132, 93, 234, 127, 162, 21, 54, 103, 67, 109
323
+ ]
324
+ lip_ids = [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291,
325
+ 146, 91, 181, 84, 17, 314, 405, 321, 375]
326
+
327
+
328
+ def compute_face_landmarks(detection_result, h, w):
329
+ """
330
+ Compute face landmarks from a detection result.
331
+
332
+ Args:
333
+ detection_result (mediapipe.solutions.face_mesh.FaceMesh): The detection result containing face landmarks.
334
+ h (int): The height of the video frame.
335
+ w (int): The width of the video frame.
336
+
337
+ Returns:
338
+ face_landmarks_list (list): A list of face landmarks.
339
+ """
340
+ face_landmarks_list = detection_result.face_landmarks
341
+ if len(face_landmarks_list) != 1:
342
+ print("#face is invalid:", len(face_landmarks_list))
343
+ return []
344
+ return [[p.x * w, p.y * h] for p in face_landmarks_list[0]]
345
+
346
+
347
+ def get_landmark(file):
348
+ """
349
+ This function takes a file as input and returns the facial landmarks detected in the file.
350
+
351
+ Args:
352
+ file (str): The path to the file containing the video or image to be processed.
353
+
354
+ Returns:
355
+ Tuple[List[float], List[float]]: A tuple containing two lists of floats representing the x and y coordinates of the facial landmarks.
356
+ """
357
+ model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task"
358
+ BaseOptions = mp.tasks.BaseOptions
359
+ FaceLandmarker = mp.tasks.vision.FaceLandmarker
360
+ FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
361
+ VisionRunningMode = mp.tasks.vision.RunningMode
362
+ # Create a face landmarker instance with the video mode:
363
+ options = FaceLandmarkerOptions(
364
+ base_options=BaseOptions(model_asset_path=model_path),
365
+ running_mode=VisionRunningMode.IMAGE,
366
+ )
367
+
368
+ with FaceLandmarker.create_from_options(options) as landmarker:
369
+ image = mp.Image.create_from_file(str(file))
370
+ height, width = image.height, image.width
371
+ face_landmarker_result = landmarker.detect(image)
372
+ face_landmark = compute_face_landmarks(
373
+ face_landmarker_result, height, width)
374
+
375
+ return np.array(face_landmark), height, width
376
+
377
+
378
+ def get_landmark_overframes(landmark_model, frames_path):
379
+ """
380
+ This function iterate frames and returns the facial landmarks detected in each frame.
381
+
382
+ Args:
383
+ landmark_model: mediapipe landmark model instance
384
+ frames_path (str): The path to the video frames.
385
+
386
+ Returns:
387
+ List[List[float], float, float]: A List containing two lists of floats representing the x and y coordinates of the facial landmarks.
388
+ """
389
+
390
+ face_landmarks = []
391
+
392
+ for file in sorted(os.listdir(frames_path)):
393
+ image = mp.Image.create_from_file(os.path.join(frames_path, file))
394
+ height, width = image.height, image.width
395
+ landmarker_result = landmark_model.detect(image)
396
+ frame_landmark = compute_face_landmarks(
397
+ landmarker_result, height, width)
398
+ face_landmarks.append(frame_landmark)
399
+
400
+ return face_landmarks, height, width
401
+
402
+
403
+ def get_lip_mask(landmarks, height, width, out_path=None, expand_ratio=2.0):
404
+ """
405
+ Extracts the lip region from the given landmarks and saves it as an image.
406
+
407
+ Parameters:
408
+ landmarks (numpy.ndarray): Array of facial landmarks.
409
+ height (int): Height of the output lip mask image.
410
+ width (int): Width of the output lip mask image.
411
+ out_path (pathlib.Path): Path to save the lip mask image.
412
+ expand_ratio (float): Expand ratio of mask.
413
+ """
414
+ lip_landmarks = np.take(landmarks, lip_ids, 0)
415
+ min_xy_lip = np.round(np.min(lip_landmarks, 0))
416
+ max_xy_lip = np.round(np.max(lip_landmarks, 0))
417
+ min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1] = expand_region(
418
+ [min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, expand_ratio)
419
+ lip_mask = np.zeros((height, width), dtype=np.uint8)
420
+ lip_mask[round(min_xy_lip[1]):round(max_xy_lip[1]),
421
+ round(min_xy_lip[0]):round(max_xy_lip[0])] = 255
422
+ if out_path:
423
+ cv2.imwrite(str(out_path), lip_mask)
424
+ return None
425
+
426
+ return lip_mask
427
+
428
+
429
+ def get_union_lip_mask(landmarks, height, width, expand_ratio=1):
430
+ """
431
+ Extracts the lip region from the given landmarks and saves it as an image.
432
+
433
+ Parameters:
434
+ landmarks (numpy.ndarray): Array of facial landmarks.
435
+ height (int): Height of the output lip mask image.
436
+ width (int): Width of the output lip mask image.
437
+ expand_ratio (float): Expand ratio of mask.
438
+ """
439
+ lip_masks = []
440
+ for landmark in landmarks:
441
+ lip_masks.append(get_lip_mask(landmarks=landmark, height=height,
442
+ width=width, expand_ratio=expand_ratio))
443
+ union_mask = get_union_mask(lip_masks)
444
+ return union_mask
445
+
446
+
447
+ def get_face_mask(landmarks, height, width, out_path=None, expand_ratio=1.2):
448
+ """
449
+ Generate a face mask based on the given landmarks.
450
+
451
+ Args:
452
+ landmarks (numpy.ndarray): The landmarks of the face.
453
+ height (int): The height of the output face mask image.
454
+ width (int): The width of the output face mask image.
455
+ out_path (pathlib.Path): The path to save the face mask image.
456
+ expand_ratio (float): Expand ratio of mask.
457
+ Returns:
458
+ None. The face mask image is saved at the specified path.
459
+ """
460
+ face_landmarks = np.take(landmarks, silhouette_ids, 0)
461
+ min_xy_face = np.round(np.min(face_landmarks, 0))
462
+ max_xy_face = np.round(np.max(face_landmarks, 0))
463
+ min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1] = expand_region(
464
+ [min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1]], width, height, expand_ratio)
465
+ face_mask = np.zeros((height, width), dtype=np.uint8)
466
+ face_mask[round(min_xy_face[1]):round(max_xy_face[1]),
467
+ round(min_xy_face[0]):round(max_xy_face[0])] = 255
468
+ if out_path:
469
+ cv2.imwrite(str(out_path), face_mask)
470
+ return None
471
+
472
+ return face_mask
473
+
474
+
475
+ def get_union_face_mask(landmarks, height, width, expand_ratio=1):
476
+ """
477
+ Generate a face mask based on the given landmarks.
478
+
479
+ Args:
480
+ landmarks (numpy.ndarray): The landmarks of the face.
481
+ height (int): The height of the output face mask image.
482
+ width (int): The width of the output face mask image.
483
+ expand_ratio (float): Expand ratio of mask.
484
+ Returns:
485
+ None. The face mask image is saved at the specified path.
486
+ """
487
+ face_masks = []
488
+ for landmark in landmarks:
489
+ face_masks.append(get_face_mask(landmarks=landmark,height=height,width=width,expand_ratio=expand_ratio))
490
+ union_mask = get_union_mask(face_masks)
491
+ return union_mask
492
+
493
+ def get_mask(file, cache_dir, face_expand_raio):
494
+ """
495
+ Generate a face mask based on the given landmarks and save it to the specified cache directory.
496
+
497
+ Args:
498
+ file (str): The path to the file containing the landmarks.
499
+ cache_dir (str): The directory to save the generated face mask.
500
+
501
+ Returns:
502
+ None
503
+ """
504
+ landmarks, height, width = get_landmark(file)
505
+ file_name = os.path.basename(file).split(".")[0]
506
+ get_lip_mask(landmarks, height, width, os.path.join(
507
+ cache_dir, f"{file_name}_lip_mask.png"))
508
+ get_face_mask(landmarks, height, width, os.path.join(
509
+ cache_dir, f"{file_name}_face_mask.png"), face_expand_raio)
510
+ get_blur_mask(os.path.join(
511
+ cache_dir, f"{file_name}_face_mask.png"), os.path.join(
512
+ cache_dir, f"{file_name}_face_mask_blur.png"), kernel_size=(51, 51))
513
+ get_blur_mask(os.path.join(
514
+ cache_dir, f"{file_name}_lip_mask.png"), os.path.join(
515
+ cache_dir, f"{file_name}_sep_lip.png"), kernel_size=(31, 31))
516
+ get_background_mask(os.path.join(
517
+ cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join(
518
+ cache_dir, f"{file_name}_sep_background.png"))
519
+ get_sep_face_mask(os.path.join(
520
+ cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join(
521
+ cache_dir, f"{file_name}_sep_lip.png"), os.path.join(
522
+ cache_dir, f"{file_name}_sep_face.png"))
523
+
524
+
525
+ def expand_region(region, image_w, image_h, expand_ratio=1.0):
526
+ """
527
+ Expand the given region by a specified ratio.
528
+ Args:
529
+ region (tuple): A tuple containing the coordinates (min_x, max_x, min_y, max_y) of the region.
530
+ image_w (int): The width of the image.
531
+ image_h (int): The height of the image.
532
+ expand_ratio (float, optional): The ratio by which the region should be expanded. Defaults to 1.0.
533
+
534
+ Returns:
535
+ tuple: A tuple containing the expanded coordinates (min_x, max_x, min_y, max_y) of the region.
536
+ """
537
+
538
+ min_x, max_x, min_y, max_y = region
539
+ mid_x = (max_x + min_x) // 2
540
+ side_len_x = (max_x - min_x) * expand_ratio
541
+ mid_y = (max_y + min_y) // 2
542
+ side_len_y = (max_y - min_y) * expand_ratio
543
+ min_x = mid_x - side_len_x // 2
544
+ max_x = mid_x + side_len_x // 2
545
+ min_y = mid_y - side_len_y // 2
546
+ max_y = mid_y + side_len_y // 2
547
+ if min_x < 0:
548
+ max_x -= min_x
549
+ min_x = 0
550
+ if max_x > image_w:
551
+ min_x -= max_x - image_w
552
+ max_x = image_w
553
+ if min_y < 0:
554
+ max_y -= min_y
555
+ min_y = 0
556
+ if max_y > image_h:
557
+ min_y -= max_y - image_h
558
+ max_y = image_h
559
+
560
+ return round(min_x), round(max_x), round(min_y), round(max_y)
561
+
562
+
563
+ def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size=(101, 101)):
564
+ """
565
+ Read, resize, blur, normalize, and save an image.
566
+
567
+ Parameters:
568
+ file_path (str): Path to the input image file.
569
+ output_dir (str): Path to the output directory to save blurred images.
570
+ resize_dim (tuple): Dimensions to resize the images to.
571
+ kernel_size (tuple): Size of the kernel to use for Gaussian blur.
572
+ """
573
+ # Read the mask image
574
+ mask = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
575
+
576
+ # Check if the image is loaded successfully
577
+ if mask is not None:
578
+ normalized_mask = blur_mask(mask,resize_dim=resize_dim,kernel_size=kernel_size)
579
+ # Save the normalized mask image
580
+ cv2.imwrite(output_file_path, normalized_mask)
581
+ return f"Processed, normalized, and saved: {output_file_path}"
582
+ return f"Failed to load image: {file_path}"
583
+
584
+
585
+ def blur_mask(mask, resize_dim=(64, 64), kernel_size=(51, 51)):
586
+ """
587
+ Read, resize, blur, normalize, and save an image.
588
+
589
+ Parameters:
590
+ file_path (str): Path to the input image file.
591
+ resize_dim (tuple): Dimensions to resize the images to.
592
+ kernel_size (tuple): Size of the kernel to use for Gaussian blur.
593
+ """
594
+ # Check if the image is loaded successfully
595
+ normalized_mask = None
596
+ if mask is not None:
597
+ # Resize the mask image
598
+ resized_mask = cv2.resize(mask, resize_dim)
599
+ # Apply Gaussian blur to the resized mask image
600
+ blurred_mask = cv2.GaussianBlur(resized_mask, kernel_size, 0)
601
+ # Normalize the blurred image
602
+ normalized_mask = cv2.normalize(
603
+ blurred_mask, None, 0, 255, cv2.NORM_MINMAX)
604
+ # Save the normalized mask image
605
+ return normalized_mask
606
+
607
+ def get_background_mask(file_path, output_file_path):
608
+ """
609
+ Read an image, invert its values, and save the result.
610
+
611
+ Parameters:
612
+ file_path (str): Path to the input image file.
613
+ output_dir (str): Path to the output directory to save the inverted image.
614
+ """
615
+ # Read the image
616
+ image = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
617
+
618
+ if image is None:
619
+ print(f"Failed to load image: {file_path}")
620
+ return
621
+
622
+ # Invert the image
623
+ inverted_image = 1.0 - (
624
+ image / 255.0
625
+ ) # Assuming the image values are in [0, 255] range
626
+ # Convert back to uint8
627
+ inverted_image = (inverted_image * 255).astype(np.uint8)
628
+
629
+ # Save the inverted image
630
+ cv2.imwrite(output_file_path, inverted_image)
631
+ print(f"Processed and saved: {output_file_path}")
632
+
633
+
634
+ def get_sep_face_mask(file_path1, file_path2, output_file_path):
635
+ """
636
+ Read two images, subtract the second one from the first, and save the result.
637
+
638
+ Parameters:
639
+ output_dir (str): Path to the output directory to save the subtracted image.
640
+ """
641
+
642
+ # Read the images
643
+ mask1 = cv2.imread(file_path1, cv2.IMREAD_GRAYSCALE)
644
+ mask2 = cv2.imread(file_path2, cv2.IMREAD_GRAYSCALE)
645
+
646
+ if mask1 is None or mask2 is None:
647
+ print(f"Failed to load images: {file_path1}")
648
+ return
649
+
650
+ # Ensure the images are the same size
651
+ if mask1.shape != mask2.shape:
652
+ print(
653
+ f"Image shapes do not match for {file_path1}: {mask1.shape} vs {mask2.shape}"
654
+ )
655
+ return
656
+
657
+ # Subtract the second mask from the first
658
+ result_mask = cv2.subtract(mask1, mask2)
659
+
660
+ # Save the result mask image
661
+ cv2.imwrite(output_file_path, result_mask)
662
+ print(f"Processed and saved: {output_file_path}")
663
+
664
+ def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int):
665
+ p = subprocess.Popen([
666
+ "ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file
667
+ ])
668
+ ret = p.wait()
669
+ assert ret == 0, "Resample audio failed!"
670
+ return output_audio_file
671
+
672
+ def get_face_region(image_path: str, detector):
673
+ try:
674
+ image = cv2.imread(image_path)
675
+ if image is None:
676
+ print(f"Failed to open image: {image_path}. Skipping...")
677
+ return None, None
678
+
679
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
680
+ detection_result = detector.detect(mp_image)
681
+
682
+ # Adjust mask creation for the three-channel image
683
+ mask = np.zeros_like(image, dtype=np.uint8)
684
+
685
+ for detection in detection_result.detections:
686
+ bbox = detection.bounding_box
687
+ start_point = (int(bbox.origin_x), int(bbox.origin_y))
688
+ end_point = (int(bbox.origin_x + bbox.width),
689
+ int(bbox.origin_y + bbox.height))
690
+ cv2.rectangle(mask, start_point, end_point,
691
+ (255, 255, 255), thickness=-1)
692
+
693
+ save_path = image_path.replace("images", "face_masks")
694
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
695
+ cv2.imwrite(save_path, mask)
696
+ # print(f"Processed and saved {save_path}")
697
+ return image_path, mask
698
+ except Exception as e:
699
+ print(f"Error processing image {image_path}: {e}")
700
+ return None, None
701
+
702
+
703
+ def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ckpt_num: int, total_limit: int = -1) -> None:
704
+ """
705
+ Save the model's state_dict to a checkpoint file.
706
+
707
+ If `total_limit` is provided, this function will remove the oldest checkpoints
708
+ until the total number of checkpoints is less than the specified limit.
709
+
710
+ Args:
711
+ model (nn.Module): The model whose state_dict is to be saved.
712
+ save_dir (str): The directory where the checkpoint will be saved.
713
+ prefix (str): The prefix for the checkpoint file name.
714
+ ckpt_num (int): The checkpoint number to be saved.
715
+ total_limit (int, optional): The maximum number of checkpoints to keep.
716
+ Defaults to None, in which case no checkpoints will be removed.
717
+
718
+ Raises:
719
+ FileNotFoundError: If the save directory does not exist.
720
+ ValueError: If the checkpoint number is negative.
721
+ OSError: If there is an error saving the checkpoint.
722
+ """
723
+
724
+ if not osp.exists(save_dir):
725
+ raise FileNotFoundError(
726
+ f"The save directory {save_dir} does not exist.")
727
+
728
+ if ckpt_num < 0:
729
+ raise ValueError(f"Checkpoint number {ckpt_num} must be non-negative.")
730
+
731
+ save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")
732
+
733
+ if total_limit > 0:
734
+ checkpoints = os.listdir(save_dir)
735
+ checkpoints = [d for d in checkpoints if d.startswith(prefix)]
736
+ checkpoints = sorted(
737
+ checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
738
+ )
739
+
740
+ if len(checkpoints) >= total_limit:
741
+ num_to_remove = len(checkpoints) - total_limit + 1
742
+ removing_checkpoints = checkpoints[0:num_to_remove]
743
+ print(
744
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
745
+ )
746
+ print(
747
+ f"Removing checkpoints: {', '.join(removing_checkpoints)}"
748
+ )
749
+
750
+ for removing_checkpoint in removing_checkpoints:
751
+ removing_checkpoint_path = osp.join(
752
+ save_dir, removing_checkpoint)
753
+ try:
754
+ os.remove(removing_checkpoint_path)
755
+ except OSError as e:
756
+ print(
757
+ f"Error removing checkpoint {removing_checkpoint_path}: {e}")
758
+
759
+ state_dict = model.state_dict()
760
+ try:
761
+ torch.save(state_dict, save_path)
762
+ print(f"Checkpoint saved at {save_path}")
763
+ except OSError as e:
764
+ raise OSError(f"Error saving checkpoint at {save_path}: {e}") from e
765
+
766
+
767
+ def init_output_dir(dir_list: List[str]):
768
+ """
769
+ Initialize the output directories.
770
+
771
+ This function creates the directories specified in the `dir_list`. If a directory already exists, it does nothing.
772
+
773
+ Args:
774
+ dir_list (List[str]): List of directory paths to create.
775
+ """
776
+ for path in dir_list:
777
+ os.makedirs(path, exist_ok=True)
778
+
779
+
780
+ def load_checkpoint(cfg, save_dir, accelerator):
781
+ """
782
+ Load the most recent checkpoint from the specified directory.
783
+
784
+ This function loads the latest checkpoint from the `save_dir` if the `resume_from_checkpoint` parameter is set to "latest".
785
+ If a specific checkpoint is provided in `resume_from_checkpoint`, it loads that checkpoint. If no checkpoint is found,
786
+ it starts training from scratch.
787
+
788
+ Args:
789
+ cfg: The configuration object containing training parameters.
790
+ save_dir (str): The directory where checkpoints are saved.
791
+ accelerator: The accelerator object for distributed training.
792
+
793
+ Returns:
794
+ int: The global step at which to resume training.
795
+ """
796
+ if cfg.resume_from_checkpoint != "latest":
797
+ resume_dir = cfg.resume_from_checkpoint
798
+ else:
799
+ resume_dir = save_dir
800
+ # Get the most recent checkpoint
801
+ dirs = os.listdir(resume_dir)
802
+
803
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
804
+ if len(dirs) > 0:
805
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
806
+ path = dirs[-1]
807
+ accelerator.load_state(os.path.join(resume_dir, path))
808
+ accelerator.print(f"Resuming from checkpoint {path}")
809
+ global_step = int(path.split("-")[1])
810
+ else:
811
+ accelerator.print(
812
+ f"Could not find checkpoint under {resume_dir}, start training from scratch")
813
+ global_step = 0
814
+
815
+ return global_step
816
+
817
+
818
+ def compute_snr(noise_scheduler, timesteps):
819
+ """
820
+ Computes SNR as per
821
+ https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
822
+ 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
823
+ """
824
+ alphas_cumprod = noise_scheduler.alphas_cumprod
825
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
826
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
827
+
828
+ # Expand the tensors.
829
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
830
+ # 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
831
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
832
+ timesteps
833
+ ].float()
834
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
835
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
836
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
837
+
838
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
839
+ device=timesteps.device
840
+ )[timesteps].float()
841
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
842
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
843
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
844
+
845
+ # Compute SNR.
846
+ snr = (alpha / sigma) ** 2
847
+ return snr
848
+
849
+
850
+ def extract_audio_from_videos(video_path: Path, audio_output_path: Path) -> Path:
851
+ """
852
+ Extract audio from a video file and save it as a WAV file.
853
+
854
+ This function uses ffmpeg to extract the audio stream from a given video file and saves it as a WAV file
855
+ in the specified output directory.
856
+
857
+ Args:
858
+ video_path (Path): The path to the input video file.
859
+ output_dir (Path): The directory where the extracted audio file will be saved.
860
+
861
+ Returns:
862
+ Path: The path to the extracted audio file.
863
+
864
+ Raises:
865
+ subprocess.CalledProcessError: If the ffmpeg command fails to execute.
866
+ """
867
+ ffmpeg_command = [
868
+ 'ffmpeg', '-y',
869
+ '-i', str(video_path),
870
+ '-vn', '-acodec',
871
+ "pcm_s16le", '-ar', '16000', '-ac', '2',
872
+ str(audio_output_path)
873
+ ]
874
+
875
+ try:
876
+ print(f"Running command: {' '.join(ffmpeg_command)}")
877
+ subprocess.run(ffmpeg_command, check=True)
878
+ except subprocess.CalledProcessError as e:
879
+ print(f"Error extracting audio from video: {e}")
880
+ raise
881
+
882
+ return audio_output_path
883
+
884
+
885
+ def convert_video_to_images(video_path: Path, output_dir: Path) -> Path:
886
+ """
887
+ Convert a video file into a sequence of images.
888
+
889
+ This function uses ffmpeg to convert each frame of the given video file into an image. The images are saved
890
+ in a directory named after the video file stem under the specified output directory.
891
+
892
+ Args:
893
+ video_path (Path): The path to the input video file.
894
+ output_dir (Path): The directory where the extracted images will be saved.
895
+
896
+ Returns:
897
+ Path: The path to the directory containing the extracted images.
898
+
899
+ Raises:
900
+ subprocess.CalledProcessError: If the ffmpeg command fails to execute.
901
+ """
902
+ ffmpeg_command = [
903
+ 'ffmpeg',
904
+ '-i', str(video_path),
905
+ '-vf', 'fps=25',
906
+ str(output_dir / '%04d.png')
907
+ ]
908
+
909
+ try:
910
+ print(f"Running command: {' '.join(ffmpeg_command)}")
911
+ subprocess.run(ffmpeg_command, check=True)
912
+ except subprocess.CalledProcessError as e:
913
+ print(f"Error converting video to images: {e}")
914
+ raise
915
+
916
+ return output_dir
917
+
918
+
919
+ def get_union_mask(masks):
920
+ """
921
+ Compute the union of a list of masks.
922
+
923
+ This function takes a list of masks and computes their union by taking the maximum value at each pixel location.
924
+ Additionally, it finds the bounding box of the non-zero regions in the mask and sets the bounding box area to white.
925
+
926
+ Args:
927
+ masks (list of np.ndarray): List of masks to be combined.
928
+
929
+ Returns:
930
+ np.ndarray: The union of the input masks.
931
+ """
932
+ union_mask = None
933
+ for mask in masks:
934
+ if union_mask is None:
935
+ union_mask = mask
936
+ else:
937
+ union_mask = np.maximum(union_mask, mask)
938
+
939
+ if union_mask is not None:
940
+ # Find the bounding box of the non-zero regions in the mask
941
+ rows = np.any(union_mask, axis=1)
942
+ cols = np.any(union_mask, axis=0)
943
+ try:
944
+ ymin, ymax = np.where(rows)[0][[0, -1]]
945
+ xmin, xmax = np.where(cols)[0][[0, -1]]
946
+ except Exception as e:
947
+ print(str(e))
948
+ return 0.0
949
+
950
+ # Set bounding box area to white
951
+ union_mask[ymin: ymax + 1, xmin: xmax + 1] = np.max(union_mask)
952
+
953
+ return union_mask
954
+
955
+
956
+ def move_final_checkpoint(save_dir, module_dir, prefix):
957
+ """
958
+ Move the final checkpoint file to the save directory.
959
+
960
+ This function identifies the latest checkpoint file based on the given prefix and moves it to the specified save directory.
961
+
962
+ Args:
963
+ save_dir (str): The directory where the final checkpoint file should be saved.
964
+ module_dir (str): The directory containing the checkpoint files.
965
+ prefix (str): The prefix used to identify checkpoint files.
966
+
967
+ Raises:
968
+ ValueError: If no checkpoint files are found with the specified prefix.
969
+ """
970
+ checkpoints = os.listdir(module_dir)
971
+ checkpoints = [d for d in checkpoints if d.startswith(prefix)]
972
+ checkpoints = sorted(
973
+ checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
974
+ )
975
+ shutil.copy2(os.path.join(
976
+ module_dir, checkpoints[-1]), os.path.join(save_dir, prefix + '.pth'))
scripts/inference.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script is a gradio web ui.
3
+
4
+ The script takes an image and an audio clip, and lets you configure all the
5
+ variables such as cfg_scale, pose_weight, face_weight, lip_weight, etc.
6
+
7
+ Usage:
8
+ This script can be run from the command line with the following command:
9
+
10
+ python scripts/app.py
11
+ """
12
+
13
+ import gradio as gr
14
+ import argparse
15
+ import copy
16
+ import logging
17
+ import math
18
+ import os
19
+ import random
20
+ import time
21
+ import warnings
22
+ from datetime import datetime
23
+ from typing import List, Tuple
24
+
25
+ import diffusers
26
+ import mlflow
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ import transformers
31
+ from accelerate import Accelerator
32
+ from accelerate.logging import get_logger
33
+ from accelerate.utils import DistributedDataParallelKwargs
34
+ from diffusers import AutoencoderKL, DDIMScheduler
35
+ from diffusers.optimization import get_scheduler
36
+ from diffusers.utils import check_min_version
37
+ from diffusers.utils.import_utils import is_xformers_available
38
+ from einops import rearrange, repeat
39
+ from omegaconf import OmegaConf
40
+ from torch import nn
41
+ from tqdm.auto import tqdm
42
+ import uuid
43
+
44
+ import sys
45
+ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
46
+
47
+ from joyhallo.animate.face_animate import FaceAnimatePipeline
48
+ from joyhallo.datasets.audio_processor import AudioProcessor
49
+ from joyhallo.datasets.image_processor import ImageProcessor
50
+ from joyhallo.datasets.talk_video import TalkingVideoDataset
51
+ from joyhallo.models.audio_proj import AudioProjModel
52
+ from joyhallo.models.face_locator import FaceLocator
53
+ from joyhallo.models.image_proj import ImageProjModel
54
+ from joyhallo.models.mutual_self_attention import ReferenceAttentionControl
55
+ from joyhallo.models.unet_2d_condition import UNet2DConditionModel
56
+ from joyhallo.models.unet_3d import UNet3DConditionModel
57
+ from joyhallo.utils.util import (compute_snr, delete_additional_ckpt,
58
+ import_filename, init_output_dir,
59
+ load_checkpoint, save_checkpoint,
60
+ seed_everything, tensor_to_video)
61
+
62
+ warnings.filterwarnings("ignore")
63
+
64
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
65
+ check_min_version("0.10.0.dev0")
66
+
67
+ logger = get_logger(__name__, log_level="INFO")
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+
70
+
71
+ class Net(nn.Module):
72
+ """
73
+ The Net class defines a neural network model that combines a reference UNet2DConditionModel,
74
+ a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image.
75
+
76
+ Args:
77
+ reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation.
78
+ denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation.
79
+ face_locator (FaceLocator): The face locator model used for face animation.
80
+ reference_control_writer: The reference control writer component.
81
+ reference_control_reader: The reference control reader component.
82
+ imageproj: The image projection model.
83
+ audioproj: The audio projection model.
84
+
85
+ Forward method:
86
+ noisy_latents (torch.Tensor): The noisy latents tensor.
87
+ timesteps (torch.Tensor): The timesteps tensor.
88
+ ref_image_latents (torch.Tensor): The reference image latents tensor.
89
+ face_emb (torch.Tensor): The face embeddings tensor.
90
+ audio_emb (torch.Tensor): The audio embeddings tensor.
91
+ mask (torch.Tensor): Hard face mask for face locator.
92
+ full_mask (torch.Tensor): Pose Mask.
93
+ face_mask (torch.Tensor): Face Mask
94
+ lip_mask (torch.Tensor): Lip Mask
95
+ uncond_img_fwd (bool): A flag indicating whether to perform reference image unconditional forward pass.
96
+ uncond_audio_fwd (bool): A flag indicating whether to perform audio unconditional forward pass.
97
+
98
+ Returns:
99
+ torch.Tensor: The output tensor of the neural network model.
100
+ """
101
+ def __init__(
102
+ self,
103
+ reference_unet: UNet2DConditionModel,
104
+ denoising_unet: UNet3DConditionModel,
105
+ face_locator: FaceLocator,
106
+ reference_control_writer,
107
+ reference_control_reader,
108
+ imageproj,
109
+ audioproj,
110
+ ):
111
+ super().__init__()
112
+ self.reference_unet = reference_unet
113
+ self.denoising_unet = denoising_unet
114
+ self.face_locator = face_locator
115
+ self.reference_control_writer = reference_control_writer
116
+ self.reference_control_reader = reference_control_reader
117
+ self.imageproj = imageproj
118
+ self.audioproj = audioproj
119
+
120
+ def forward(
121
+ self,
122
+ noisy_latents: torch.Tensor,
123
+ timesteps: torch.Tensor,
124
+ ref_image_latents: torch.Tensor,
125
+ face_emb: torch.Tensor,
126
+ audio_emb: torch.Tensor,
127
+ mask: torch.Tensor,
128
+ full_mask: torch.Tensor,
129
+ face_mask: torch.Tensor,
130
+ lip_mask: torch.Tensor,
131
+ uncond_img_fwd: bool = False,
132
+ uncond_audio_fwd: bool = False,
133
+ ):
134
+ """
135
+ simple docstring to prevent pylint error
136
+ """
137
+ face_emb = self.imageproj(face_emb)
138
+ mask = mask.to(device=device)
139
+ mask_feature = self.face_locator(mask)
140
+ audio_emb = audio_emb.to(
141
+ device=self.audioproj.device, dtype=self.audioproj.dtype)
142
+ audio_emb = self.audioproj(audio_emb)
143
+
144
+ # condition forward
145
+ if not uncond_img_fwd:
146
+ ref_timesteps = torch.zeros_like(timesteps)
147
+ ref_timesteps = repeat(
148
+ ref_timesteps,
149
+ "b -> (repeat b)",
150
+ repeat=ref_image_latents.size(0) // ref_timesteps.size(0),
151
+ )
152
+ self.reference_unet(
153
+ ref_image_latents,
154
+ ref_timesteps,
155
+ encoder_hidden_states=face_emb,
156
+ return_dict=False,
157
+ )
158
+ self.reference_control_reader.update(self.reference_control_writer)
159
+
160
+ if uncond_audio_fwd:
161
+ audio_emb = torch.zeros_like(audio_emb).to(
162
+ device=audio_emb.device, dtype=audio_emb.dtype
163
+ )
164
+
165
+ model_pred = self.denoising_unet(
166
+ noisy_latents,
167
+ timesteps,
168
+ mask_cond_fea=mask_feature,
169
+ encoder_hidden_states=face_emb,
170
+ audio_embedding=audio_emb,
171
+ full_mask=full_mask,
172
+ face_mask=face_mask,
173
+ lip_mask=lip_mask
174
+ ).sample
175
+
176
+ return model_pred
177
+
178
+
179
+ def get_attention_mask(mask: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor:
180
+ """
181
+ Rearrange the mask tensors to the required format.
182
+
183
+ Args:
184
+ mask (torch.Tensor): The input mask tensor.
185
+ weight_dtype (torch.dtype): The data type for the mask tensor.
186
+
187
+ Returns:
188
+ torch.Tensor: The rearranged mask tensor.
189
+ """
190
+ if isinstance(mask, List):
191
+ _mask = []
192
+ for m in mask:
193
+ _mask.append(
194
+ rearrange(m, "b f 1 h w -> (b f) (h w)").to(weight_dtype))
195
+ return _mask
196
+ mask = rearrange(mask, "b f 1 h w -> (b f) (h w)").to(weight_dtype)
197
+ return mask
198
+
199
+
200
+ def get_noise_scheduler(cfg: argparse.Namespace) -> Tuple[DDIMScheduler, DDIMScheduler]:
201
+ """
202
+ Create noise scheduler for training.
203
+
204
+ Args:
205
+ cfg (argparse.Namespace): Configuration object.
206
+
207
+ Returns:
208
+ Tuple[DDIMScheduler, DDIMScheduler]: Train noise scheduler and validation noise scheduler.
209
+ """
210
+
211
+ sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
212
+ if cfg.enable_zero_snr:
213
+ sched_kwargs.update(
214
+ rescale_betas_zero_snr=True,
215
+ timestep_spacing="trailing",
216
+ prediction_type="v_prediction",
217
+ )
218
+ val_noise_scheduler = DDIMScheduler(**sched_kwargs)
219
+ sched_kwargs.update({"beta_schedule": "scaled_linear"})
220
+ train_noise_scheduler = DDIMScheduler(**sched_kwargs)
221
+
222
+ return train_noise_scheduler, val_noise_scheduler
223
+
224
+
225
+ def process_audio_emb(audio_emb: torch.Tensor) -> torch.Tensor:
226
+ """
227
+ Process the audio embedding to concatenate with other tensors.
228
+
229
+ Parameters:
230
+ audio_emb (torch.Tensor): The audio embedding tensor to process.
231
+
232
+ Returns:
233
+ concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
234
+ """
235
+ concatenated_tensors = []
236
+
237
+ for i in range(audio_emb.shape[0]):
238
+ vectors_to_concat = [
239
+ audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)]for j in range(-2, 3)]
240
+ concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
241
+
242
+ audio_emb = torch.stack(concatenated_tensors, dim=0)
243
+
244
+ return audio_emb
245
+
246
+
247
+ def log_validation(
248
+ accelerator: Accelerator,
249
+ vae: AutoencoderKL,
250
+ net: Net,
251
+ scheduler: DDIMScheduler,
252
+ width: int,
253
+ height: int,
254
+ clip_length: int = 24,
255
+ generator: torch.Generator = None,
256
+ cfg: dict = None,
257
+ save_dir: str = None,
258
+ global_step: int = 0,
259
+ times: int = None,
260
+ face_analysis_model_path: str = "",
261
+ ) -> None:
262
+ """
263
+ Log validation video during the training process.
264
+
265
+ Args:
266
+ accelerator (Accelerator): The accelerator for distributed training.
267
+ vae (AutoencoderKL): The autoencoder model.
268
+ net (Net): The main neural network model.
269
+ scheduler (DDIMScheduler): The scheduler for noise.
270
+ width (int): The width of the input images.
271
+ height (int): The height of the input images.
272
+ clip_length (int): The length of the video clips. Defaults to 24.
273
+ generator (torch.Generator): The random number generator. Defaults to None.
274
+ cfg (dict): The configuration dictionary. Defaults to None.
275
+ save_dir (str): The directory to save validation results. Defaults to None.
276
+ global_step (int): The current global step in training. Defaults to 0.
277
+ times (int): The number of inference times. Defaults to None.
278
+ face_analysis_model_path (str): The path to the face analysis model. Defaults to "".
279
+
280
+ Returns:
281
+ torch.Tensor: The tensor result of the validation.
282
+ """
283
+ ori_net = accelerator.unwrap_model(net)
284
+ reference_unet = ori_net.reference_unet
285
+ denoising_unet = ori_net.denoising_unet
286
+ face_locator = ori_net.face_locator
287
+ imageproj = ori_net.imageproj
288
+ audioproj = ori_net.audioproj
289
+ tmp_denoising_unet = copy.deepcopy(denoising_unet)
290
+
291
+ pipeline = FaceAnimatePipeline(
292
+ vae=vae,
293
+ reference_unet=reference_unet,
294
+ denoising_unet=tmp_denoising_unet,
295
+ face_locator=face_locator,
296
+ image_proj=imageproj,
297
+ scheduler=scheduler,
298
+ )
299
+ pipeline = pipeline.to(device)
300
+
301
+ image_processor = ImageProcessor((width, height), face_analysis_model_path)
302
+ audio_processor = AudioProcessor(
303
+ cfg.data.sample_rate,
304
+ cfg.data.fps,
305
+ cfg.wav2vec_config.model_path,
306
+ cfg.wav2vec_config.features == "last",
307
+ os.path.dirname(cfg.audio_separator.model_path),
308
+ os.path.basename(cfg.audio_separator.model_path),
309
+ os.path.join(save_dir, '.cache', "audio_preprocess"),
310
+ device=device,
311
+ )
312
+ return cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length
313
+
314
+
315
+ def inference(cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length):
316
+ ref_img_path = cfg.ref_img_path
317
+ audio_path = cfg.audio_path
318
+ source_image_pixels, \
319
+ source_image_face_region, \
320
+ source_image_face_emb, \
321
+ source_image_full_mask, \
322
+ source_image_face_mask, \
323
+ source_image_lip_mask = image_processor.preprocess(
324
+ ref_img_path, os.path.join(save_dir, '.cache'), cfg.face_expand_ratio)
325
+ audio_emb, audio_length = audio_processor.preprocess(
326
+ audio_path, clip_length)
327
+
328
+ audio_emb = process_audio_emb(audio_emb)
329
+
330
+ source_image_pixels = source_image_pixels.unsqueeze(0)
331
+ source_image_face_region = source_image_face_region.unsqueeze(0)
332
+ source_image_face_emb = source_image_face_emb.reshape(1, -1)
333
+ source_image_face_emb = torch.tensor(source_image_face_emb)
334
+
335
+ source_image_full_mask = [
336
+ (mask.repeat(clip_length, 1))
337
+ for mask in source_image_full_mask
338
+ ]
339
+ source_image_face_mask = [
340
+ (mask.repeat(clip_length, 1))
341
+ for mask in source_image_face_mask
342
+ ]
343
+ source_image_lip_mask = [
344
+ (mask.repeat(clip_length, 1))
345
+ for mask in source_image_lip_mask
346
+ ]
347
+
348
+ times = audio_emb.shape[0] // clip_length
349
+ tensor_result = []
350
+ generator = torch.manual_seed(42)
351
+ for t in range(times):
352
+ print(f"[{t+1}/{times}]")
353
+
354
+ if len(tensor_result) == 0:
355
+ # The first iteration
356
+ motion_zeros = source_image_pixels.repeat(
357
+ cfg.data.n_motion_frames, 1, 1, 1)
358
+ motion_zeros = motion_zeros.to(
359
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
360
+ pixel_values_ref_img = torch.cat(
361
+ [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
362
+ else:
363
+ motion_frames = tensor_result[-1][0]
364
+ motion_frames = motion_frames.permute(1, 0, 2, 3)
365
+ motion_frames = motion_frames[0 - cfg.data.n_motion_frames:]
366
+ motion_frames = motion_frames * 2.0 - 1.0
367
+ motion_frames = motion_frames.to(
368
+ dtype=source_image_pixels.dtype, device=source_image_pixels.device)
369
+ pixel_values_ref_img = torch.cat(
370
+ [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
371
+
372
+ pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
373
+
374
+ audio_tensor = audio_emb[
375
+ t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
376
+ ]
377
+ audio_tensor = audio_tensor.unsqueeze(0)
378
+ audio_tensor = audio_tensor.to(
379
+ device=audioproj.device, dtype=audioproj.dtype)
380
+ audio_tensor = audioproj(audio_tensor)
381
+
382
+ pipeline_output = pipeline(
383
+ ref_image=pixel_values_ref_img,
384
+ audio_tensor=audio_tensor,
385
+ face_emb=source_image_face_emb,
386
+ face_mask=source_image_face_region,
387
+ pixel_values_full_mask=source_image_full_mask,
388
+ pixel_values_face_mask=source_image_face_mask,
389
+ pixel_values_lip_mask=source_image_lip_mask,
390
+ width=cfg.data.train_width,
391
+ height=cfg.data.train_height,
392
+ video_length=clip_length,
393
+ num_inference_steps=cfg.inference_steps,
394
+ guidance_scale=cfg.cfg_scale,
395
+ generator=generator,
396
+ )
397
+
398
+ tensor_result.append(pipeline_output.videos)
399
+
400
+ tensor_result = torch.cat(tensor_result, dim=2)
401
+ tensor_result = tensor_result.squeeze(0)
402
+ tensor_result = tensor_result[:, :audio_length]
403
+ output_file = cfg.output
404
+ tensor_to_video(tensor_result, output_file, audio_path)
405
+ return output_file
406
+
407
+
408
+ def get_model(cfg: argparse.Namespace) -> None:
409
+ """
410
+ Trains the model using the given configuration (cfg).
411
+
412
+ Args:
413
+ cfg (dict): The configuration dictionary containing the parameters for training.
414
+
415
+ Notes:
416
+ - This function trains the model using the given configuration.
417
+ - It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler.
418
+ - The training progress is logged and tracked using the accelerator.
419
+ - The trained model is saved after the training is completed.
420
+ """
421
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
422
+ accelerator = Accelerator(
423
+ gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
424
+ mixed_precision=cfg.solver.mixed_precision,
425
+ log_with="mlflow",
426
+ project_dir="./mlruns",
427
+ kwargs_handlers=[kwargs],
428
+ )
429
+
430
+ # Make one log on every process with the configuration for debugging.
431
+ logging.basicConfig(
432
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
433
+ datefmt="%m/%d/%Y %H:%M:%S",
434
+ level=logging.INFO,
435
+ )
436
+ logger.info(accelerator.state, main_process_only=False)
437
+ if accelerator.is_local_main_process:
438
+ transformers.utils.logging.set_verbosity_warning()
439
+ diffusers.utils.logging.set_verbosity_info()
440
+ else:
441
+ transformers.utils.logging.set_verbosity_error()
442
+ diffusers.utils.logging.set_verbosity_error()
443
+
444
+ # If passed along, set the training seed now.
445
+ if cfg.seed is not None:
446
+ seed_everything(cfg.seed)
447
+
448
+ # create output dir for training
449
+ exp_name = cfg.exp_name
450
+ save_dir = f"{cfg.output_dir}/{exp_name}"
451
+ validation_dir = save_dir
452
+ if accelerator.is_main_process:
453
+ init_output_dir([save_dir])
454
+
455
+ accelerator.wait_for_everyone()
456
+
457
+ if cfg.weight_dtype == "fp16":
458
+ weight_dtype = torch.float16
459
+ elif cfg.weight_dtype == "bf16":
460
+ weight_dtype = torch.bfloat16
461
+ elif cfg.weight_dtype == "fp32":
462
+ weight_dtype = torch.float32
463
+ else:
464
+ raise ValueError(
465
+ f"Do not support weight dtype: {cfg.weight_dtype} during training"
466
+ )
467
+
468
+ if not torch.cuda.is_available():
469
+ weight_dtype = torch.float32
470
+
471
+ # Create Models
472
+ vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
473
+ device=device, dtype=weight_dtype
474
+ )
475
+ reference_unet = UNet2DConditionModel.from_pretrained(
476
+ cfg.base_model_path,
477
+ subfolder="unet",
478
+ ).to(device=device, dtype=weight_dtype)
479
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
480
+ cfg.base_model_path,
481
+ cfg.mm_path,
482
+ subfolder="unet",
483
+ unet_additional_kwargs=OmegaConf.to_container(
484
+ cfg.unet_additional_kwargs),
485
+ use_landmark=False
486
+ ).to(device=device, dtype=weight_dtype)
487
+ imageproj = ImageProjModel(
488
+ cross_attention_dim=denoising_unet.config.cross_attention_dim,
489
+ clip_embeddings_dim=512,
490
+ clip_extra_context_tokens=4,
491
+ ).to(device=device, dtype=weight_dtype)
492
+ face_locator = FaceLocator(
493
+ conditioning_embedding_channels=320,
494
+ ).to(device=device, dtype=weight_dtype)
495
+ audioproj = AudioProjModel(
496
+ seq_len=5,
497
+ blocks=12,
498
+ channels=768,
499
+ intermediate_dim=512,
500
+ output_dim=768,
501
+ context_tokens=32,
502
+ ).to(device=device, dtype=weight_dtype)
503
+
504
+ # Freeze
505
+ vae.requires_grad_(False)
506
+ imageproj.requires_grad_(False)
507
+ reference_unet.requires_grad_(False)
508
+ denoising_unet.requires_grad_(False)
509
+ face_locator.requires_grad_(False)
510
+ audioproj.requires_grad_(True)
511
+
512
+ # Set motion module learnable
513
+ trainable_modules = cfg.trainable_para
514
+ for name, module in denoising_unet.named_modules():
515
+ if any(trainable_mod in name for trainable_mod in trainable_modules):
516
+ for params in module.parameters():
517
+ params.requires_grad_(True)
518
+
519
+ reference_control_writer = ReferenceAttentionControl(
520
+ reference_unet,
521
+ do_classifier_free_guidance=False,
522
+ mode="write",
523
+ fusion_blocks="full",
524
+ )
525
+ reference_control_reader = ReferenceAttentionControl(
526
+ denoising_unet,
527
+ do_classifier_free_guidance=False,
528
+ mode="read",
529
+ fusion_blocks="full",
530
+ )
531
+
532
+ net = Net(
533
+ reference_unet,
534
+ denoising_unet,
535
+ face_locator,
536
+ reference_control_writer,
537
+ reference_control_reader,
538
+ imageproj,
539
+ audioproj,
540
+ ).to(dtype=weight_dtype)
541
+
542
+ m,u = net.load_state_dict(
543
+ torch.load(
544
+ cfg.audio_ckpt_dir,
545
+ map_location="cpu",
546
+ ),
547
+ )
548
+ assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint."
549
+ print("loaded weight from ", os.path.join(cfg.audio_ckpt_dir))
550
+
551
+ # get noise scheduler
552
+ _, val_noise_scheduler = get_noise_scheduler(cfg)
553
+
554
+ if cfg.solver.enable_xformers_memory_efficient_attention and torch.cuda.is_available():
555
+ if is_xformers_available():
556
+ reference_unet.enable_xformers_memory_efficient_attention()
557
+ denoising_unet.enable_xformers_memory_efficient_attention()
558
+
559
+ else:
560
+ raise ValueError(
561
+ "xformers is not available. Make sure it is installed correctly"
562
+ )
563
+
564
+ if cfg.solver.gradient_checkpointing:
565
+ reference_unet.enable_gradient_checkpointing()
566
+ denoising_unet.enable_gradient_checkpointing()
567
+
568
+ if cfg.solver.scale_lr:
569
+ learning_rate = (
570
+ cfg.solver.learning_rate
571
+ * cfg.solver.gradient_accumulation_steps
572
+ * cfg.data.train_bs
573
+ * accelerator.num_processes
574
+ )
575
+ else:
576
+ learning_rate = cfg.solver.learning_rate
577
+
578
+ # Initialize the optimizer
579
+ optimizer_cls = torch.optim.AdamW
580
+
581
+ trainable_params = list(
582
+ filter(lambda p: p.requires_grad, net.parameters()))
583
+
584
+ optimizer = optimizer_cls(
585
+ trainable_params,
586
+ lr=learning_rate,
587
+ betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
588
+ weight_decay=cfg.solver.adam_weight_decay,
589
+ eps=cfg.solver.adam_epsilon,
590
+ )
591
+
592
+ # Scheduler
593
+ lr_scheduler = get_scheduler(
594
+ cfg.solver.lr_scheduler,
595
+ optimizer=optimizer,
596
+ num_warmup_steps=cfg.solver.lr_warmup_steps
597
+ * cfg.solver.gradient_accumulation_steps,
598
+ num_training_steps=cfg.solver.max_train_steps
599
+ * cfg.solver.gradient_accumulation_steps,
600
+ )
601
+
602
+ # get data loader
603
+ train_dataset = TalkingVideoDataset(
604
+ img_size=(cfg.data.train_width, cfg.data.train_height),
605
+ sample_rate=cfg.data.sample_rate,
606
+ n_sample_frames=cfg.data.n_sample_frames,
607
+ n_motion_frames=cfg.data.n_motion_frames,
608
+ audio_margin=cfg.data.audio_margin,
609
+ data_meta_paths=cfg.data.train_meta_paths,
610
+ wav2vec_cfg=cfg.wav2vec_config,
611
+ )
612
+ train_dataloader = torch.utils.data.DataLoader(
613
+ train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=16
614
+ )
615
+
616
+ # Prepare everything with our `accelerator`.
617
+ (
618
+ net,
619
+ optimizer,
620
+ train_dataloader,
621
+ lr_scheduler,
622
+ ) = accelerator.prepare(
623
+ net,
624
+ optimizer,
625
+ train_dataloader,
626
+ lr_scheduler,
627
+ )
628
+
629
+ return accelerator, vae, net, val_noise_scheduler, cfg, validation_dir
630
+
631
+
632
+ def load_config(config_path: str) -> dict:
633
+ """
634
+ Loads the configuration file.
635
+
636
+ Args:
637
+ config_path (str): Path to the configuration file.
638
+
639
+ Returns:
640
+ dict: The configuration dictionary.
641
+ """
642
+
643
+ if config_path.endswith(".yaml"):
644
+ return OmegaConf.load(config_path)
645
+ if config_path.endswith(".py"):
646
+ return import_filename(config_path).cfg
647
+ raise ValueError("Unsupported format for config file")
648
+
649
+ args = argparse.Namespace()
650
+ _config = load_config('configs/inference/inference.yaml')
651
+ for key, value in _config.items():
652
+ setattr(args, key, value)
653
+ accelerator, vae, net, val_noise_scheduler, cfg, validation_dir = get_model(args)
654
+ cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length = log_validation(
655
+ accelerator=accelerator,
656
+ vae=vae,
657
+ net=net,
658
+ scheduler=val_noise_scheduler,
659
+ width=cfg.data.train_width,
660
+ height=cfg.data.train_height,
661
+ clip_length=cfg.data.n_sample_frames,
662
+ cfg=cfg,
663
+ save_dir=validation_dir,
664
+ global_step=0,
665
+ times=cfg.single_inference_times if cfg.single_inference_times is not None else None,
666
+ face_analysis_model_path=cfg.face_analysis_model_path
667
+ )
668
+
669
+ def predict(image, audio, pose_weight, face_weight, lip_weight, face_expand_ratio, progress=gr.Progress(track_tqdm=True)):
670
+ """
671
+ Create a gradio interface with the configs.
672
+ """
673
+ _ = progress
674
+ unique_id = uuid.uuid4()
675
+ config = {
676
+ 'ref_img_path': image,
677
+ 'audio_path': audio,
678
+ 'pose_weight': pose_weight,
679
+ 'face_weight': face_weight,
680
+ 'lip_weight': lip_weight,
681
+ 'face_expand_ratio': face_expand_ratio,
682
+ 'config': 'configs/inference/inference.yaml',
683
+ 'checkpoint': None,
684
+ 'output': f'output-{unique_id}.mp4'
685
+ }
686
+ global cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length
687
+ for key, value in config.items():
688
+ setattr(cfg, key, value)
689
+
690
+ return inference(cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length)