Linoy Tsaban
commited on
Commit
·
17db690
1
Parent(s):
3fcb5ce
Update inversion_utils.py
Browse files- inversion_utils.py +6 -22
inversion_utils.py
CHANGED
@@ -29,27 +29,11 @@ def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
|
|
29 |
image = image[offset:offset + w]
|
30 |
image = np.array(Image.fromarray(image).resize((512, 512)))
|
31 |
image = torch.from_numpy(image).float() / 127.5 - 1
|
32 |
-
image = image.permute(2, 0, 1).unsqueeze(0).to(device)
|
33 |
|
34 |
return image
|
35 |
|
36 |
|
37 |
-
def load_real_image(folder = "data/", img_name = None, idx = 0, img_size=512, device='cuda'):
|
38 |
-
from PIL import Image
|
39 |
-
from glob import glob
|
40 |
-
if img_name is not None:
|
41 |
-
path = os.path.join(folder, img_name)
|
42 |
-
else:
|
43 |
-
path = glob(folder + "*")[idx]
|
44 |
-
|
45 |
-
img = Image.open(path).resize((img_size,
|
46 |
-
img_size))
|
47 |
-
|
48 |
-
img = pil_to_tensor(img).to(device)
|
49 |
-
|
50 |
-
if img.shape[1]== 4:
|
51 |
-
img = img[:,:3,:,:]
|
52 |
-
return img
|
53 |
|
54 |
def mu_tilde(model, xt,x0, timestep):
|
55 |
"mu_tilde(x_t, x_0) DDPM paper eq. 7"
|
@@ -77,10 +61,10 @@ def sample_xts_from_x0(model, x0, num_inference_steps=50):
|
|
77 |
|
78 |
timesteps = model.scheduler.timesteps.to(model.device)
|
79 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
80 |
-
xts = torch.zeros(variance_noise_shape).to(x0.device)
|
81 |
for t in reversed(timesteps):
|
82 |
idx = t_to_idx[int(t)]
|
83 |
-
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
|
84 |
xts = torch.cat([xts, x0 ],dim = 0)
|
85 |
|
86 |
return xts
|
@@ -151,7 +135,7 @@ def inversion_forward_process(model, x0,
|
|
151 |
if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
|
152 |
xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
|
153 |
alpha_bar = model.scheduler.alphas_cumprod
|
154 |
-
zs = torch.zeros(size=variance_noise_shape, device=model.device)
|
155 |
|
156 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
157 |
xt = x0
|
@@ -230,7 +214,7 @@ def reverse_step(model, model_output, timestep, sample, eta = 0, variance_noise=
|
|
230 |
# 8. Add noice if eta > 0
|
231 |
if eta > 0:
|
232 |
if variance_noise is None:
|
233 |
-
variance_noise = torch.randn(model_output.shape, device=model.device)
|
234 |
sigma_z = eta * variance ** (0.5) * variance_noise
|
235 |
prev_sample = prev_sample + sigma_z
|
236 |
|
@@ -248,7 +232,7 @@ def inversion_reverse_process(model,
|
|
248 |
|
249 |
batch_size = len(prompts)
|
250 |
|
251 |
-
cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device)
|
252 |
|
253 |
text_embeddings = encode_text(model, prompts)
|
254 |
uncond_embedding = encode_text(model, [""] * batch_size)
|
|
|
29 |
image = image[offset:offset + w]
|
30 |
image = np.array(Image.fromarray(image).resize((512, 512)))
|
31 |
image = torch.from_numpy(image).float() / 127.5 - 1
|
32 |
+
image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype =torch.float16)
|
33 |
|
34 |
return image
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def mu_tilde(model, xt,x0, timestep):
|
39 |
"mu_tilde(x_t, x_0) DDPM paper eq. 7"
|
|
|
61 |
|
62 |
timesteps = model.scheduler.timesteps.to(model.device)
|
63 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
64 |
+
xts = torch.zeros(variance_noise_shape).to(x0.device, dtype =torch.float16)
|
65 |
for t in reversed(timesteps):
|
66 |
idx = t_to_idx[int(t)]
|
67 |
+
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0, dtype =torch.float16) * sqrt_one_minus_alpha_bar[t]
|
68 |
xts = torch.cat([xts, x0 ],dim = 0)
|
69 |
|
70 |
return xts
|
|
|
135 |
if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
|
136 |
xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
|
137 |
alpha_bar = model.scheduler.alphas_cumprod
|
138 |
+
zs = torch.zeros(size=variance_noise_shape, device=model.device, dtype =torch.float16)
|
139 |
|
140 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
141 |
xt = x0
|
|
|
214 |
# 8. Add noice if eta > 0
|
215 |
if eta > 0:
|
216 |
if variance_noise is None:
|
217 |
+
variance_noise = torch.randn(model_output.shape, device=model.device, dtype =torch.float16)
|
218 |
sigma_z = eta * variance ** (0.5) * variance_noise
|
219 |
prev_sample = prev_sample + sigma_z
|
220 |
|
|
|
232 |
|
233 |
batch_size = len(prompts)
|
234 |
|
235 |
+
cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device, dtype=torch.float16)
|
236 |
|
237 |
text_embeddings = encode_text(model, prompts)
|
238 |
uncond_embedding = encode_text(model, [""] * batch_size)
|