Linoy Tsaban commited on
Commit
acc80f0
·
1 Parent(s): a62f21a

Create inversion_utils.py

Browse files
Files changed (1) hide show
  1. inversion_utils.py +291 -0
inversion_utils.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from tqdm import tqdm
4
+ from PIL import Image, ImageDraw ,ImageFont
5
+ from matplotlib import pyplot as plt
6
+ import torchvision.transforms as T
7
+ import os
8
+ import yaml
9
+ import numpy as np
10
+
11
+
12
+ def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
13
+ if type(image_path) is str:
14
+ image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3]
15
+ else:
16
+ image = image_path
17
+ h, w, c = image.shape
18
+ left = min(left, w-1)
19
+ right = min(right, w - left - 1)
20
+ top = min(top, h - left - 1)
21
+ bottom = min(bottom, h - top - 1)
22
+ image = image[top:h-bottom, left:w-right]
23
+ h, w, c = image.shape
24
+ if h < w:
25
+ offset = (w - h) // 2
26
+ image = image[:, offset:offset + h]
27
+ elif w < h:
28
+ offset = (h - w) // 2
29
+ image = image[offset:offset + w]
30
+ image = np.array(Image.fromarray(image).resize((512, 512)))
31
+ image = torch.from_numpy(image).float() / 127.5 - 1
32
+ image = image.permute(2, 0, 1).unsqueeze(0).to(device)
33
+
34
+ return image
35
+
36
+
37
+ def load_real_image(folder = "data/", img_name = None, idx = 0, img_size=512, device='cuda'):
38
+ from PIL import Image
39
+ from glob import glob
40
+ if img_name is not None:
41
+ path = os.path.join(folder, img_name)
42
+ else:
43
+ path = glob(folder + "*")[idx]
44
+
45
+ img = Image.open(path).resize((img_size,
46
+ img_size))
47
+
48
+ img = pil_to_tensor(img).to(device)
49
+
50
+ if img.shape[1]== 4:
51
+ img = img[:,:3,:,:]
52
+ return img
53
+
54
+ def mu_tilde(model, xt,x0, timestep):
55
+ "mu_tilde(x_t, x_0) DDPM paper eq. 7"
56
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
57
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
58
+ alpha_t = model.scheduler.alphas[timestep]
59
+ beta_t = 1 - alpha_t
60
+ alpha_bar = model.scheduler.alphas_cumprod[timestep]
61
+ return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + ((alpha_t**0.5 *(1-alpha_prod_t_prev)) / (1- alpha_bar))*xt
62
+
63
+ def sample_xts_from_x0(model, x0, num_inference_steps=50):
64
+ """
65
+ Samples from P(x_1:T|x_0)
66
+ """
67
+ # torch.manual_seed(43256465436)
68
+ alpha_bar = model.scheduler.alphas_cumprod
69
+ sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
70
+ alphas = model.scheduler.alphas
71
+ betas = 1 - alphas
72
+ variance_noise_shape = (
73
+ num_inference_steps,
74
+ model.unet.in_channels,
75
+ model.unet.sample_size,
76
+ model.unet.sample_size)
77
+
78
+ timesteps = model.scheduler.timesteps.to(model.device)
79
+ t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
80
+ xts = torch.zeros(variance_noise_shape).to(x0.device)
81
+ for t in reversed(timesteps):
82
+ idx = t_to_idx[int(t)]
83
+ xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
84
+ xts = torch.cat([xts, x0 ],dim = 0)
85
+
86
+ return xts
87
+
88
+ def encode_text(model, prompts):
89
+ text_input = model.tokenizer(
90
+ prompts,
91
+ padding="max_length",
92
+ max_length=model.tokenizer.model_max_length,
93
+ truncation=True,
94
+ return_tensors="pt",
95
+ )
96
+ with torch.no_grad():
97
+ text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
98
+ return text_encoding
99
+
100
+ def forward_step(model, model_output, timestep, sample):
101
+ next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
102
+ timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
103
+
104
+ # 2. compute alphas, betas
105
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
106
+ # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod
107
+
108
+ beta_prod_t = 1 - alpha_prod_t
109
+
110
+ # 3. compute predicted original sample from predicted noise also called
111
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
112
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
113
+
114
+ # 5. TODO: simple noising implementatiom
115
+ next_sample = model.scheduler.add_noise(pred_original_sample,
116
+ model_output,
117
+ torch.LongTensor([next_timestep]))
118
+ return next_sample
119
+
120
+
121
+ def get_variance(model, timestep): #, prev_timestep):
122
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
123
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
124
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
125
+ beta_prod_t = 1 - alpha_prod_t
126
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
127
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
128
+ return variance
129
+
130
+ def inversion_forward_process(model, x0,
131
+ etas = None,
132
+ prog_bar = False,
133
+ prompt = "",
134
+ cfg_scale = 3.5,
135
+ num_inference_steps=50, eps = None):
136
+
137
+ if not prompt=="":
138
+ text_embeddings = encode_text(model, prompt)
139
+ uncond_embedding = encode_text(model, "")
140
+ timesteps = model.scheduler.timesteps.to(model.device)
141
+ variance_noise_shape = (
142
+ num_inference_steps,
143
+ model.unet.in_channels,
144
+ model.unet.sample_size,
145
+ model.unet.sample_size)
146
+ if etas is None or (type(etas) in [int, float] and etas == 0):
147
+ eta_is_zero = True
148
+ zs = None
149
+ else:
150
+ eta_is_zero = False
151
+ if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
152
+ xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
153
+ alpha_bar = model.scheduler.alphas_cumprod
154
+ zs = torch.zeros(size=variance_noise_shape, device=model.device)
155
+
156
+ t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
157
+ xt = x0
158
+ op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
159
+
160
+ for t in op:
161
+ idx = t_to_idx[int(t)]
162
+ # 1. predict noise residual
163
+ if not eta_is_zero:
164
+ xt = xts[idx][None]
165
+
166
+ with torch.no_grad():
167
+ out = model.unet.forward(xt, timestep = t, encoder_hidden_states = uncond_embedding)
168
+ if not prompt=="":
169
+ cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states = text_embeddings)
170
+
171
+ if not prompt=="":
172
+ ## classifier free guidance
173
+ noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
174
+ else:
175
+ noise_pred = out.sample
176
+
177
+ if eta_is_zero:
178
+ # 2. compute more noisy image and set x_t -> x_t+1
179
+ xt = forward_step(model, noise_pred, t, xt)
180
+
181
+ else:
182
+ xtm1 = xts[idx+1][None]
183
+ # pred of x0
184
+ pred_original_sample = (xt - (1-alpha_bar[t]) ** 0.5 * noise_pred ) / alpha_bar[t] ** 0.5
185
+
186
+ # direction to xt
187
+ prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
188
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
189
+
190
+ variance = get_variance(model, t)
191
+ pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance ) ** (0.5) * noise_pred
192
+
193
+ mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
194
+
195
+ z = (xtm1 - mu_xt ) / ( etas[idx] * variance ** 0.5 )
196
+ zs[idx] = z
197
+
198
+ # correction to avoid error accumulation
199
+ xtm1 = mu_xt + ( etas[idx] * variance ** 0.5 )*z
200
+ xts[idx+1] = xtm1
201
+
202
+ if not zs is None:
203
+ zs[-1] = torch.zeros_like(zs[-1])
204
+
205
+ return xt, zs, xts
206
+
207
+
208
+ def reverse_step(model, model_output, timestep, sample, eta = 0, variance_noise=None):
209
+ # 1. get previous step value (=t-1)
210
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
211
+ # 2. compute alphas, betas
212
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
213
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
214
+ beta_prod_t = 1 - alpha_prod_t
215
+ # 3. compute predicted original sample from predicted noise also called
216
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
217
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
218
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
219
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
220
+ # variance = self.scheduler._get_variance(timestep, prev_timestep)
221
+ variance = get_variance(model, timestep) #, prev_timestep)
222
+ std_dev_t = eta * variance ** (0.5)
223
+ # Take care of asymetric reverse process (asyrp)
224
+ model_output_direction = model_output
225
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
226
+ # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
227
+ pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
228
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
229
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
230
+ # 8. Add noice if eta > 0
231
+ if eta > 0:
232
+ if variance_noise is None:
233
+ variance_noise = torch.randn(model_output.shape, device=model.device)
234
+ sigma_z = eta * variance ** (0.5) * variance_noise
235
+ prev_sample = prev_sample + sigma_z
236
+
237
+ return prev_sample
238
+
239
+ def inversion_reverse_process(model,
240
+ xT,
241
+ etas = 0,
242
+ prompts = "",
243
+ cfg_scales = None,
244
+ prog_bar = False,
245
+ zs = None,
246
+ controller=None,
247
+ asyrp = False):
248
+
249
+ batch_size = len(prompts)
250
+
251
+ cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device)
252
+
253
+ text_embeddings = encode_text(model, prompts)
254
+ uncond_embedding = encode_text(model, [""] * batch_size)
255
+
256
+ if etas is None: etas = 0
257
+ if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
258
+ assert len(etas) == model.scheduler.num_inference_steps
259
+ timesteps = model.scheduler.timesteps.to(model.device)
260
+
261
+ xt = xT.expand(batch_size, -1, -1, -1)
262
+ op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
263
+
264
+ t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
265
+
266
+ for t in op:
267
+ idx = t_to_idx[int(t)]
268
+ ## Unconditional embedding
269
+ with torch.no_grad():
270
+ uncond_out = model.unet.forward(xt, timestep = t,
271
+ encoder_hidden_states = uncond_embedding)
272
+
273
+ ## Conditional embedding
274
+ if prompts:
275
+ with torch.no_grad():
276
+ cond_out = model.unet.forward(xt, timestep = t,
277
+ encoder_hidden_states = text_embeddings)
278
+
279
+
280
+ z = zs[idx] if not zs is None else None
281
+ z = z.expand(batch_size, -1, -1, -1)
282
+ if prompts:
283
+ ## classifier free guidance
284
+ noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
285
+ else:
286
+ noise_pred = uncond_out.sample
287
+ # 2. compute less noisy image and set x_t -> x_t-1
288
+ xt = reverse_step(model, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
289
+ if controller is not None:
290
+ xt = controller.step_callback(xt)
291
+ return xt, zs