import pickle import numpy as np import torch import time from pathlib import Path from dp2 import utils import tops from .lpips import SampleSimilarityLPIPS from torch_fidelity.defaults import DEFAULTS as trf_defaults from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric from torch_fidelity.utils import create_feature_extractor lpips_model = None fid_model = None @torch.no_grad() def mse(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: se = (images1 - images2) ** 2 se = se.view(images1.shape[0], -1).mean(dim=1) return se @torch.no_grad() def psnr(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: mse_ = mse(images1, images2) psnr = 10 * torch.log10(1 / mse_) return psnr @torch.no_grad() def lpips(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: return _lpips_w_grad(images1, images2) def _lpips_w_grad(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: global lpips_model if lpips_model is None: lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) images1 = images1.mul(255) images2 = images2.mul(255) with torch.cuda.amp.autocast(tops.AMP()): dists = lpips_model(images1, images2)[0].view(-1) return dists @torch.no_grad() def compute_metrics_iteratively( dataloader, generator, cache_directory, data_len=None, truncation_value: float = None, ) -> dict: """ Args: n_samples (int): Creates N samples from same image to calculate stats dataset_percentage (float): The percentage of the dataset to compute metrics on. """ global lpips_model, fid_model if lpips_model is None: lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) if fid_model is None: fid_model = create_feature_extractor( trf_defaults["feature_extractor"], [trf_defaults["feature_layer_fid"]], cuda=False) fid_model = tops.to_cuda(fid_model) cache_directory = Path(cache_directory) start_time = time.time() lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device()) diversity_total = torch.zeros_like(lpips_total) fid_cache_path = cache_directory.joinpath("fid_stats.pkl") has_fid_cache = fid_cache_path.is_file() if data_len is None: data_len = len(dataloader)*dataloader.batch_size if not has_fid_cache: fid_features_real = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device()) fid_features_fake = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device()) n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device()) eidx = 0 for batch in utils.tqdm_(iter(dataloader), desc="Computing FID, LPIPS and LPIPS Diversity"): sidx = eidx eidx = sidx + batch["img"].shape[0] n_samples_seen += batch["img"].shape[0] with torch.cuda.amp.autocast(tops.AMP()): fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"] fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"] fakes1 = utils.denormalize_img(fakes1).mul(255) fakes2 = utils.denormalize_img(fakes2).mul(255) real_data = utils.denormalize_img(batch["img"]).mul(255) lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1) fake2_lpips_feats = lpips_model.get_feats(fakes2) lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats) lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2) diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum() if not has_fid_cache: fid_features_real[sidx:eidx] = fid_model(real_data.byte())[0] fid_features_fake[sidx:eidx] = fid_model(fakes1.byte())[0] fid_features_fake = fid_features_fake[:n_samples_seen] if has_fid_cache: if tops.rank() == 0: with open(fid_cache_path, "rb") as fp: fid_stat_real = pickle.load(fp) else: fid_features_real = fid_features_real[:n_samples_seen] fid_features_real = tops.all_gather_uneven(fid_features_real).cpu() if tops.rank() == 0: fid_stat_real = fid_features_to_statistics(fid_features_real) cache_directory.mkdir(exist_ok=True, parents=True) with open(fid_cache_path, "wb") as fp: pickle.dump(fid_stat_real, fp) fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu() if tops.rank() == 0: print("Starting calculation of fid from features of shape:", fid_features_fake.shape) fid_stat_fake = fid_features_to_statistics(fid_features_fake) fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"] tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM) tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM) tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM) lpips_total = lpips_total / n_samples_seen diversity_total = diversity_total / n_samples_seen to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total) if tops.rank() == 0: to_return["fid"] = fid_ else: to_return["fid"] = -1 to_return["validation_time_s"] = time.time() - start_time return to_return @torch.no_grad() def compute_lpips( dataloader, generator, truncation_value: float = None, data_len=None, ) -> dict: """ Args: n_samples (int): Creates N samples from same image to calculate stats dataset_percentage (float): The percentage of the dataset to compute metrics on. """ global lpips_model, fid_model if lpips_model is None: lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) start_time = time.time() lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device()) diversity_total = torch.zeros_like(lpips_total) if data_len is None: data_len = len(dataloader) * dataloader.batch_size eidx = 0 n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device()) for batch in utils.tqdm_(dataloader, desc="Validating on dataset."): sidx = eidx eidx = sidx + batch["img"].shape[0] n_samples_seen += batch["img"].shape[0] with torch.cuda.amp.autocast(tops.AMP()): fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"] fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"] real_data = batch["img"] fakes1 = utils.denormalize_img(fakes1).mul(255) fakes2 = utils.denormalize_img(fakes2).mul(255) real_data = utils.denormalize_img(real_data).mul(255) lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1) fake2_lpips_feats = lpips_model.get_feats(fakes2) lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats) lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2) diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum() tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM) tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM) tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM) lpips_total = lpips_total / n_samples_seen diversity_total = diversity_total / n_samples_seen to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total) to_return = {k: v.cpu().item() for k, v in to_return.items()} to_return["validation_time_s"] = time.time() - start_time return to_return