Initialize latents correctly
Browse files
animatediff/pipelines/pipeline_animation.py
CHANGED
@@ -328,17 +328,14 @@ class AnimationPipeline(DiffusionPipeline):
|
|
328 |
|
329 |
latents = latents.to(device)
|
330 |
else:
|
|
|
|
|
331 |
# If init_latents is not None, repeat it for the entire batch
|
332 |
if init_latents is not None:
|
333 |
init_latents = init_latents.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
334 |
for i in range(video_length):
|
335 |
-
|
336 |
-
|
337 |
-
else:
|
338 |
-
init_alpha = (video_length - float(i)) / video_length / 30
|
339 |
-
latents[:, :, i, :, :] = init_latents * init_alpha + latents[:, :, i, :, :] * (1 - init_alpha)
|
340 |
-
else:
|
341 |
-
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
342 |
|
343 |
else:
|
344 |
if latents.shape != shape:
|
@@ -353,6 +350,7 @@ class AnimationPipeline(DiffusionPipeline):
|
|
353 |
|
354 |
|
355 |
|
|
|
356 |
@torch.no_grad()
|
357 |
def __call__(
|
358 |
self,
|
|
|
328 |
|
329 |
latents = latents.to(device)
|
330 |
else:
|
331 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
332 |
+
|
333 |
# If init_latents is not None, repeat it for the entire batch
|
334 |
if init_latents is not None:
|
335 |
init_latents = init_latents.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
336 |
for i in range(video_length):
|
337 |
+
init_alpha = (video_length - float(i)) / video_length / 30
|
338 |
+
latents[:, :, i, :, :] = init_latents[:, :, i, :, :] * init_alpha + latents[:, :, i, :, :] * (1 - init_alpha)
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
else:
|
341 |
if latents.shape != shape:
|
|
|
350 |
|
351 |
|
352 |
|
353 |
+
|
354 |
@torch.no_grad()
|
355 |
def __call__(
|
356 |
self,
|