import torch from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.util import instantiate_from_config import numpy as np import random from dataset.concat_dataset import ConCatDataset #, collate_fn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler import os from tqdm import tqdm from distributed import get_rank, synchronize, get_world_size from trainer import read_official_ckpt, batch_to_device, ImageCaptionSaver, wrap_loader #, get_padded_boxes from PIL import Image import math import json def draw_masks_from_boxes(boxes,size): image_masks = [] for box in boxes: image_mask = torch.ones(size[0],size[1]) for bx in box: x0, x1 = bx[0]*size[0], bx[2]*size[0] y0, y1 = bx[1]*size[1], bx[3]*size[1] image_mask[int(y0):int(y1), int(x0):int(x1)] = 0 image_masks.append(image_mask) return torch.stack(image_masks).unsqueeze(1) def set_alpha_scale(model, alpha_scale): from ldm.modules.attention import GatedCrossAttentionDense, GatedSelfAttentionDense for module in model.modules(): if type(module) == GatedCrossAttentionDense or type(module) == GatedSelfAttentionDense: module.scale = alpha_scale # print("scale: ", alpha_scale) # print("attn: ", module.alpha_attn) # print("dense: ", module.alpha_dense) # print(' ') # print(' ') def save_images(samples, image_ids, folder, to256): for sample, image_id in zip(samples, image_ids): sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5 sample = sample.cpu().numpy().transpose(1,2,0) * 255 img_name = str(int(image_id))+'.png' img = Image.fromarray(sample.astype(np.uint8)) if to256: img = img.resize( (256,256), Image.BICUBIC) img.save(os.path.join(folder,img_name)) def ckpt_to_folder_name(basename): name="" for s in basename: if s.isdigit(): name+=s seen = round( int(name)/1000, 1 ) return str(seen).ljust(4,'0')+'k' class Evaluator: def __init__(self, config): self.config = config self.device = torch.device("cuda") # = = = = = create model and diffusion = = = = = # if self.config.ckpt != "real": self.model = instantiate_from_config(config.model).to(self.device) self.autoencoder = instantiate_from_config(config.autoencoder).to(self.device) self.text_encoder = instantiate_from_config(config.text_encoder).to(self.device) self.diffusion = instantiate_from_config(config.diffusion).to(self.device) # donot need to load official_ckpt for self.model here, since we will load from our ckpt state_dict = read_official_ckpt( os.path.join(config.DATA_ROOT, config.official_ckpt_name) ) self.autoencoder.load_state_dict( state_dict["autoencoder"] ) self.text_encoder.load_state_dict( state_dict["text_encoder"] ) self.diffusion.load_state_dict( state_dict["diffusion"] ) # = = = = = load from our ckpt = = = = = # if self.config.ckpt == "real": print("Saving all real images...") self.just_save_real = True else: checkpoint = torch.load(self.config.ckpt, map_location="cpu") which_state = 'ema' if 'ema' in checkpoint else "model" which_state = which_state if config.which_state is None else config.which_state self.model.load_state_dict(checkpoint[which_state]) print("ckpt is loaded") self.just_save_real = False set_alpha_scale(self.model, self.config.alpha_scale) self.autoencoder.eval() self.model.eval() self.text_encoder.eval() # = = = = = create data = = = = = # self.dataset_eval = ConCatDataset(config.val_dataset_names, config.DATA_ROOT, config.which_embedder, train=False) print("total eval images: ", len(self.dataset_eval)) sampler = DistributedSampler(self.dataset_eval,shuffle=False) if config.distributed else None loader_eval = DataLoader( self.dataset_eval,batch_size=config.batch_size, num_workers=config.workers, pin_memory=True, sampler=sampler, drop_last=False) # shuffle default is False self.loader_eval = loader_eval # = = = = = create output folder = = = = = # folder_name = ckpt_to_folder_name(os.path.basename(config.ckpt)) self.outdir = os.path.join(config.OUTPUT_ROOT, folder_name) self.outdir_real = os.path.join(self.outdir,'real') self.outdir_fake = os.path.join(self.outdir,'fake') if config.to256: self.outdir_real256 = os.path.join(self.outdir,'real256') self.outdir_fake256 = os.path.join(self.outdir,'fake256') synchronize() # if rank0 is faster, it may mkdir before the other rank call os.listdir() if get_rank() == 0: os.makedirs(self.outdir, exist_ok=True) os.makedirs(self.outdir_real, exist_ok=True) os.makedirs(self.outdir_fake, exist_ok=True) if config.to256: os.makedirs(self.outdir_real256, exist_ok=True) os.makedirs(self.outdir_fake256, exist_ok=True) print(self.outdir) # double check self.evaluation_finished = False if os.path.exists( os.path.join(self.outdir,'score.txt') ): self.evaluation_finished = True def alread_saved_this_batch(self, batch): existing_real_files = os.listdir( self.outdir_real ) existing_fake_files = os.listdir( self.outdir_fake ) status = [] for image_id in batch["id"]: img_name = str(int(image_id))+'.png' status.append(img_name in existing_real_files) status.append(img_name in existing_fake_files) return all(status) @torch.no_grad() def start_evaluating(self): iterator = tqdm( self.loader_eval, desc='Evaluating progress') for batch in iterator: #if not self.alread_saved_this_batch(batch): if True: batch_to_device(batch, self.device) batch_size = batch["image"].shape[0] samples_real = batch["image"] if self.just_save_real: samples_fake = None else: uc = self.text_encoder.encode( batch_size*[""] ) context = self.text_encoder.encode( batch["caption"] ) image_mask = x0 = None if self.config.inpaint: image_mask = draw_masks_from_boxes( batch['boxes'], self.model.image_size ).cuda() x0 = self.autoencoder.encode( batch["image"] ) shape = (batch_size, self.model.in_channels, self.model.image_size, self.model.image_size) if self.config.no_plms: sampler = DDIMSampler(self.diffusion, self.model) steps = 250 else: sampler = PLMSSampler(self.diffusion, self.model) steps = 50 input = dict( x=None, timesteps=None, context=context, boxes=batch['boxes'], masks=batch['masks'], positive_embeddings=batch["positive_embeddings"] ) samples_fake = sampler.sample(S=steps, shape=shape, input=input, uc=uc, guidance_scale=self.config.guidance_scale, mask=image_mask, x0=x0) samples_fake = self.autoencoder.decode(samples_fake) save_images(samples_real, batch['id'], self.outdir_real, to256=False ) if self.config.to256: save_images(samples_real, batch['id'], self.outdir_real256, to256=True ) if samples_fake is not None: save_images(samples_fake, batch['id'], self.outdir_fake, to256=False ) if self.config.to256: save_images(samples_fake, batch['id'], self.outdir_fake256, to256=True ) def fire_fid(self): paths = [self.outdir_real, self.outdir_fake] if self.config.to256: paths = [self.outdir_real256, self.outdir_fake256]