Spaces:
Runtime error
Runtime error
import pickle | |
import numpy as np | |
import torch | |
import time | |
from pathlib import Path | |
from dp2 import utils | |
import tops | |
from .lpips import SampleSimilarityLPIPS | |
from torch_fidelity.defaults import DEFAULTS as trf_defaults | |
from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric | |
from torch_fidelity.utils import create_feature_extractor | |
lpips_model = None | |
fid_model = None | |
def mse(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: | |
se = (images1 - images2) ** 2 | |
se = se.view(images1.shape[0], -1).mean(dim=1) | |
return se | |
def psnr(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: | |
mse_ = mse(images1, images2) | |
psnr = 10 * torch.log10(1 / mse_) | |
return psnr | |
def lpips(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: | |
return _lpips_w_grad(images1, images2) | |
def _lpips_w_grad(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: | |
global lpips_model | |
if lpips_model is None: | |
lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) | |
images1 = images1.mul(255) | |
images2 = images2.mul(255) | |
with torch.cuda.amp.autocast(tops.AMP()): | |
dists = lpips_model(images1, images2)[0].view(-1) | |
return dists | |
def compute_metrics_iteratively( | |
dataloader, generator, | |
cache_directory, | |
data_len=None, | |
truncation_value: float = None, | |
) -> dict: | |
""" | |
Args: | |
n_samples (int): Creates N samples from same image to calculate stats | |
dataset_percentage (float): The percentage of the dataset to compute metrics on. | |
""" | |
global lpips_model, fid_model | |
if lpips_model is None: | |
lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) | |
if fid_model is None: | |
fid_model = create_feature_extractor( | |
trf_defaults["feature_extractor"], [trf_defaults["feature_layer_fid"]], cuda=False) | |
fid_model = tops.to_cuda(fid_model) | |
cache_directory = Path(cache_directory) | |
start_time = time.time() | |
lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device()) | |
diversity_total = torch.zeros_like(lpips_total) | |
fid_cache_path = cache_directory.joinpath("fid_stats.pkl") | |
has_fid_cache = fid_cache_path.is_file() | |
if data_len is None: | |
data_len = len(dataloader)*dataloader.batch_size | |
if not has_fid_cache: | |
fid_features_real = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device()) | |
fid_features_fake = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device()) | |
n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device()) | |
eidx = 0 | |
for batch in utils.tqdm_(iter(dataloader), desc="Computing FID, LPIPS and LPIPS Diversity"): | |
sidx = eidx | |
eidx = sidx + batch["img"].shape[0] | |
n_samples_seen += batch["img"].shape[0] | |
with torch.cuda.amp.autocast(tops.AMP()): | |
fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"] | |
fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"] | |
fakes1 = utils.denormalize_img(fakes1).mul(255) | |
fakes2 = utils.denormalize_img(fakes2).mul(255) | |
real_data = utils.denormalize_img(batch["img"]).mul(255) | |
lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1) | |
fake2_lpips_feats = lpips_model.get_feats(fakes2) | |
lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats) | |
lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2) | |
diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum() | |
if not has_fid_cache: | |
fid_features_real[sidx:eidx] = fid_model(real_data.byte())[0] | |
fid_features_fake[sidx:eidx] = fid_model(fakes1.byte())[0] | |
fid_features_fake = fid_features_fake[:n_samples_seen] | |
if has_fid_cache: | |
if tops.rank() == 0: | |
with open(fid_cache_path, "rb") as fp: | |
fid_stat_real = pickle.load(fp) | |
else: | |
fid_features_real = fid_features_real[:n_samples_seen] | |
fid_features_real = tops.all_gather_uneven(fid_features_real).cpu() | |
if tops.rank() == 0: | |
fid_stat_real = fid_features_to_statistics(fid_features_real) | |
cache_directory.mkdir(exist_ok=True, parents=True) | |
with open(fid_cache_path, "wb") as fp: | |
pickle.dump(fid_stat_real, fp) | |
fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu() | |
if tops.rank() == 0: | |
print("Starting calculation of fid from features of shape:", fid_features_fake.shape) | |
fid_stat_fake = fid_features_to_statistics(fid_features_fake) | |
fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"] | |
tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM) | |
tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM) | |
tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM) | |
lpips_total = lpips_total / n_samples_seen | |
diversity_total = diversity_total / n_samples_seen | |
to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total) | |
if tops.rank() == 0: | |
to_return["fid"] = fid_ | |
else: | |
to_return["fid"] = -1 | |
to_return["validation_time_s"] = time.time() - start_time | |
return to_return | |
def compute_lpips( | |
dataloader, generator, | |
truncation_value: float = None, | |
data_len=None, | |
) -> dict: | |
""" | |
Args: | |
n_samples (int): Creates N samples from same image to calculate stats | |
dataset_percentage (float): The percentage of the dataset to compute metrics on. | |
""" | |
global lpips_model, fid_model | |
if lpips_model is None: | |
lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) | |
start_time = time.time() | |
lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device()) | |
diversity_total = torch.zeros_like(lpips_total) | |
if data_len is None: | |
data_len = len(dataloader) * dataloader.batch_size | |
eidx = 0 | |
n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device()) | |
for batch in utils.tqdm_(dataloader, desc="Validating on dataset."): | |
sidx = eidx | |
eidx = sidx + batch["img"].shape[0] | |
n_samples_seen += batch["img"].shape[0] | |
with torch.cuda.amp.autocast(tops.AMP()): | |
fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"] | |
fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"] | |
real_data = batch["img"] | |
fakes1 = utils.denormalize_img(fakes1).mul(255) | |
fakes2 = utils.denormalize_img(fakes2).mul(255) | |
real_data = utils.denormalize_img(real_data).mul(255) | |
lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1) | |
fake2_lpips_feats = lpips_model.get_feats(fakes2) | |
lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats) | |
lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2) | |
diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum() | |
tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM) | |
tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM) | |
tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM) | |
lpips_total = lpips_total / n_samples_seen | |
diversity_total = diversity_total / n_samples_seen | |
to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total) | |
to_return = {k: v.cpu().item() for k, v in to_return.items()} | |
to_return["validation_time_s"] = time.time() - start_time | |
return to_return | |