xxxpo13 commited on
Commit
bb749ff
·
verified ·
1 Parent(s): 51e9ef1

Create pyramid_dit_for_video_gen_pipeline.py

Browse files
Files changed (1) hide show
  1. pyramid_dit_for_video_gen_pipeline.py +518 -0
pyramid_dit_for_video_gen_pipeline.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import sys
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from collections import OrderedDict
8
+ from einops import rearrange
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ import numpy as np
11
+ import math
12
+ import random
13
+ import PIL
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+ from torchvision import transforms
17
+ from copy import deepcopy
18
+ from typing import Any, Callable, Dict, List, Optional, Union
19
+ from accelerate import Accelerator
20
+ from diffusion_schedulers import PyramidFlowMatchEulerDiscreteScheduler
21
+ from video_vae.modeling_causal_vae import CausalVideoVAE
22
+
23
+ from trainer_misc import (
24
+ all_to_all,
25
+ is_sequence_parallel_initialized,
26
+ get_sequence_parallel_group,
27
+ get_sequence_parallel_group_rank,
28
+ get_sequence_parallel_rank,
29
+ get_sequence_parallel_world_size,
30
+ get_rank,
31
+ )
32
+
33
+ from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT
34
+ from .modeling_text_encoder import SD3TextEncoderWithMask
35
+
36
+ def compute_density_for_timestep_sampling(
37
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
38
+ ):
39
+ if weighting_scheme == "logit_normal":
40
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
41
+ u = torch.nn.functional.sigmoid(u)
42
+ elif weighting_scheme == "mode":
43
+ u = torch.rand(size=(batch_size,), device="cpu")
44
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
45
+ else:
46
+ u = torch.rand(size=(batch_size,), device="cpu")
47
+ return u
48
+
49
+ class PyramidDiTForVideoGeneration:
50
+ def __init__(self, model_path, model_dtype='bf16', use_gradient_checkpointing=False, return_log=True,
51
+ model_variant="diffusion_transformer_768p", timestep_shift=1.0, stage_range=[0, 1/3, 2/3, 1],
52
+ sample_ratios=[1, 1, 1], scheduler_gamma=1/3, use_mixed_training=False, use_flash_attn=False,
53
+ load_text_encoder=True, load_vae=True, max_temporal_length=31, frame_per_unit=1, use_temporal_causal=True,
54
+ corrupt_ratio=1/3, interp_condition_pos=True, stages=[1, 2, 4], **kwargs,
55
+ ):
56
+ super().__init__()
57
+
58
+ if model_dtype == 'bf16':
59
+ torch_dtype = torch.bfloat16
60
+ elif model_dtype == 'fp16':
61
+ torch_dtype = torch.float16
62
+ else:
63
+ torch_dtype = torch.float32
64
+
65
+ self.stages = stages
66
+ self.sample_ratios = sample_ratios
67
+ self.corrupt_ratio = corrupt_ratio
68
+
69
+ dit_path = os.path.join(model_path, model_variant)
70
+
71
+ # The dit
72
+ if use_mixed_training:
73
+ print("using mixed precision training, do not explicitly casting models")
74
+ self.dit = PyramidDiffusionMMDiT.from_pretrained(
75
+ dit_path, use_gradient_checkpointing=use_gradient_checkpointing,
76
+ use_flash_attn=use_flash_attn, use_t5_mask=True,
77
+ add_temp_pos_embed=True, temp_pos_embed_type='rope',
78
+ use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos,
79
+ )
80
+ else:
81
+ print("using half precision")
82
+ self.dit = PyramidDiffusionMMDiT.from_pretrained(
83
+ dit_path, torch_dtype=torch_dtype,
84
+ use_gradient_checkpointing=use_gradient_checkpointing,
85
+ use_flash_attn=use_flash_attn, use_t5_mask=True,
86
+ add_temp_pos_embed=True, temp_pos_embed_type='rope',
87
+ use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos,
88
+ )
89
+
90
+ # The text encoder
91
+ if load_text_encoder:
92
+ self.text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype)
93
+ else:
94
+ self.text_encoder = None
95
+
96
+ # The base video vae decoder
97
+ if load_vae:
98
+ self.vae = CausalVideoVAE.from_pretrained(os.path.join(model_path, 'causal_video_vae'), torch_dtype=torch_dtype, interpolate=False)
99
+ # Freeze vae
100
+ for parameter in self.vae.parameters():
101
+ parameter.requires_grad = False
102
+ else:
103
+ self.vae = None
104
+
105
+ # For the image latent
106
+ self.vae_shift_factor = 0.1490
107
+ self.vae_scale_factor = 1 / 1.8415
108
+
109
+ # For the video latent
110
+ self.vae_video_shift_factor = -0.2343
111
+ self.vae_video_scale_factor = 1 / 3.0986
112
+
113
+ self.downsample = 8
114
+
115
+ # Configure the video training hyper-parameters
116
+ # The video sequence: one frame + N * unit
117
+ self.frame_per_unit = frame_per_unit
118
+ self.max_temporal_length = max_temporal_length
119
+ assert (max_temporal_length - 1) % frame_per_unit == 0, "The frame number should be divided by the frame number per unit"
120
+ self.num_units_per_video = 1 + ((max_temporal_length - 1) // frame_per_unit) + int(sum(sample_ratios))
121
+
122
+ self.scheduler = PyramidFlowMatchEulerDiscreteScheduler(
123
+ shift=timestep_shift, stages=len(self.stages),
124
+ stage_range=stage_range, gamma=scheduler_gamma,
125
+ )
126
+ print(f"The start sigmas and end sigmas of each stage is Start: {self.scheduler.start_sigmas}, End: {self.scheduler.end_sigmas}, Ori_start: {self.scheduler.ori_start_sigmas}")
127
+
128
+ self.cfg_rate = 0.1
129
+ self.return_log = return_log
130
+ self.use_flash_attn = use_flash_attn
131
+
132
+ # Initialize scaler for mixed precision
133
+ self.scaler = torch.cuda.amp.GradScaler()
134
+
135
+ # ... [other methods remain the same] ...
136
+
137
+ @torch.cuda.amp.autocast()
138
+ def generate(
139
+ self,
140
+ prompt: Union[str, List[str]] = None,
141
+ height: Optional[int] = None,
142
+ width: Optional[int] = None,
143
+ temp: int = 1,
144
+ num_inference_steps: Optional[Union[int, List[int]]] = 28,
145
+ video_num_inference_steps: Optional[Union[int, List[int]]] = 28,
146
+ guidance_scale: float = 7.0,
147
+ video_guidance_scale: float = 7.0,
148
+ min_guidance_scale: float = 2.0,
149
+ use_linear_guidance: bool = False,
150
+ alpha: float = 0.5,
151
+ negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror",
152
+ num_images_per_prompt: Optional[int] = 1,
153
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
154
+ output_type: Optional[str] = "pil",
155
+ save_memory: bool = True,
156
+ cpu_offloading: bool = False,
157
+ ):
158
+ device = self.device if not cpu_offloading else "cuda"
159
+ dtype = self.dtype
160
+ if cpu_offloading:
161
+ if str(self.dit.device) != "cpu":
162
+ print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
163
+ self.dit.to("cpu")
164
+ if str(self.vae.device) != "cpu":
165
+ print("(vae) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
166
+ self.vae.to("cpu")
167
+
168
+ assert (temp - 1) % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"
169
+
170
+ if isinstance(prompt, str):
171
+ batch_size = 1
172
+ prompt = prompt + ", hyper quality, Ultra HD, 8K"
173
+ else:
174
+ assert isinstance(prompt, list)
175
+ batch_size = len(prompt)
176
+ prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt]
177
+
178
+ if isinstance(num_inference_steps, int):
179
+ num_inference_steps = [num_inference_steps] * len(self.stages)
180
+
181
+ if isinstance(video_num_inference_steps, int):
182
+ video_num_inference_steps = [video_num_inference_steps] * len(self.stages)
183
+
184
+ negative_prompt = negative_prompt or ""
185
+
186
+ # Get the text embeddings
187
+ if cpu_offloading:
188
+ self.text_encoder.to("cuda")
189
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
190
+ negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
191
+ if cpu_offloading:
192
+ self.text_encoder.to("cpu")
193
+ self.dit.to("cuda")
194
+
195
+ if use_linear_guidance:
196
+ max_guidance_scale = guidance_scale
197
+ guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp)]
198
+ print(guidance_scale_list)
199
+
200
+ self._guidance_scale = guidance_scale
201
+ self._video_guidance_scale = video_guidance_scale
202
+
203
+ if self.do_classifier_free_guidance:
204
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
205
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
206
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
207
+
208
+ # Create the initial random noise
209
+ num_channels_latents = self.dit.config.in_channels
210
+ latents = self.prepare_latents(
211
+ batch_size * num_images_per_prompt,
212
+ num_channels_latents,
213
+ temp,
214
+ height,
215
+ width,
216
+ prompt_embeds.dtype,
217
+ device,
218
+ generator,
219
+ )
220
+
221
+ temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
222
+
223
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w')
224
+ for _ in range(len(self.stages)-1):
225
+ height //= 2;width //= 2
226
+ latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2
227
+
228
+ latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
229
+
230
+ num_units = 1 + (temp - 1) // self.frame_per_unit
231
+ stages = self.stages
232
+
233
+ generated_latents_list = []
234
+ last_generated_latents = None
235
+
236
+ for unit_index in tqdm(range(num_units)):
237
+ if use_linear_guidance:
238
+ self._guidance_scale = guidance_scale_list[unit_index]
239
+ self._video_guidance_scale = guidance_scale_list[unit_index]
240
+
241
+ if unit_index == 0:
242
+ past_condition_latents = [[] for _ in range(len(stages))]
243
+ with torch.no_grad():
244
+ intermed_latents = self.generate_one_unit(
245
+ latents[:,:,:1],
246
+ past_condition_latents,
247
+ prompt_embeds,
248
+ prompt_attention_mask,
249
+ pooled_prompt_embeds,
250
+ num_inference_steps,
251
+ height,
252
+ width,
253
+ 1,
254
+ device,
255
+ dtype,
256
+ generator,
257
+ is_first_frame=True,
258
+ )
259
+ else:
260
+ past_condition_latents = []
261
+ clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
262
+
263
+ for i_s in range(len(stages)):
264
+ last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
265
+ stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
266
+
267
+ cur_unit_num = unit_index
268
+ cur_stage = i_s
269
+ cur_unit_ptx = 1
270
+
271
+ while cur_unit_ptx < cur_unit_num:
272
+ cur_stage = max(cur_stage - 1, 0)
273
+ if cur_stage == 0:
274
+ break
275
+ cur_unit_ptx += 1
276
+ cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
277
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
278
+
279
+ if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
280
+ cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
281
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
282
+
283
+ stage_input = list(reversed(stage_input))
284
+ past_condition_latents.append(stage_input)
285
+
286
+ with torch.no_grad():
287
+ intermed_latents = self.generate_one_unit(
288
+ latents[:,:, 1 + (unit_index - 1) * self.frame_per_unit:1 + unit_index * self.frame_per_unit],
289
+ past_condition_latents,
290
+ prompt_embeds,
291
+ prompt_attention_mask,
292
+ pooled_prompt_embeds,
293
+ video_num_inference_steps,
294
+ height,
295
+ width,
296
+ self.frame_per_unit,
297
+ device,
298
+ dtype,
299
+ generator,
300
+ is_first_frame=False,
301
+ )
302
+
303
+ generated_latents_list.append(intermed_latents[-1])
304
+ last_generated_latents = intermed_latents
305
+
306
+ torch.cuda.empty_cache()
307
+
308
+ generated_latents = torch.cat(generated_latents_list, dim=2)
309
+
310
+ if output_type == "latent":
311
+ image = generated_latents
312
+ else:
313
+ if cpu_offloading:
314
+ self.dit.to("cpu")
315
+ self.vae.to("cuda")
316
+ image = self.decode_latent(generated_latents, save_memory=save_memory)
317
+ if cpu_offloading:
318
+ self.vae.to("cpu")
319
+
320
+ return image
321
+
322
+ def decode_latent(self, latents, save_memory=True):
323
+ if latents.shape[2] == 1:
324
+ latents = (latents / self.vae_scale_factor) + self.vae_shift_factor
325
+ else:
326
+ latents[:, :, :1] = (latents[:, :, :1] / self.vae_scale_factor) + self.vae_shift_factor
327
+ latents[:, :, 1:] = (latents[:, :, 1:] / self.vae_video_scale_factor) + self.vae_video_shift_factor
328
+
329
+ with torch.no_grad(), torch.cuda.amp.autocast():
330
+ if save_memory:
331
+ image = self.vae.decode(latents, temporal_chunk=True, window_size=1, tile_sample_min_size=128).sample
332
+ else:
333
+ image = self.vae.decode(latents, temporal_chunk=True, window_size=2, tile_sample_min_size=256).sample
334
+
335
+ image = image.float()
336
+ image = (image / 2 + 0.5).clamp(0, 1)
337
+ image = rearrange(image, "B C T H W -> (B T) C H W")
338
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
339
+ image = self.numpy_to_pil(image)
340
+ return image
341
+
342
+ @staticmethod
343
+ def numpy_to_pil(images):
344
+ if images.ndim == 3:
345
+ images = images[None, ...]
346
+ images = (images * 255).round().astype("uint8")
347
+ if images.shape[-1] == 1:
348
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
349
+ else:
350
+ pil_images = [Image.fromarray(image) for image in images]
351
+ return pil_images
352
+
353
+ @property
354
+ def device(self):
355
+ return next(self.dit.parameters()).device
356
+
357
+ @property
358
+ def dtype(self):
359
+ return next(self.dit.parameters()).dtype
360
+
361
+ @property
362
+ def guidance_scale(self):
363
+ return self._guidance_scale
364
+
365
+ @property
366
+ def video_guidance_scale(self):
367
+ return self._video_guidance_scale
368
+
369
+ @property
370
+ def do_classifier_free_guidance(self):
371
+ return self._guidance_scale > 0
372
+
373
+ def prepare_latents(
374
+ self,
375
+ batch_size,
376
+ num_channels_latents,
377
+ temp,
378
+ height,
379
+ width,
380
+ dtype,
381
+ device,
382
+ generator,
383
+ ):
384
+ shape = (
385
+ batch_size,
386
+ num_channels_latents,
387
+ int(temp),
388
+ int(height) // self.downsample,
389
+ int(width) // self.downsample,
390
+ )
391
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
392
+ return latents
393
+
394
+ def sample_block_noise(self, bs, ch, temp, height, width):
395
+ gamma = self.scheduler.config.gamma
396
+ dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma)
397
+ block_number = bs * ch * temp * (height // 2) * (width // 2)
398
+ noise = torch.stack([dist.sample() for _ in range(block_number)])
399
+ noise = rearrange(noise, '(b c t h w) (p q) -> b c t (h p) (w q)',b=bs,c=ch,t=temp,h=height//2,w=width//2,p=2,q=2)
400
+ return noise
401
+
402
+ @torch.no_grad()
403
+ def generate_one_unit(
404
+ self,
405
+ latents,
406
+ past_conditions,
407
+ prompt_embeds,
408
+ prompt_attention_mask,
409
+ pooled_prompt_embeds,
410
+ num_inference_steps,
411
+ height,
412
+ width,
413
+ temp,
414
+ device,
415
+ dtype,
416
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
417
+ is_first_frame: bool = False,
418
+ ):
419
+ stages = self.stages
420
+ intermed_latents = []
421
+
422
+ for i_s in range(len(stages)):
423
+ self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device)
424
+ timesteps = self.scheduler.timesteps
425
+
426
+ if i_s > 0:
427
+ height *= 2; width *= 2
428
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w')
429
+ latents = F.interpolate(latents, size=(height, width), mode='nearest')
430
+ latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
431
+ ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s]
432
+ gamma = self.scheduler.config.gamma
433
+ alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)
434
+ beta = alpha * (1 - ori_sigma) / math.sqrt(gamma)
435
+
436
+ bs, ch, temp, height, width = latents.shape
437
+ noise = self.sample_block_noise(bs, ch, temp, height, width)
438
+ noise = noise.to(device=device, dtype=dtype)
439
+ latents = alpha * latents + beta * noise
440
+
441
+ for idx, t in enumerate(timesteps):
442
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
443
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
444
+ latent_model_input = past_conditions[i_s] + [latent_model_input]
445
+
446
+ noise_pred = self.dit(
447
+ sample=[latent_model_input],
448
+ timestep_ratio=timestep,
449
+ encoder_hidden_states=prompt_embeds,
450
+ encoder_attention_mask=prompt_attention_mask,
451
+ pooled_projections=pooled_prompt_embeds,
452
+ )
453
+
454
+ noise_pred = noise_pred[0]
455
+
456
+ if self.do_classifier_free_guidance:
457
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
458
+ if is_first_frame:
459
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
460
+ else:
461
+ noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond)
462
+
463
+ latents = self.scheduler.step(
464
+ model_output=noise_pred,
465
+ timestep=timestep,
466
+ sample=latents,
467
+ generator=generator,
468
+ ).prev_sample
469
+
470
+ intermed_latents.append(latents)
471
+
472
+ return intermed_latents
473
+
474
+ def get_pyramid_latent(self, x, stage_num):
475
+ vae_latent_list = []
476
+ vae_latent_list.append(x)
477
+
478
+ temp, height, width = x.shape[-3], x.shape[-2], x.shape[-1]
479
+ for _ in range(stage_num):
480
+ height //= 2
481
+ width //= 2
482
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
483
+ x = torch.nn.functional.interpolate(x, size=(height, width), mode='bilinear')
484
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=temp)
485
+ vae_latent_list.append(x)
486
+
487
+ vae_latent_list = list(reversed(vae_latent_list))
488
+ return vae_latent_list
489
+
490
+ def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs):
491
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
492
+ dit_checkpoint = OrderedDict()
493
+ for key in checkpoint:
494
+ if key.startswith('vae') or key.startswith('text_encoder'):
495
+ continue
496
+ if key.startswith('dit'):
497
+ new_key = key.split('.')
498
+ new_key = '.'.join(new_key[1:])
499
+ dit_checkpoint[new_key] = checkpoint[key]
500
+ else:
501
+ dit_checkpoint[key] = checkpoint[key]
502
+
503
+ load_result = self.dit.load_state_dict(dit_checkpoint, strict=True)
504
+ print(f"Load checkpoint from {checkpoint_path}, load result: {load_result}")
505
+
506
+ def load_vae_checkpoint(self, vae_checkpoint_path, model_key='model'):
507
+ checkpoint = torch.load(vae_checkpoint_path, map_location='cpu')
508
+ checkpoint = checkpoint[model_key]
509
+ loaded_checkpoint = OrderedDict()
510
+
511
+ for key in checkpoint.keys():
512
+ if key.startswith('vae.'):
513
+ new_key = key.split('.')
514
+ new_key = '.'.join(new_key[1:])
515
+ loaded_checkpoint[new_key] = checkpoint[key]
516
+
517
+ load_result = self.vae.load_state_dict(loaded_checkpoint)
518
+ print(f"Load the VAE from {vae_checkpoint_path}, load result: {load_result}")