import gc import os import io import math import sys import tempfile from PIL import Image, ImageOps import requests import torch from torch import nn from torch.nn import functional as F from torchvision import transforms from torchvision.transforms import functional as TF from tqdm.notebook import tqdm import numpy as np from math import log2, sqrt import argparse import pickle ################################### mask_fusion ###################################### from util.metrics_accumulator import MetricsAccumulator metrics_accumulator = MetricsAccumulator() from pathlib import Path from PIL import Image ################################### mask_fusion ###################################### import clip import lpips from torch.nn.functional import mse_loss ################################### CLIPseg ###################################### from torchvision import utils as vutils import cv2 ################################### CLIPseg ###################################### def str2bool(x): return x.lower() in ('true') USE_CPU = False device = torch.device('cuda:0' if (torch.cuda.is_available() and not USE_CPU) else 'cpu') def fetch(url_or_path): if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): r = requests.get(url_or_path) r.raise_for_status() fd = io.BytesIO() fd.write(r.content) fd.seek(0) return fd return open(url_or_path, 'rb') class MakeCutouts(nn.Module): def __init__(self, cut_size, cutn, cut_pow=1.): super().__init__() self.cut_size = cut_size self.cutn = cutn self.cut_pow = cut_pow def forward(self, input): sideY, sideX = input.shape[2:4] max_size = min(sideX, sideY) min_size = min(sideX, sideY, self.cut_size) cutouts = [] for _ in range(self.cutn): size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) offsetx = torch.randint(0, sideX - size + 1, ()) offsety = torch.randint(0, sideY - size + 1, ()) cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) return torch.cat(cutouts) def spherical_dist_loss(x, y): x = F.normalize(x, dim=-1) y = F.normalize(y, dim=-1) return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) def do_run( arg_seed, arg_text, arg_batch_size, arg_num_batches, arg_negative, arg_cutn, arg_edit, arg_height, arg_width, arg_edit_y, arg_edit_x, arg_edit_width, arg_edit_height, mask, arg_guidance_scale, arg_background_preservation_loss, arg_lpips_sim_lambda, arg_l2_sim_lambda, arg_ddpm, arg_ddim, arg_enforce_background, arg_clip_guidance_scale, arg_clip_guidance, model_params, model, diffusion, ldm, bert, clip_model ): normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) if arg_seed >= 0: torch.manual_seed(arg_seed) text_emb = bert.encode([arg_text] * arg_batch_size).to(device).float() text_blank = bert.encode([arg_negative] * arg_batch_size).to(device).float() text = clip.tokenize([arg_text] * arg_batch_size, truncate=True).to(device) text_clip_blank = clip.tokenize([arg_negative] * arg_batch_size, truncate=True).to(device) text_emb_clip = clip_model.encode_text(text) text_emb_clip_blank = clip_model.encode_text(text_clip_blank) make_cutouts = MakeCutouts(clip_model.visual.input_resolution, arg_cutn) text_emb_norm = text_emb_clip[0] / text_emb_clip[0].norm(dim=-1, keepdim=True) image_embed = None if arg_edit: w = arg_edit_width if arg_edit_width else arg_width h = arg_edit_height if arg_edit_height else arg_height arg_edit = arg_edit.convert('RGB') input_image_pil = arg_edit init_image_pil = input_image_pil.resize((arg_height, arg_width), Image.Resampling.LANCZOS) input_image_pil = ImageOps.fit(input_image_pil, (w, h)) im = transforms.ToTensor()(input_image_pil).unsqueeze(0).to(device) init_image = (TF.to_tensor(init_image_pil).to(device).unsqueeze(0).mul(2).sub(1)) im = 2*im-1 im = ldm.encode(im).sample() y = arg_edit_y//8 x = arg_edit_x//8 input_image = torch.zeros(1, 4, arg_height//8, arg_width//8, device=device) ycrop = y + im.shape[2] - input_image.shape[2] xcrop = x + im.shape[3] - input_image.shape[3] ycrop = ycrop if ycrop > 0 else 0 xcrop = xcrop if xcrop > 0 else 0 input_image[0,:,y if y >=0 else 0:y+im.shape[2],x if x >=0 else 0:x+im.shape[3]] = im[:,:,0 if y > 0 else -y:im.shape[2]-ycrop,0 if x > 0 else -x:im.shape[3]-xcrop] input_image_pil = ldm.decode(input_image) input_image_pil = TF.to_pil_image(input_image_pil.squeeze(0).add(1).div(2).clamp(0, 1)) input_image *= 0.18215 new_mask = TF.resize(mask.unsqueeze(0).unsqueeze(0).to(device), (arg_width//8, arg_height//8)) mask1 = (new_mask > 0.5) mask1 = mask1.float() input_image *= mask1 image_embed = torch.cat(arg_batch_size*2*[input_image], dim=0).float() elif model_params['image_condition']: # using inpaint model but no image is provided image_embed = torch.zeros(arg_batch_size*2, 4, arg_height//8, arg_width//8, device=device) kwargs = { "context": torch.cat([text_emb, text_blank], dim=0).float(), "clip_embed": torch.cat([text_emb_clip, text_emb_clip_blank], dim=0).float() if model_params['clip_embed_dim'] else None, "image_embed": image_embed } # Create a classifier-free guidance sampling function def model_fn(x_t, ts, **kwargs): half = x_t[: len(x_t) // 2] combined = torch.cat([half, half], dim=0) model_out = model(combined, ts, **kwargs) eps, rest = model_out[:, :3], model_out[:, 3:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + arg_guidance_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1) cur_t = None @torch.no_grad() def postprocess_fn(out, t): if mask is not None: background_stage_t = diffusion.q_sample(init_image, t[0]) background_stage_t = torch.tile( background_stage_t, dims=(arg_batch_size, 1, 1, 1) ) out["sample"] = out["sample"] * mask + background_stage_t * (1 - mask) return out # if arg_ddpm: # sample_fn = diffusion.p_sample_loop_progressive # elif arg_ddim: # sample_fn = diffusion.ddim_sample_loop_progressive # else: sample_fn = diffusion.plms_sample_loop_progressive def save_sample(i, sample): out_ims = [] for k, image in enumerate(sample['pred_xstart'][:arg_batch_size]): image /= 0.18215 im = image.unsqueeze(0) out = ldm.decode(im) metrics_accumulator.print_average_metric() for b in range(arg_batch_size): pred_image = sample["pred_xstart"][b] if arg_enforce_background: new_mask = TF.resize(mask.unsqueeze(0).unsqueeze(0).to(device), (arg_width, arg_height)) pred_image = ( init_image[0] * new_mask[0] + out * (1 - new_mask[0]) ) pred_image_pil = TF.to_pil_image(pred_image.squeeze(0).add(1).div(2).clamp(0, 1)) out_ims.append(pred_image_pil) return out_ims all_saved_ims = [] for i in range(arg_num_batches): cur_t = diffusion.num_timesteps - 1 samples = sample_fn( model_fn, (arg_batch_size*2, 4, int(arg_height//8), int(arg_width//8)), clip_denoised=False, model_kwargs=kwargs, cond_fn=None, device=device, progress=True, ) for j, sample in enumerate(samples): cur_t -= 1 if j % 5 == 0 and j != diffusion.num_timesteps - 1: all_saved_ims += save_sample(i, sample) all_saved_ims += save_sample(i, sample) return all_saved_ims def run_model( segmodel, model, diffusion, ldm, bert, clip_model, model_params, from_text, instruction, negative_prompt, original_img, seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda ): input_image = original_img transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((256, 256)), ]) img = transform(input_image).unsqueeze(0) with torch.no_grad(): preds = segmodel(img.repeat(1,1,1,1), from_text)[0] mask = torch.sigmoid(preds[0][0]) image = (mask.detach().cpu().numpy() * 255).astype(np.uint8) # cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) ret, thresh = cv2.threshold(image, 100, 255, cv2.THRESH_TRUNC, image) timg = np.array(thresh) x, y = timg.shape for row in range(x): for col in range(y): if (timg[row][col]) == 100: timg[row][col] = 255 if (timg[row][col]) < 100: timg[row][col] = 0 fulltensor = torch.full_like(mask, fill_value=255) bgtensor = fulltensor-timg mask = bgtensor / 255.0 gc.collect() use_ddim = False use_ddpm = False all_saved_ims = do_run( seed, instruction, 1, 1, negative_prompt, cutn, input_image, 256, 256, 0, 0, 0, 0, mask, guidance_scale, True, 1000, l2_sim_lambda, use_ddpm, use_ddim, True, clip_guidance_scale, False, model_params, model, diffusion, ldm, bert, clip_model ) return all_saved_ims[-1]