File size: 3,350 Bytes
7f49ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import argparse
import pickle
import torch
from torch import nn
import numpy as np
from scipy import linalg
from tqdm import tqdm
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from calc_inception import load_patched_inception_v3
import os
@torch.no_grad()
def extract_features(loader, inception, device):
pbar = tqdm(loader)
feature_list = []
for img,_ in pbar:
img = img.to(device)
feature = inception(img)[0].view(img.shape[0], -1)
feature_list.append(feature.to('cpu'))
features = torch.cat(feature_list, 0)
return features
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
if not np.isfinite(cov_sqrt).all():
print('product of cov matrices is singular')
offset = np.eye(sample_cov.shape[0]) * eps
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
if np.iscomplexobj(cov_sqrt):
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
m = np.max(np.abs(cov_sqrt.imag))
raise ValueError(f'Imaginary component {m}')
cov_sqrt = cov_sqrt.real
mean_diff = sample_mean - real_mean
mean_norm = mean_diff @ mean_diff
trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
fid = mean_norm + trace
return fid
if __name__ == '__main__':
device = 'cuda'
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64)
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--path_a', type=str)
parser.add_argument('--path_b', type=str)
parser.add_argument('--iter', type=int, default=3)
parser.add_argument('--end', type=int, default=13)
args = parser.parse_args()
inception = load_patched_inception_v3().eval().to(device)
transform = transforms.Compose(
[
transforms.Resize( (args.size, args.size) ),
#transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
dset_a = ImageFolder(args.path_a, transform)
loader_a = DataLoader(dset_a, batch_size=args.batch, num_workers=4)
features_a = extract_features(loader_a, inception, device).numpy()
print(f'extracted {features_a.shape[0]} features')
real_mean = np.mean(features_a, 0)
real_cov = np.cov(features_a, rowvar=False)
#for folder in os.listdir(args.path_b):
for folder in range(args.iter,args.end+1):
folder = 'eval_%d'%(folder*10000)
if os.path.exists(os.path.join( args.path_b, folder )):
print(folder)
dset_b = ImageFolder( os.path.join( args.path_b, folder ), transform)
loader_b = DataLoader(dset_b, batch_size=args.batch, num_workers=4)
features_b = extract_features(loader_b, inception, device).numpy()
print(f'extracted {features_b.shape[0]} features')
sample_mean = np.mean(features_b, 0)
sample_cov = np.cov(features_b, rowvar=False)
fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
print(folder, ' fid:', fid)
|