fffiloni commited on
Commit
dc3cd77
1 Parent(s): 72801fb

Update animatediff/pipelines/pipeline_animation.py

Browse files
animatediff/pipelines/pipeline_animation.py CHANGED
@@ -8,6 +8,8 @@ import numpy as np
8
  import torch
9
  from tqdm import tqdm
10
 
 
 
11
  from diffusers.utils import is_accelerate_available
12
  from packaging import version
13
  from transformers import CLIPTextModel, CLIPTokenizer
@@ -28,7 +30,7 @@ from diffusers.utils import deprecate, logging, BaseOutput
28
  from einops import rearrange
29
 
30
  from ..models.unet import UNet3DConditionModel
31
-
32
 
33
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
 
@@ -283,8 +285,29 @@ class AnimationPipeline(DiffusionPipeline):
283
  f" {type(callback_steps)}."
284
  )
285
 
286
- def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
 
287
  shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  if isinstance(generator, list) and len(generator) != batch_size:
289
  raise ValueError(
290
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -296,6 +319,7 @@ class AnimationPipeline(DiffusionPipeline):
296
  if isinstance(generator, list):
297
  shape = shape
298
  # shape = (1,) + shape[1:]
 
299
  latents = [
300
  torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
301
  for i in range(batch_size)
@@ -303,19 +327,29 @@ class AnimationPipeline(DiffusionPipeline):
303
  latents = torch.cat(latents, dim=0).to(device)
304
  else:
305
  latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
 
 
 
 
 
 
 
306
  else:
307
  if latents.shape != shape:
308
  raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
309
  latents = latents.to(device)
310
 
311
  # scale the initial noise by the standard deviation required by the scheduler
312
- latents = latents * self.scheduler.init_noise_sigma
 
 
313
  return latents
314
 
315
  @torch.no_grad()
316
  def __call__(
317
  self,
318
  prompt: Union[str, List[str]],
 
319
  video_length: Optional[int],
320
  height: Optional[int] = None,
321
  width: Optional[int] = None,
@@ -368,6 +402,7 @@ class AnimationPipeline(DiffusionPipeline):
368
  # Prepare latent variables
369
  num_channels_latents = self.unet.in_channels
370
  latents = self.prepare_latents(
 
371
  batch_size * num_videos_per_prompt,
372
  num_channels_latents,
373
  video_length,
 
8
  import torch
9
  from tqdm import tqdm
10
 
11
+ import PIL
12
+
13
  from diffusers.utils import is_accelerate_available
14
  from packaging import version
15
  from transformers import CLIPTextModel, CLIPTokenizer
 
30
  from einops import rearrange
31
 
32
  from ..models.unet import UNet3DConditionModel
33
+ from ..utils.util import preprocess_image
34
 
35
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
 
 
285
  f" {type(callback_steps)}."
286
  )
287
 
288
+ #def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
289
+ def prepare_latents(self, init_image, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
290
  shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
291
+
292
+ if init_image is not None:
293
+ image = PIL.Image.open(init_image)
294
+ image = preprocess_image(image)
295
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
296
+ raise ValueError(
297
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
298
+ )
299
+ image = image.to(device=device, dtype=dtype)
300
+ if isinstance(generator, list):
301
+ init_latents = [
302
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
303
+ ]
304
+ init_latents = torch.cat(init_latents, dim=0)
305
+ else:
306
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
307
+ else:
308
+ init_latents = None
309
+
310
+
311
  if isinstance(generator, list) and len(generator) != batch_size:
312
  raise ValueError(
313
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
 
319
  if isinstance(generator, list):
320
  shape = shape
321
  # shape = (1,) + shape[1:]
322
+ # ignore init latents for batch model
323
  latents = [
324
  torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
325
  for i in range(batch_size)
 
327
  latents = torch.cat(latents, dim=0).to(device)
328
  else:
329
  latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
330
+
331
+ if init_latents is not None:
332
+ for i in range(video_length):
333
+ # I just feel dividing by 30 yield stable result but I don't know why
334
+ # gradully reduce init alpha along video frames (loosen restriction)
335
+ init_alpha = (video_length - float(i)) / video_length / 30
336
+ latents[:, :, i, :, :] = init_latents * init_alpha + latents[:, :, i, :, :] * (1 - init_alpha)
337
  else:
338
  if latents.shape != shape:
339
  raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
340
  latents = latents.to(device)
341
 
342
  # scale the initial noise by the standard deviation required by the scheduler
343
+ #latents = latents * self.scheduler.init_noise_sigma
344
+ if init_latents is None:
345
+ latents = latents * self.scheduler.init_noise_sigma
346
  return latents
347
 
348
  @torch.no_grad()
349
  def __call__(
350
  self,
351
  prompt: Union[str, List[str]],
352
+ init_image: str = None,
353
  video_length: Optional[int],
354
  height: Optional[int] = None,
355
  width: Optional[int] = None,
 
402
  # Prepare latent variables
403
  num_channels_latents = self.unet.in_channels
404
  latents = self.prepare_latents(
405
+ init_image,
406
  batch_size * num_videos_per_prompt,
407
  num_channels_latents,
408
  video_length,