Update animatediff/pipelines/pipeline_animation.py
Browse files
animatediff/pipelines/pipeline_animation.py
CHANGED
@@ -8,6 +8,8 @@ import numpy as np
|
|
8 |
import torch
|
9 |
from tqdm import tqdm
|
10 |
|
|
|
|
|
11 |
from diffusers.utils import is_accelerate_available
|
12 |
from packaging import version
|
13 |
from transformers import CLIPTextModel, CLIPTokenizer
|
@@ -28,7 +30,7 @@ from diffusers.utils import deprecate, logging, BaseOutput
|
|
28 |
from einops import rearrange
|
29 |
|
30 |
from ..models.unet import UNet3DConditionModel
|
31 |
-
|
32 |
|
33 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
34 |
|
@@ -283,8 +285,29 @@ class AnimationPipeline(DiffusionPipeline):
|
|
283 |
f" {type(callback_steps)}."
|
284 |
)
|
285 |
|
286 |
-
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
|
|
287 |
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
if isinstance(generator, list) and len(generator) != batch_size:
|
289 |
raise ValueError(
|
290 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
@@ -296,6 +319,7 @@ class AnimationPipeline(DiffusionPipeline):
|
|
296 |
if isinstance(generator, list):
|
297 |
shape = shape
|
298 |
# shape = (1,) + shape[1:]
|
|
|
299 |
latents = [
|
300 |
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
301 |
for i in range(batch_size)
|
@@ -303,19 +327,29 @@ class AnimationPipeline(DiffusionPipeline):
|
|
303 |
latents = torch.cat(latents, dim=0).to(device)
|
304 |
else:
|
305 |
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
else:
|
307 |
if latents.shape != shape:
|
308 |
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
309 |
latents = latents.to(device)
|
310 |
|
311 |
# scale the initial noise by the standard deviation required by the scheduler
|
312 |
-
latents = latents * self.scheduler.init_noise_sigma
|
|
|
|
|
313 |
return latents
|
314 |
|
315 |
@torch.no_grad()
|
316 |
def __call__(
|
317 |
self,
|
318 |
prompt: Union[str, List[str]],
|
|
|
319 |
video_length: Optional[int],
|
320 |
height: Optional[int] = None,
|
321 |
width: Optional[int] = None,
|
@@ -368,6 +402,7 @@ class AnimationPipeline(DiffusionPipeline):
|
|
368 |
# Prepare latent variables
|
369 |
num_channels_latents = self.unet.in_channels
|
370 |
latents = self.prepare_latents(
|
|
|
371 |
batch_size * num_videos_per_prompt,
|
372 |
num_channels_latents,
|
373 |
video_length,
|
|
|
8 |
import torch
|
9 |
from tqdm import tqdm
|
10 |
|
11 |
+
import PIL
|
12 |
+
|
13 |
from diffusers.utils import is_accelerate_available
|
14 |
from packaging import version
|
15 |
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
30 |
from einops import rearrange
|
31 |
|
32 |
from ..models.unet import UNet3DConditionModel
|
33 |
+
from ..utils.util import preprocess_image
|
34 |
|
35 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36 |
|
|
|
285 |
f" {type(callback_steps)}."
|
286 |
)
|
287 |
|
288 |
+
#def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
289 |
+
def prepare_latents(self, init_image, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
290 |
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
291 |
+
|
292 |
+
if init_image is not None:
|
293 |
+
image = PIL.Image.open(init_image)
|
294 |
+
image = preprocess_image(image)
|
295 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
296 |
+
raise ValueError(
|
297 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
298 |
+
)
|
299 |
+
image = image.to(device=device, dtype=dtype)
|
300 |
+
if isinstance(generator, list):
|
301 |
+
init_latents = [
|
302 |
+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
303 |
+
]
|
304 |
+
init_latents = torch.cat(init_latents, dim=0)
|
305 |
+
else:
|
306 |
+
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
307 |
+
else:
|
308 |
+
init_latents = None
|
309 |
+
|
310 |
+
|
311 |
if isinstance(generator, list) and len(generator) != batch_size:
|
312 |
raise ValueError(
|
313 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
|
|
319 |
if isinstance(generator, list):
|
320 |
shape = shape
|
321 |
# shape = (1,) + shape[1:]
|
322 |
+
# ignore init latents for batch model
|
323 |
latents = [
|
324 |
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
325 |
for i in range(batch_size)
|
|
|
327 |
latents = torch.cat(latents, dim=0).to(device)
|
328 |
else:
|
329 |
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
330 |
+
|
331 |
+
if init_latents is not None:
|
332 |
+
for i in range(video_length):
|
333 |
+
# I just feel dividing by 30 yield stable result but I don't know why
|
334 |
+
# gradully reduce init alpha along video frames (loosen restriction)
|
335 |
+
init_alpha = (video_length - float(i)) / video_length / 30
|
336 |
+
latents[:, :, i, :, :] = init_latents * init_alpha + latents[:, :, i, :, :] * (1 - init_alpha)
|
337 |
else:
|
338 |
if latents.shape != shape:
|
339 |
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
340 |
latents = latents.to(device)
|
341 |
|
342 |
# scale the initial noise by the standard deviation required by the scheduler
|
343 |
+
#latents = latents * self.scheduler.init_noise_sigma
|
344 |
+
if init_latents is None:
|
345 |
+
latents = latents * self.scheduler.init_noise_sigma
|
346 |
return latents
|
347 |
|
348 |
@torch.no_grad()
|
349 |
def __call__(
|
350 |
self,
|
351 |
prompt: Union[str, List[str]],
|
352 |
+
init_image: str = None,
|
353 |
video_length: Optional[int],
|
354 |
height: Optional[int] = None,
|
355 |
width: Optional[int] = None,
|
|
|
402 |
# Prepare latent variables
|
403 |
num_channels_latents = self.unet.in_channels
|
404 |
latents = self.prepare_latents(
|
405 |
+
init_image,
|
406 |
batch_size * num_videos_per_prompt,
|
407 |
num_channels_latents,
|
408 |
video_length,
|