Linoy Tsaban commited on
Commit
17db690
·
1 Parent(s): 3fcb5ce

Update inversion_utils.py

Browse files
Files changed (1) hide show
  1. inversion_utils.py +6 -22
inversion_utils.py CHANGED
@@ -29,27 +29,11 @@ def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
29
  image = image[offset:offset + w]
30
  image = np.array(Image.fromarray(image).resize((512, 512)))
31
  image = torch.from_numpy(image).float() / 127.5 - 1
32
- image = image.permute(2, 0, 1).unsqueeze(0).to(device)
33
 
34
  return image
35
 
36
 
37
- def load_real_image(folder = "data/", img_name = None, idx = 0, img_size=512, device='cuda'):
38
- from PIL import Image
39
- from glob import glob
40
- if img_name is not None:
41
- path = os.path.join(folder, img_name)
42
- else:
43
- path = glob(folder + "*")[idx]
44
-
45
- img = Image.open(path).resize((img_size,
46
- img_size))
47
-
48
- img = pil_to_tensor(img).to(device)
49
-
50
- if img.shape[1]== 4:
51
- img = img[:,:3,:,:]
52
- return img
53
 
54
  def mu_tilde(model, xt,x0, timestep):
55
  "mu_tilde(x_t, x_0) DDPM paper eq. 7"
@@ -77,10 +61,10 @@ def sample_xts_from_x0(model, x0, num_inference_steps=50):
77
 
78
  timesteps = model.scheduler.timesteps.to(model.device)
79
  t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
80
- xts = torch.zeros(variance_noise_shape).to(x0.device)
81
  for t in reversed(timesteps):
82
  idx = t_to_idx[int(t)]
83
- xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
84
  xts = torch.cat([xts, x0 ],dim = 0)
85
 
86
  return xts
@@ -151,7 +135,7 @@ def inversion_forward_process(model, x0,
151
  if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
152
  xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
153
  alpha_bar = model.scheduler.alphas_cumprod
154
- zs = torch.zeros(size=variance_noise_shape, device=model.device)
155
 
156
  t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
157
  xt = x0
@@ -230,7 +214,7 @@ def reverse_step(model, model_output, timestep, sample, eta = 0, variance_noise=
230
  # 8. Add noice if eta > 0
231
  if eta > 0:
232
  if variance_noise is None:
233
- variance_noise = torch.randn(model_output.shape, device=model.device)
234
  sigma_z = eta * variance ** (0.5) * variance_noise
235
  prev_sample = prev_sample + sigma_z
236
 
@@ -248,7 +232,7 @@ def inversion_reverse_process(model,
248
 
249
  batch_size = len(prompts)
250
 
251
- cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device)
252
 
253
  text_embeddings = encode_text(model, prompts)
254
  uncond_embedding = encode_text(model, [""] * batch_size)
 
29
  image = image[offset:offset + w]
30
  image = np.array(Image.fromarray(image).resize((512, 512)))
31
  image = torch.from_numpy(image).float() / 127.5 - 1
32
+ image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype =torch.float16)
33
 
34
  return image
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def mu_tilde(model, xt,x0, timestep):
39
  "mu_tilde(x_t, x_0) DDPM paper eq. 7"
 
61
 
62
  timesteps = model.scheduler.timesteps.to(model.device)
63
  t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
64
+ xts = torch.zeros(variance_noise_shape).to(x0.device, dtype =torch.float16)
65
  for t in reversed(timesteps):
66
  idx = t_to_idx[int(t)]
67
+ xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0, dtype =torch.float16) * sqrt_one_minus_alpha_bar[t]
68
  xts = torch.cat([xts, x0 ],dim = 0)
69
 
70
  return xts
 
135
  if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
136
  xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
137
  alpha_bar = model.scheduler.alphas_cumprod
138
+ zs = torch.zeros(size=variance_noise_shape, device=model.device, dtype =torch.float16)
139
 
140
  t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
141
  xt = x0
 
214
  # 8. Add noice if eta > 0
215
  if eta > 0:
216
  if variance_noise is None:
217
+ variance_noise = torch.randn(model_output.shape, device=model.device, dtype =torch.float16)
218
  sigma_z = eta * variance ** (0.5) * variance_noise
219
  prev_sample = prev_sample + sigma_z
220
 
 
232
 
233
  batch_size = len(prompts)
234
 
235
+ cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device, dtype=torch.float16)
236
 
237
  text_embeddings = encode_text(model, prompts)
238
  uncond_embedding = encode_text(model, [""] * batch_size)