Linoy Tsaban commited on
Commit
1a2c8b5
1 Parent(s): 8832b9b

Create preprocess_utils.py

Browse files
Files changed (1) hide show
  1. preprocess_utils.py +179 -0
preprocess_utils.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
2
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
3
+ # suppress partial model loading warning
4
+ logging.set_verbosity_error()
5
+
6
+ import os
7
+ from tqdm import tqdm, trange
8
+ import torch
9
+ import torch.nn as nn
10
+ import argparse
11
+ from torchvision.io import write_video
12
+ from pathlib import Path
13
+ from util import *
14
+ import torchvision.transforms as T
15
+
16
+
17
+ def get_timesteps(scheduler, num_inference_steps, strength, device):
18
+ # get the original timestep using init_timestep
19
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
20
+
21
+ t_start = max(num_inference_steps - init_timestep, 0)
22
+ timesteps = scheduler.timesteps[t_start:]
23
+
24
+ return timesteps, num_inference_steps - t_start
25
+
26
+ @torch.no_grad()
27
+ def decode_latents(pipe, latents):
28
+ decoded = []
29
+ batch_size = 8
30
+ for b in range(0, latents.shape[0], batch_size):
31
+ latents_batch = 1 / 0.18215 * latents[b:b + batch_size]
32
+ imgs = pipe.vae.decode(latents_batch).sample
33
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
34
+ decoded.append(imgs)
35
+ return torch.cat(decoded)
36
+
37
+ @torch.no_grad()
38
+ def ddim_inversion(pipe, cond, latent_frames, batch_size, save_latents=True, timesteps_to_save=None):
39
+
40
+ timesteps = reversed(pipe.scheduler.timesteps)
41
+ timesteps_to_save = timesteps_to_save if timesteps_to_save is not None else timesteps
42
+ for i, t in enumerate(tqdm(timesteps)):
43
+ for b in range(0, latent_frames.shape[0], batch_size):
44
+ x_batch = latent_frames[b:b + batch_size]
45
+ model_input = x_batch
46
+ cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
47
+ #remove comment from commented block to support controlnet
48
+ # if self.sd_version == 'depth':
49
+ # depth_maps = torch.cat([self.depth_maps[b: b + batch_size]])
50
+ # model_input = torch.cat([x_batch, depth_maps],dim=1)
51
+
52
+ alpha_prod_t = pipe.scheduler.alphas_cumprod[t]
53
+ alpha_prod_t_prev = (
54
+ pipe.scheduler.alphas_cumprod[timesteps[i - 1]]
55
+ if i > 0 else pipe.scheduler.final_alpha_cumprod
56
+ )
57
+
58
+ mu = alpha_prod_t ** 0.5
59
+ mu_prev = alpha_prod_t_prev ** 0.5
60
+ sigma = (1 - alpha_prod_t) ** 0.5
61
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
62
+
63
+
64
+ #remove line below and replace with commented block to support controlnet
65
+ eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
66
+ # if self.sd_version != 'ControlNet':
67
+ # eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
68
+ # else:
69
+ # eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]]))
70
+
71
+ pred_x0 = (x_batch - sigma_prev * eps) / mu_prev
72
+ latent_frames[b:b + batch_size] = mu * pred_x0 + sigma * eps
73
+
74
+ # if save_latents and t in timesteps_to_save:
75
+ # torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
76
+ # torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
77
+ return latent_frames
78
+
79
+ @torch.no_grad()
80
+ def ddim_sample(pipe, x, cond, batch_size):
81
+ timesteps = pipe.scheduler.timesteps
82
+ for i, t in enumerate(tqdm(timesteps)):
83
+ for b in range(0, x.shape[0], batch_size):
84
+ x_batch = x[b:b + batch_size]
85
+ model_input = x_batch
86
+ cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
87
+
88
+ #remove comment from commented block to support controlnet
89
+ # if self.sd_version == 'depth':
90
+ # depth_maps = torch.cat([self.depth_maps[b: b + batch_size]])
91
+ # model_input = torch.cat([x_batch, depth_maps],dim=1)
92
+
93
+ alpha_prod_t = pipe.scheduler.alphas_cumprod[t]
94
+ alpha_prod_t_prev = (
95
+ pipe.scheduler.alphas_cumprod[timesteps[i + 1]]
96
+ if i < len(timesteps) - 1
97
+ else pipe.scheduler.final_alpha_cumprod
98
+ )
99
+ mu = alpha_prod_t ** 0.5
100
+ sigma = (1 - alpha_prod_t) ** 0.5
101
+ mu_prev = alpha_prod_t_prev ** 0.5
102
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
103
+
104
+ #remove line below and replace with commented block to support controlnet
105
+ eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
106
+ # if self.sd_version != 'ControlNet':
107
+ # eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
108
+ # else:
109
+ # eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]]))
110
+
111
+ pred_x0 = (x_batch - sigma * eps) / mu
112
+ x[b:b + batch_size] = mu_prev * pred_x0 + sigma_prev * eps
113
+ return x
114
+
115
+
116
+ @torch.no_grad()
117
+ def get_text_embeds(pipe, prompt, negative_prompt, batch_size=1, device="cuda"):
118
+ # Tokenize text and get embeddings
119
+ text_input = pipe.tokenizer(prompt, padding='max_length', max_length=pipe.tokenizer.model_max_length,
120
+ truncation=True, return_tensors='pt')
121
+ text_embeddings = pipe.text_encoder(text_input.input_ids.to(pipe.device))[0]
122
+
123
+ # Do the same for unconditional embeddings
124
+ uncond_input = pipe.tokenizer(negative_prompt, padding='max_length', max_length=pipe.tokenizer.model_max_length,
125
+ return_tensors='pt')
126
+
127
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]
128
+
129
+ # Cat for final embeddings
130
+ text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size)
131
+ return text_embeddings
132
+
133
+ @torch.no_grad()
134
+ def extract_latents(pipe,
135
+ num_steps,
136
+ latent_frames,
137
+ batch_size,
138
+ timesteps_to_save,
139
+ inversion_prompt=''):
140
+ pipe.scheduler.set_timesteps(num_steps)
141
+ cond = get_text_embeds(pipe, inversion_prompt, "", device=pipe.device)[1].unsqueeze(0)
142
+ # latent_frames = self.latents
143
+
144
+ inverted_latents = ddim_inversion(pipe, cond,
145
+ latent_frames,
146
+ batch_size=batch_size,
147
+ save_latents=False,
148
+ timesteps_to_save=timesteps_to_save)
149
+
150
+ # latent_reconstruction = ddim_sample(pipe, inverted_latents, cond, batch_size=batch_size)
151
+
152
+ # rgb_reconstruction = decode_latents(pipe, latent_reconstruction)
153
+
154
+ # return rgb_reconstruction
155
+ return inverted_latents
156
+
157
+ @torch.no_grad()
158
+ def encode_imgs(pipe, imgs, batch_size=10, deterministic=True):
159
+ imgs = 2 * imgs - 1
160
+ latents = []
161
+ for i in range(0, len(imgs), batch_size):
162
+ posterior = pipe.vae.encode(imgs[i:i + batch_size]).latent_dist
163
+ latent = posterior.mean if deterministic else posterior.sample()
164
+ latents.append(latent * 0.18215)
165
+ latents = torch.cat(latents)
166
+ return latents
167
+
168
+ def get_data(pipe, frames, n_frames):
169
+ """
170
+ converts frames to tensors, saves to device and encodes to obtain latents
171
+ """
172
+ frames = frames[:n_frames]
173
+ if frames[0].size[0] == frames[0].size[1]:
174
+ frames = [frame.convert("RGB").resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames]
175
+ stacked_tensor_frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(torch.float16).to(pipe.device)
176
+ # encode to latents
177
+ latents = encode_imgs(pipe, stacked_tensor_frames, deterministic=True).to(torch.float16).to(pipe.device)
178
+ return stacked_tensor_frames, latents
179
+