Spaces:
Runtime error
Runtime error
import gc | |
import os | |
import io | |
import math | |
import sys | |
import tempfile | |
from PIL import Image, ImageOps | |
import requests | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision import transforms | |
from torchvision.transforms import functional as TF | |
from tqdm.notebook import tqdm | |
import numpy as np | |
from math import log2, sqrt | |
import argparse | |
import pickle | |
################################### mask_fusion ###################################### | |
from util.metrics_accumulator import MetricsAccumulator | |
metrics_accumulator = MetricsAccumulator() | |
from pathlib import Path | |
from PIL import Image | |
################################### mask_fusion ###################################### | |
import clip | |
import lpips | |
from torch.nn.functional import mse_loss | |
################################### CLIPseg ###################################### | |
from torchvision import utils as vutils | |
import cv2 | |
################################### CLIPseg ###################################### | |
def str2bool(x): | |
return x.lower() in ('true') | |
USE_CPU = False | |
device = torch.device('cuda:0' if (torch.cuda.is_available() and not USE_CPU) else 'cpu') | |
def fetch(url_or_path): | |
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): | |
r = requests.get(url_or_path) | |
r.raise_for_status() | |
fd = io.BytesIO() | |
fd.write(r.content) | |
fd.seek(0) | |
return fd | |
return open(url_or_path, 'rb') | |
class MakeCutouts(nn.Module): | |
def __init__(self, cut_size, cutn, cut_pow=1.): | |
super().__init__() | |
self.cut_size = cut_size | |
self.cutn = cutn | |
self.cut_pow = cut_pow | |
def forward(self, input): | |
sideY, sideX = input.shape[2:4] | |
max_size = min(sideX, sideY) | |
min_size = min(sideX, sideY, self.cut_size) | |
cutouts = [] | |
for _ in range(self.cutn): | |
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) | |
offsetx = torch.randint(0, sideX - size + 1, ()) | |
offsety = torch.randint(0, sideY - size + 1, ()) | |
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | |
return torch.cat(cutouts) | |
def spherical_dist_loss(x, y): | |
x = F.normalize(x, dim=-1) | |
y = F.normalize(y, dim=-1) | |
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | |
def do_run( | |
arg_seed, arg_text, arg_batch_size, arg_num_batches, arg_negative, arg_cutn, arg_edit, arg_height, arg_width, | |
arg_edit_y, arg_edit_x, arg_edit_width, arg_edit_height, mask, arg_guidance_scale, arg_background_preservation_loss, | |
arg_lpips_sim_lambda, arg_l2_sim_lambda, arg_ddpm, arg_ddim, arg_enforce_background, arg_clip_guidance_scale, | |
arg_clip_guidance, model_params, model, diffusion, ldm, bert, clip_model | |
): | |
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) | |
if arg_seed >= 0: | |
torch.manual_seed(arg_seed) | |
text_emb = bert.encode([arg_text] * arg_batch_size).to(device).float() | |
text_blank = bert.encode([arg_negative] * arg_batch_size).to(device).float() | |
text = clip.tokenize([arg_text] * arg_batch_size, truncate=True).to(device) | |
text_clip_blank = clip.tokenize([arg_negative] * arg_batch_size, truncate=True).to(device) | |
text_emb_clip = clip_model.encode_text(text) | |
text_emb_clip_blank = clip_model.encode_text(text_clip_blank) | |
make_cutouts = MakeCutouts(clip_model.visual.input_resolution, arg_cutn) | |
text_emb_norm = text_emb_clip[0] / text_emb_clip[0].norm(dim=-1, keepdim=True) | |
image_embed = None | |
if arg_edit: | |
w = arg_edit_width if arg_edit_width else arg_width | |
h = arg_edit_height if arg_edit_height else arg_height | |
arg_edit = arg_edit.convert('RGB') | |
input_image_pil = arg_edit | |
init_image_pil = input_image_pil.resize((arg_height, arg_width), Image.Resampling.LANCZOS) | |
input_image_pil = ImageOps.fit(input_image_pil, (w, h)) | |
im = transforms.ToTensor()(input_image_pil).unsqueeze(0).to(device) | |
init_image = (TF.to_tensor(init_image_pil).to(device).unsqueeze(0).mul(2).sub(1)) | |
im = 2*im-1 | |
im = ldm.encode(im).sample() | |
y = arg_edit_y//8 | |
x = arg_edit_x//8 | |
input_image = torch.zeros(1, 4, arg_height//8, arg_width//8, device=device) | |
ycrop = y + im.shape[2] - input_image.shape[2] | |
xcrop = x + im.shape[3] - input_image.shape[3] | |
ycrop = ycrop if ycrop > 0 else 0 | |
xcrop = xcrop if xcrop > 0 else 0 | |
input_image[0,:,y if y >=0 else 0:y+im.shape[2],x if x >=0 else 0:x+im.shape[3]] = im[:,:,0 if y > 0 else -y:im.shape[2]-ycrop,0 if x > 0 else -x:im.shape[3]-xcrop] | |
input_image_pil = ldm.decode(input_image) | |
input_image_pil = TF.to_pil_image(input_image_pil.squeeze(0).add(1).div(2).clamp(0, 1)) | |
input_image *= 0.18215 | |
new_mask = TF.resize(mask.unsqueeze(0).unsqueeze(0).to(device), (arg_width//8, arg_height//8)) | |
mask1 = (new_mask > 0.5) | |
mask1 = mask1.float() | |
input_image *= mask1 | |
image_embed = torch.cat(arg_batch_size*2*[input_image], dim=0).float() | |
elif model_params['image_condition']: | |
# using inpaint model but no image is provided | |
image_embed = torch.zeros(arg_batch_size*2, 4, arg_height//8, arg_width//8, device=device) | |
kwargs = { | |
"context": torch.cat([text_emb, text_blank], dim=0).float(), | |
"clip_embed": torch.cat([text_emb_clip, text_emb_clip_blank], dim=0).float() if model_params['clip_embed_dim'] else None, | |
"image_embed": image_embed | |
} | |
# Create a classifier-free guidance sampling function | |
def model_fn(x_t, ts, **kwargs): | |
half = x_t[: len(x_t) // 2] | |
combined = torch.cat([half, half], dim=0) | |
model_out = model(combined, ts, **kwargs) | |
eps, rest = model_out[:, :3], model_out[:, 3:] | |
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) | |
half_eps = uncond_eps + arg_guidance_scale * (cond_eps - uncond_eps) | |
eps = torch.cat([half_eps, half_eps], dim=0) | |
return torch.cat([eps, rest], dim=1) | |
cur_t = None | |
def postprocess_fn(out, t): | |
if mask is not None: | |
background_stage_t = diffusion.q_sample(init_image, t[0]) | |
background_stage_t = torch.tile( | |
background_stage_t, dims=(arg_batch_size, 1, 1, 1) | |
) | |
out["sample"] = out["sample"] * mask + background_stage_t * (1 - mask) | |
return out | |
# if arg_ddpm: | |
# sample_fn = diffusion.p_sample_loop_progressive | |
# elif arg_ddim: | |
# sample_fn = diffusion.ddim_sample_loop_progressive | |
# else: | |
sample_fn = diffusion.plms_sample_loop_progressive | |
def save_sample(i, sample): | |
out_ims = [] | |
for k, image in enumerate(sample['pred_xstart'][:arg_batch_size]): | |
image /= 0.18215 | |
im = image.unsqueeze(0) | |
out = ldm.decode(im) | |
metrics_accumulator.print_average_metric() | |
for b in range(arg_batch_size): | |
pred_image = sample["pred_xstart"][b] | |
if arg_enforce_background: | |
new_mask = TF.resize(mask.unsqueeze(0).unsqueeze(0).to(device), (arg_width, arg_height)) | |
pred_image = ( | |
init_image[0] * new_mask[0] + out * (1 - new_mask[0]) | |
) | |
pred_image_pil = TF.to_pil_image(pred_image.squeeze(0).add(1).div(2).clamp(0, 1)) | |
out_ims.append(pred_image_pil) | |
return out_ims | |
all_saved_ims = [] | |
for i in range(arg_num_batches): | |
cur_t = diffusion.num_timesteps - 1 | |
samples = sample_fn( | |
model_fn, | |
(arg_batch_size*2, 4, int(arg_height//8), int(arg_width//8)), | |
clip_denoised=False, | |
model_kwargs=kwargs, | |
cond_fn=None, | |
device=device, | |
progress=True, | |
) | |
for j, sample in enumerate(samples): | |
cur_t -= 1 | |
if j % 5 == 0 and j != diffusion.num_timesteps - 1: | |
all_saved_ims += save_sample(i, sample) | |
all_saved_ims += save_sample(i, sample) | |
return all_saved_ims | |
def run_model( | |
segmodel, model, diffusion, ldm, bert, clip_model, model_params, | |
from_text, instruction, negative_prompt, original_img, seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda | |
): | |
input_image = original_img | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
transforms.Resize((256, 256)), | |
]) | |
img = transform(input_image).unsqueeze(0) | |
with torch.no_grad(): | |
preds = segmodel(img.repeat(1,1,1,1), from_text)[0] | |
mask = torch.sigmoid(preds[0][0]) | |
image = (mask.detach().cpu().numpy() * 255).astype(np.uint8) # cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
ret, thresh = cv2.threshold(image, 100, 255, cv2.THRESH_TRUNC, image) | |
timg = np.array(thresh) | |
x, y = timg.shape | |
for row in range(x): | |
for col in range(y): | |
if (timg[row][col]) == 100: | |
timg[row][col] = 255 | |
if (timg[row][col]) < 100: | |
timg[row][col] = 0 | |
fulltensor = torch.full_like(mask, fill_value=255) | |
bgtensor = fulltensor-timg | |
mask = bgtensor / 255.0 | |
gc.collect() | |
use_ddim = False | |
use_ddpm = False | |
all_saved_ims = do_run( | |
seed, instruction, 1, 1, negative_prompt, cutn, input_image, 256, 256, | |
0, 0, 0, 0, mask, guidance_scale, True, | |
1000, l2_sim_lambda, use_ddpm, use_ddim, True, clip_guidance_scale, False, | |
model_params, model, diffusion, ldm, bert, clip_model | |
) | |
return all_saved_ims[-1] | |