from __future__ import annotations import math import random import sys from argparse import ArgumentParser import einops import k_diffusion as K import numpy as np import torch import torch.nn as nn from tqdm.auto import tqdm from einops import rearrange from omegaconf import OmegaConf from PIL import Image, ImageOps from torch import autocast import json import matplotlib.pyplot as plt import seaborn from pathlib import Path sys.path.append("./") from clip_similarity import ClipSimilarity from edit_dataset import EditDatasetEval sys.path.append("./stable_diffusion") from ldm.util import instantiate_from_config class CFGDenoiser(nn.Module): def __init__(self, model): super().__init__() self.inner_model = model def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): cfg_z = einops.repeat(z, "1 ... -> n ...", n=3) cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) cfg_cond = { "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], } out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] if vae_ckpt is not None: print(f"Loading VAE from {vae_ckpt}") vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] sd = { k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v for k, v in sd.items() } model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) return model class ImageEditor(nn.Module): def __init__(self, config, ckpt, vae_ckpt=None): super().__init__() config = OmegaConf.load(config) self.model = load_model_from_config(config, ckpt, vae_ckpt) self.model.eval().cuda() self.model_wrap = K.external.CompVisDenoiser(self.model) self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.null_token = self.model.get_learned_conditioning([""]) def forward( self, image: torch.Tensor, edit: str, scale_txt: float = 7.5, scale_img: float = 1.0, steps: int = 100, ) -> torch.Tensor: assert image.dim() == 3 assert image.size(1) % 64 == 0 assert image.size(2) % 64 == 0 with torch.no_grad(), autocast("cuda"), self.model.ema_scope(): cond = { "c_crossattn": [self.model.get_learned_conditioning([edit])], "c_concat": [self.model.encode_first_stage(image[None]).mode()], } uncond = { "c_crossattn": [self.model.get_learned_conditioning([""])], "c_concat": [torch.zeros_like(cond["c_concat"][0])], } extra_args = { "uncond": uncond, "cond": cond, "image_cfg_scale": scale_img, "text_cfg_scale": scale_txt, } sigmas = self.model_wrap.get_sigmas(steps) x = torch.randn_like(cond["c_concat"][0]) * sigmas[0] x = K.sampling.sample_euler_ancestral(self.model_wrap_cfg, x, sigmas, extra_args=extra_args) x = self.model.decode_first_stage(x)[0] return x def compute_metrics(config, model_path, vae_ckpt, data_path, output_path, scales_img, scales_txt, num_samples = 5000, split = "test", steps = 50, res = 512, seed = 0): editor = ImageEditor(config, model_path, vae_ckpt).cuda() clip_similarity = ClipSimilarity().cuda() outpath = Path(output_path, f"n={num_samples}_p={split}_s={steps}_r={res}_e={seed}.jsonl") Path(output_path).mkdir(parents=True, exist_ok=True) for scale_txt in scales_txt: for scale_img in scales_img: dataset = EditDatasetEval( path=data_path, split=split, res=res ) assert num_samples <= len(dataset) print(f'Processing t={scale_txt}, i={scale_img}') torch.manual_seed(seed) perm = torch.randperm(len(dataset)) count = 0 i = 0 sim_0_avg = 0 sim_1_avg = 0 sim_direction_avg = 0 sim_image_avg = 0 count = 0 pbar = tqdm(total=num_samples) while count < num_samples: idx = perm[i].item() sample = dataset[idx] i += 1 gen = editor(sample["image_0"].cuda(), sample["edit"], scale_txt=scale_txt, scale_img=scale_img, steps=steps) sim_0, sim_1, sim_direction, sim_image = clip_similarity( sample["image_0"][None].cuda(), gen[None].cuda(), [sample["input_prompt"]], [sample["output_prompt"]] ) sim_0_avg += sim_0.item() sim_1_avg += sim_1.item() sim_direction_avg += sim_direction.item() sim_image_avg += sim_image.item() count += 1 pbar.update(count) pbar.close() sim_0_avg /= count sim_1_avg /= count sim_direction_avg /= count sim_image_avg /= count with open(outpath, "a") as f: f.write(f"{json.dumps(dict(sim_0=sim_0_avg, sim_1=sim_1_avg, sim_direction=sim_direction_avg, sim_image=sim_image_avg, num_samples=num_samples, split=split, scale_txt=scale_txt, scale_img=scale_img, steps=steps, res=res, seed=seed))}\n") return outpath def plot_metrics(metrics_file, output_path): with open(metrics_file, 'r') as f: data = [json.loads(line) for line in f] plt.rcParams.update({'font.size': 11.5}) seaborn.set_style("darkgrid") plt.figure(figsize=(20.5* 0.7, 10.8* 0.7), dpi=200) x = [d["sim_direction"] for d in data] y = [d["sim_image"] for d in data] plt.plot(x, y, marker='o', linewidth=2, markersize=4) plt.xlabel("CLIP Text-Image Direction Similarity", labelpad=10) plt.ylabel("CLIP Image Similarity", labelpad=10) plt.savefig(Path(output_path) / Path("plot.pdf"), bbox_inches="tight") def main(): parser = ArgumentParser() parser.add_argument("--resolution", default=512, type=int) parser.add_argument("--steps", default=100, type=int) parser.add_argument("--config", default="configs/generate.yaml", type=str) parser.add_argument("--output_path", default="analysis/", type=str) parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-22000.ckpt", type=str) parser.add_argument("--dataset", default="data/clip-filtered-dataset/", type=str) parser.add_argument("--vae-ckpt", default=None, type=str) args = parser.parse_args() scales_img = [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2] scales_txt = [7.5] metrics_file = compute_metrics( args.config, args.ckpt, args.vae_ckpt, args.dataset, args.output_path, scales_img, scales_txt steps = args.steps ) plot_metrics(metrics_file, args.output_path) if __name__ == "__main__": main()