fffiloni commited on
Commit
f7c49f8
1 Parent(s): bd349c0

Initialize latents correctly

Browse files
animatediff/pipelines/pipeline_animation.py CHANGED
@@ -328,17 +328,14 @@ class AnimationPipeline(DiffusionPipeline):
328
 
329
  latents = latents.to(device)
330
  else:
 
 
331
  # If init_latents is not None, repeat it for the entire batch
332
  if init_latents is not None:
333
  init_latents = init_latents.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
334
  for i in range(video_length):
335
- if init_latents is None:
336
- latents[:, :, i, :, :] = torch.randn(latents[:, :, i, :, :].shape, device=rand_device, dtype=dtype) * self.scheduler.init_noise_sigma
337
- else:
338
- init_alpha = (video_length - float(i)) / video_length / 30
339
- latents[:, :, i, :, :] = init_latents * init_alpha + latents[:, :, i, :, :] * (1 - init_alpha)
340
- else:
341
- latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
342
 
343
  else:
344
  if latents.shape != shape:
@@ -353,6 +350,7 @@ class AnimationPipeline(DiffusionPipeline):
353
 
354
 
355
 
 
356
  @torch.no_grad()
357
  def __call__(
358
  self,
 
328
 
329
  latents = latents.to(device)
330
  else:
331
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
332
+
333
  # If init_latents is not None, repeat it for the entire batch
334
  if init_latents is not None:
335
  init_latents = init_latents.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
336
  for i in range(video_length):
337
+ init_alpha = (video_length - float(i)) / video_length / 30
338
+ latents[:, :, i, :, :] = init_latents[:, :, i, :, :] * init_alpha + latents[:, :, i, :, :] * (1 - init_alpha)
 
 
 
 
 
339
 
340
  else:
341
  if latents.shape != shape:
 
350
 
351
 
352
 
353
+
354
  @torch.no_grad()
355
  def __call__(
356
  self,