fffiloni commited on
Commit
248e487
1 Parent(s): dc3cd77

Update animatediff/utils/util.py

Browse files
Files changed (1) hide show
  1. animatediff/utils/util.py +66 -0
animatediff/utils/util.py CHANGED
@@ -9,6 +9,10 @@ import torchvision
9
  from tqdm import tqdm
10
  from einops import rearrange
11
 
 
 
 
 
12
 
13
  def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
14
  videos = rearrange(videos, "b c t h w -> t b c h w")
@@ -82,3 +86,65 @@ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
82
  def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
83
  ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
84
  return ddim_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from tqdm import tqdm
10
  from einops import rearrange
11
 
12
+ import PIL.Image
13
+ import PIL.ImageOps
14
+ from packaging import version
15
+ from PIL import Image
16
 
17
  def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
18
  videos = rearrange(videos, "b c t h w -> t b c h w")
 
86
  def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
87
  ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
88
  return ddim_latents
89
+
90
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
91
+ PIL_INTERPOLATION = {
92
+ "linear": PIL.Image.Resampling.BILINEAR,
93
+ "bilinear": PIL.Image.Resampling.BILINEAR,
94
+ "bicubic": PIL.Image.Resampling.BICUBIC,
95
+ "lanczos": PIL.Image.Resampling.LANCZOS,
96
+ "nearest": PIL.Image.Resampling.NEAREST,
97
+ }
98
+ else:
99
+ PIL_INTERPOLATION = {
100
+ "linear": PIL.Image.LINEAR,
101
+ "bilinear": PIL.Image.BILINEAR,
102
+ "bicubic": PIL.Image.BICUBIC,
103
+ "lanczos": PIL.Image.LANCZOS,
104
+ "nearest": PIL.Image.NEAREST,
105
+ }
106
+
107
+
108
+ def pt_to_pil(images):
109
+ """
110
+ Convert a torch image to a PIL image.
111
+ """
112
+ images = (images / 2 + 0.5).clamp(0, 1)
113
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
114
+ images = numpy_to_pil(images)
115
+ return images
116
+
117
+
118
+ def numpy_to_pil(images):
119
+ """
120
+ Convert a numpy image or a batch of images to a PIL image.
121
+ """
122
+ if images.ndim == 3:
123
+ images = images[None, ...]
124
+ images = (images * 255).round().astype("uint8")
125
+ if images.shape[-1] == 1:
126
+ # special case for grayscale (single channel) images
127
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
128
+ else:
129
+ pil_images = [Image.fromarray(image) for image in images]
130
+
131
+
132
+ def preprocess_image(image):
133
+ if isinstance(image, torch.Tensor):
134
+ return image
135
+ elif isinstance(image, PIL.Image.Image):
136
+ image = [image]
137
+
138
+ if isinstance(image[0], PIL.Image.Image):
139
+ w, h = image[0].size
140
+ w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
141
+
142
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
143
+ image = np.concatenate(image, axis=0)
144
+ image = np.array(image).astype(np.float32) / 255.0
145
+ image = image.transpose(0, 3, 1, 2)
146
+ image = 2.0 * image - 1.0
147
+ image = torch.from_numpy(image)
148
+ elif isinstance(image[0], torch.Tensor):
149
+ image = torch.cat(image, dim=0)
150
+ return image