Spaces:
Running
on
T4
Running
on
T4
from __future__ import annotations | |
import math | |
import random | |
import sys | |
from argparse import ArgumentParser | |
import einops | |
import k_diffusion as K | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from tqdm.auto import tqdm | |
from einops import rearrange | |
from omegaconf import OmegaConf | |
from PIL import Image, ImageOps | |
from torch import autocast | |
import json | |
import matplotlib.pyplot as plt | |
import seaborn | |
from pathlib import Path | |
sys.path.append("./") | |
from clip_similarity import ClipSimilarity | |
from edit_dataset import EditDatasetEval | |
sys.path.append("./stable_diffusion") | |
from ldm.util import instantiate_from_config | |
class CFGDenoiser(nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.inner_model = model | |
def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): | |
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3) | |
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) | |
cfg_cond = { | |
"c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], | |
"c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], | |
} | |
out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) | |
return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) | |
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): | |
print(f"Loading model from {ckpt}") | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
if "global_step" in pl_sd: | |
print(f"Global Step: {pl_sd['global_step']}") | |
sd = pl_sd["state_dict"] | |
if vae_ckpt is not None: | |
print(f"Loading VAE from {vae_ckpt}") | |
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] | |
sd = { | |
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v | |
for k, v in sd.items() | |
} | |
model = instantiate_from_config(config.model) | |
m, u = model.load_state_dict(sd, strict=False) | |
if len(m) > 0 and verbose: | |
print("missing keys:") | |
print(m) | |
if len(u) > 0 and verbose: | |
print("unexpected keys:") | |
print(u) | |
return model | |
class ImageEditor(nn.Module): | |
def __init__(self, config, ckpt, vae_ckpt=None): | |
super().__init__() | |
config = OmegaConf.load(config) | |
self.model = load_model_from_config(config, ckpt, vae_ckpt) | |
self.model.eval().cuda() | |
self.model_wrap = K.external.CompVisDenoiser(self.model) | |
self.model_wrap_cfg = CFGDenoiser(self.model_wrap) | |
self.null_token = self.model.get_learned_conditioning([""]) | |
def forward( | |
self, | |
image: torch.Tensor, | |
edit: str, | |
scale_txt: float = 7.5, | |
scale_img: float = 1.0, | |
steps: int = 100, | |
) -> torch.Tensor: | |
assert image.dim() == 3 | |
assert image.size(1) % 64 == 0 | |
assert image.size(2) % 64 == 0 | |
with torch.no_grad(), autocast("cuda"), self.model.ema_scope(): | |
cond = { | |
"c_crossattn": [self.model.get_learned_conditioning([edit])], | |
"c_concat": [self.model.encode_first_stage(image[None]).mode()], | |
} | |
uncond = { | |
"c_crossattn": [self.model.get_learned_conditioning([""])], | |
"c_concat": [torch.zeros_like(cond["c_concat"][0])], | |
} | |
extra_args = { | |
"uncond": uncond, | |
"cond": cond, | |
"image_cfg_scale": scale_img, | |
"text_cfg_scale": scale_txt, | |
} | |
sigmas = self.model_wrap.get_sigmas(steps) | |
x = torch.randn_like(cond["c_concat"][0]) * sigmas[0] | |
x = K.sampling.sample_euler_ancestral(self.model_wrap_cfg, x, sigmas, extra_args=extra_args) | |
x = self.model.decode_first_stage(x)[0] | |
return x | |
def compute_metrics(config, | |
model_path, | |
vae_ckpt, | |
data_path, | |
output_path, | |
scales_img, | |
scales_txt, | |
num_samples = 5000, | |
split = "test", | |
steps = 50, | |
res = 512, | |
seed = 0): | |
editor = ImageEditor(config, model_path, vae_ckpt).cuda() | |
clip_similarity = ClipSimilarity().cuda() | |
outpath = Path(output_path, f"n={num_samples}_p={split}_s={steps}_r={res}_e={seed}.jsonl") | |
Path(output_path).mkdir(parents=True, exist_ok=True) | |
for scale_txt in scales_txt: | |
for scale_img in scales_img: | |
dataset = EditDatasetEval( | |
path=data_path, | |
split=split, | |
res=res | |
) | |
assert num_samples <= len(dataset) | |
print(f'Processing t={scale_txt}, i={scale_img}') | |
torch.manual_seed(seed) | |
perm = torch.randperm(len(dataset)) | |
count = 0 | |
i = 0 | |
sim_0_avg = 0 | |
sim_1_avg = 0 | |
sim_direction_avg = 0 | |
sim_image_avg = 0 | |
count = 0 | |
pbar = tqdm(total=num_samples) | |
while count < num_samples: | |
idx = perm[i].item() | |
sample = dataset[idx] | |
i += 1 | |
gen = editor(sample["image_0"].cuda(), sample["edit"], scale_txt=scale_txt, scale_img=scale_img, steps=steps) | |
sim_0, sim_1, sim_direction, sim_image = clip_similarity( | |
sample["image_0"][None].cuda(), gen[None].cuda(), [sample["input_prompt"]], [sample["output_prompt"]] | |
) | |
sim_0_avg += sim_0.item() | |
sim_1_avg += sim_1.item() | |
sim_direction_avg += sim_direction.item() | |
sim_image_avg += sim_image.item() | |
count += 1 | |
pbar.update(count) | |
pbar.close() | |
sim_0_avg /= count | |
sim_1_avg /= count | |
sim_direction_avg /= count | |
sim_image_avg /= count | |
with open(outpath, "a") as f: | |
f.write(f"{json.dumps(dict(sim_0=sim_0_avg, sim_1=sim_1_avg, sim_direction=sim_direction_avg, sim_image=sim_image_avg, num_samples=num_samples, split=split, scale_txt=scale_txt, scale_img=scale_img, steps=steps, res=res, seed=seed))}\n") | |
return outpath | |
def plot_metrics(metrics_file, output_path): | |
with open(metrics_file, 'r') as f: | |
data = [json.loads(line) for line in f] | |
plt.rcParams.update({'font.size': 11.5}) | |
seaborn.set_style("darkgrid") | |
plt.figure(figsize=(20.5* 0.7, 10.8* 0.7), dpi=200) | |
x = [d["sim_direction"] for d in data] | |
y = [d["sim_image"] for d in data] | |
plt.plot(x, y, marker='o', linewidth=2, markersize=4) | |
plt.xlabel("CLIP Text-Image Direction Similarity", labelpad=10) | |
plt.ylabel("CLIP Image Similarity", labelpad=10) | |
plt.savefig(Path(output_path) / Path("plot.pdf"), bbox_inches="tight") | |
def main(): | |
parser = ArgumentParser() | |
parser.add_argument("--resolution", default=512, type=int) | |
parser.add_argument("--steps", default=100, type=int) | |
parser.add_argument("--config", default="configs/generate.yaml", type=str) | |
parser.add_argument("--output_path", default="analysis/", type=str) | |
parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-22000.ckpt", type=str) | |
parser.add_argument("--dataset", default="data/clip-filtered-dataset/", type=str) | |
parser.add_argument("--vae-ckpt", default=None, type=str) | |
args = parser.parse_args() | |
scales_img = [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2] | |
scales_txt = [7.5] | |
metrics_file = compute_metrics( | |
args.config, | |
args.ckpt, | |
args.vae_ckpt, | |
args.dataset, | |
args.output_path, | |
scales_img, | |
scales_txt | |
steps = args.steps | |
) | |
plot_metrics(metrics_file, args.output_path) | |
if __name__ == "__main__": | |
main() | |