kai-2054 commited on
Commit
ba94f88
·
verified ·
1 Parent(s): 6eb7da6

Create run_edit.py

Browse files
Files changed (1) hide show
  1. run_edit.py +287 -0
run_edit.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import io
4
+ import math
5
+ import sys
6
+ import tempfile
7
+
8
+ from PIL import Image, ImageOps
9
+ import requests
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from torchvision import transforms
14
+ from torchvision.transforms import functional as TF
15
+ from tqdm.notebook import tqdm
16
+
17
+ import numpy as np
18
+
19
+ from math import log2, sqrt
20
+
21
+ import argparse
22
+ import pickle
23
+
24
+
25
+
26
+
27
+ ################################### mask_fusion ######################################
28
+ from util.metrics_accumulator import MetricsAccumulator
29
+ metrics_accumulator = MetricsAccumulator()
30
+
31
+ from pathlib import Path
32
+ from PIL import Image
33
+ ################################### mask_fusion ######################################
34
+
35
+ import clip
36
+ import lpips
37
+ from torch.nn.functional import mse_loss
38
+
39
+ ################################### CLIPseg ######################################
40
+ from torchvision import utils as vutils
41
+ import cv2
42
+
43
+ ################################### CLIPseg ######################################
44
+
45
+ def str2bool(x):
46
+ return x.lower() in ('true')
47
+
48
+ USE_CPU = False
49
+ device = torch.device('cuda:0' if (torch.cuda.is_available() and not USE_CPU) else 'cpu')
50
+
51
+
52
+ def fetch(url_or_path):
53
+ if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
54
+ r = requests.get(url_or_path)
55
+ r.raise_for_status()
56
+ fd = io.BytesIO()
57
+ fd.write(r.content)
58
+ fd.seek(0)
59
+ return fd
60
+ return open(url_or_path, 'rb')
61
+
62
+
63
+ class MakeCutouts(nn.Module):
64
+ def __init__(self, cut_size, cutn, cut_pow=1.):
65
+ super().__init__()
66
+
67
+ self.cut_size = cut_size
68
+ self.cutn = cutn
69
+ self.cut_pow = cut_pow
70
+
71
+ def forward(self, input):
72
+ sideY, sideX = input.shape[2:4]
73
+ max_size = min(sideX, sideY)
74
+ min_size = min(sideX, sideY, self.cut_size)
75
+ cutouts = []
76
+ for _ in range(self.cutn):
77
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
78
+ offsetx = torch.randint(0, sideX - size + 1, ())
79
+ offsety = torch.randint(0, sideY - size + 1, ())
80
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
81
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
82
+ return torch.cat(cutouts)
83
+
84
+ def spherical_dist_loss(x, y):
85
+ x = F.normalize(x, dim=-1)
86
+ y = F.normalize(y, dim=-1)
87
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
88
+
89
+
90
+ def do_run(
91
+ arg_seed, arg_text, arg_batch_size, arg_num_batches, arg_negative, arg_cutn, arg_edit, arg_height, arg_width,
92
+ arg_edit_y, arg_edit_x, arg_edit_width, arg_edit_height, mask, arg_guidance_scale, arg_background_preservation_loss,
93
+ arg_lpips_sim_lambda, arg_l2_sim_lambda, arg_ddpm, arg_ddim, arg_enforce_background, arg_clip_guidance_scale,
94
+ arg_clip_guidance, model_params, model, diffusion, ldm, bert, clip_model
95
+ ):
96
+ normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
97
+
98
+ if arg_seed >= 0:
99
+ torch.manual_seed(arg_seed)
100
+
101
+ text_emb = bert.encode([arg_text] * arg_batch_size).to(device).float()
102
+ text_blank = bert.encode([arg_negative] * arg_batch_size).to(device).float()
103
+
104
+ text = clip.tokenize([arg_text] * arg_batch_size, truncate=True).to(device)
105
+ text_clip_blank = clip.tokenize([arg_negative] * arg_batch_size, truncate=True).to(device)
106
+
107
+
108
+
109
+ text_emb_clip = clip_model.encode_text(text)
110
+ text_emb_clip_blank = clip_model.encode_text(text_clip_blank)
111
+ make_cutouts = MakeCutouts(clip_model.visual.input_resolution, arg_cutn)
112
+ text_emb_norm = text_emb_clip[0] / text_emb_clip[0].norm(dim=-1, keepdim=True)
113
+ image_embed = None
114
+
115
+ if arg_edit:
116
+ w = arg_edit_width if arg_edit_width else arg_width
117
+ h = arg_edit_height if arg_edit_height else arg_height
118
+
119
+ arg_edit = arg_edit.convert('RGB')
120
+ input_image_pil = arg_edit
121
+
122
+ init_image_pil = input_image_pil.resize((arg_height, arg_width), Image.Resampling.LANCZOS)
123
+
124
+ input_image_pil = ImageOps.fit(input_image_pil, (w, h))
125
+
126
+ im = transforms.ToTensor()(input_image_pil).unsqueeze(0).to(device)
127
+
128
+ init_image = (TF.to_tensor(init_image_pil).to(device).unsqueeze(0).mul(2).sub(1))
129
+
130
+ im = 2*im-1
131
+ im = ldm.encode(im).sample()
132
+
133
+ y = arg_edit_y//8
134
+ x = arg_edit_x//8
135
+
136
+ input_image = torch.zeros(1, 4, arg_height//8, arg_width//8, device=device)
137
+
138
+ ycrop = y + im.shape[2] - input_image.shape[2]
139
+ xcrop = x + im.shape[3] - input_image.shape[3]
140
+
141
+ ycrop = ycrop if ycrop > 0 else 0
142
+ xcrop = xcrop if xcrop > 0 else 0
143
+
144
+ input_image[0,:,y if y >=0 else 0:y+im.shape[2],x if x >=0 else 0:x+im.shape[3]] = im[:,:,0 if y > 0 else -y:im.shape[2]-ycrop,0 if x > 0 else -x:im.shape[3]-xcrop]
145
+
146
+ input_image_pil = ldm.decode(input_image)
147
+ input_image_pil = TF.to_pil_image(input_image_pil.squeeze(0).add(1).div(2).clamp(0, 1))
148
+
149
+ input_image *= 0.18215
150
+
151
+ new_mask = TF.resize(mask.unsqueeze(0).unsqueeze(0).to(device), (arg_width//8, arg_height//8))
152
+
153
+ mask1 = (new_mask > 0.5)
154
+ mask1 = mask1.float()
155
+
156
+ input_image *= mask1
157
+
158
+ image_embed = torch.cat(arg_batch_size*2*[input_image], dim=0).float()
159
+ elif model_params['image_condition']:
160
+ # using inpaint model but no image is provided
161
+ image_embed = torch.zeros(arg_batch_size*2, 4, arg_height//8, arg_width//8, device=device)
162
+
163
+ kwargs = {
164
+ "context": torch.cat([text_emb, text_blank], dim=0).float(),
165
+ "clip_embed": torch.cat([text_emb_clip, text_emb_clip_blank], dim=0).float() if model_params['clip_embed_dim'] else None,
166
+ "image_embed": image_embed
167
+ }
168
+
169
+ # Create a classifier-free guidance sampling function
170
+ def model_fn(x_t, ts, **kwargs):
171
+ half = x_t[: len(x_t) // 2]
172
+ combined = torch.cat([half, half], dim=0)
173
+ model_out = model(combined, ts, **kwargs)
174
+ eps, rest = model_out[:, :3], model_out[:, 3:]
175
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
176
+ half_eps = uncond_eps + arg_guidance_scale * (cond_eps - uncond_eps)
177
+ eps = torch.cat([half_eps, half_eps], dim=0)
178
+ return torch.cat([eps, rest], dim=1)
179
+
180
+ cur_t = None
181
+
182
+ @torch.no_grad()
183
+ def postprocess_fn(out, t):
184
+ if mask is not None:
185
+ background_stage_t = diffusion.q_sample(init_image, t[0])
186
+ background_stage_t = torch.tile(
187
+ background_stage_t, dims=(arg_batch_size, 1, 1, 1)
188
+ )
189
+ out["sample"] = out["sample"] * mask + background_stage_t * (1 - mask)
190
+ return out
191
+
192
+ # if arg_ddpm:
193
+ # sample_fn = diffusion.p_sample_loop_progressive
194
+ # elif arg_ddim:
195
+ # sample_fn = diffusion.ddim_sample_loop_progressive
196
+ # else:
197
+ sample_fn = diffusion.plms_sample_loop_progressive
198
+
199
+ def save_sample(i, sample):
200
+ out_ims = []
201
+ for k, image in enumerate(sample['pred_xstart'][:arg_batch_size]):
202
+ image /= 0.18215
203
+ im = image.unsqueeze(0)
204
+ out = ldm.decode(im)
205
+ metrics_accumulator.print_average_metric()
206
+
207
+ for b in range(arg_batch_size):
208
+ pred_image = sample["pred_xstart"][b]
209
+
210
+ if arg_enforce_background:
211
+ new_mask = TF.resize(mask.unsqueeze(0).unsqueeze(0).to(device), (arg_width, arg_height))
212
+ pred_image = (
213
+ init_image[0] * new_mask[0] + out * (1 - new_mask[0])
214
+ )
215
+
216
+ pred_image_pil = TF.to_pil_image(pred_image.squeeze(0).add(1).div(2).clamp(0, 1))
217
+ out_ims.append(pred_image_pil)
218
+ return out_ims
219
+
220
+
221
+ all_saved_ims = []
222
+ for i in range(arg_num_batches):
223
+ cur_t = diffusion.num_timesteps - 1
224
+
225
+ samples = sample_fn(
226
+ model_fn,
227
+ (arg_batch_size*2, 4, int(arg_height//8), int(arg_width//8)),
228
+ clip_denoised=False,
229
+ model_kwargs=kwargs,
230
+ cond_fn=None,
231
+ device=device,
232
+ progress=True,
233
+ )
234
+
235
+ for j, sample in enumerate(samples):
236
+ cur_t -= 1
237
+ if j % 5 == 0 and j != diffusion.num_timesteps - 1:
238
+ all_saved_ims += save_sample(i, sample)
239
+ all_saved_ims += save_sample(i, sample)
240
+
241
+ return all_saved_ims
242
+
243
+ def run_model(
244
+ segmodel, model, diffusion, ldm, bert, clip_model, model_params,
245
+ from_text, instruction, negative_prompt, original_img, seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
246
+ ):
247
+ input_image = original_img
248
+
249
+ transform = transforms.Compose([
250
+ transforms.ToTensor(),
251
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
252
+ transforms.Resize((256, 256)),
253
+ ])
254
+ img = transform(input_image).unsqueeze(0)
255
+
256
+ with torch.no_grad():
257
+ preds = segmodel(img.repeat(1,1,1,1), from_text)[0]
258
+
259
+ mask = torch.sigmoid(preds[0][0])
260
+ image = (mask.detach().cpu().numpy() * 255).astype(np.uint8) # cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
261
+ ret, thresh = cv2.threshold(image, 100, 255, cv2.THRESH_TRUNC, image)
262
+ timg = np.array(thresh)
263
+ x, y = timg.shape
264
+ for row in range(x):
265
+ for col in range(y):
266
+ if (timg[row][col]) == 100:
267
+ timg[row][col] = 255
268
+ if (timg[row][col]) < 100:
269
+ timg[row][col] = 0
270
+
271
+ fulltensor = torch.full_like(mask, fill_value=255)
272
+ bgtensor = fulltensor-timg
273
+ mask = bgtensor / 255.0
274
+
275
+ gc.collect()
276
+ use_ddim = False
277
+ use_ddpm = False
278
+ all_saved_ims = do_run(
279
+ seed, instruction, 1, 1, negative_prompt, cutn, input_image, 256, 256,
280
+ 0, 0, 0, 0, mask, guidance_scale, True,
281
+ 1000, l2_sim_lambda, use_ddpm, use_ddim, True, clip_guidance_scale, False,
282
+ model_params, model, diffusion, ldm, bert, clip_model
283
+ )
284
+
285
+ return all_saved_ims[-1]
286
+
287
+