Spaces:
Sleeping
Sleeping
import logging | |
from abc import abstractmethod, ABC | |
import numpy as np | |
import sklearn | |
import sklearn.svm | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from joblib import Parallel, delayed | |
from scipy import linalg | |
from models.ade20k import SegmentationModule, NUM_CLASS, segm_options | |
from .fid.inception import InceptionV3 | |
from .lpips import PerceptualLoss | |
from .ssim import SSIM | |
LOGGER = logging.getLogger(__name__) | |
def get_groupings(groups): | |
""" | |
:param groups: group numbers for respective elements | |
:return: dict of kind {group_idx: indices of the corresponding group elements} | |
""" | |
label_groups, count_groups = np.unique(groups, return_counts=True) | |
indices = np.argsort(groups) | |
grouping = dict() | |
cur_start = 0 | |
for label, count in zip(label_groups, count_groups): | |
cur_end = cur_start + count | |
cur_indices = indices[cur_start:cur_end] | |
grouping[label] = cur_indices | |
cur_start = cur_end | |
return grouping | |
class EvaluatorScore(nn.Module): | |
def forward(self, pred_batch, target_batch, mask): | |
pass | |
def get_value(self, groups=None, states=None): | |
pass | |
def reset(self): | |
pass | |
class PairwiseScore(EvaluatorScore, ABC): | |
def __init__(self): | |
super().__init__() | |
self.individual_values = None | |
def get_value(self, groups=None, states=None): | |
""" | |
:param groups: | |
:return: | |
total_results: dict of kind {'mean': score mean, 'std': score std} | |
group_results: None, if groups is None; | |
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} | |
""" | |
individual_values = torch.cat(states, dim=-1).reshape(-1).cpu().numpy() if states is not None \ | |
else self.individual_values | |
total_results = { | |
'mean': individual_values.mean(), | |
'std': individual_values.std() | |
} | |
if groups is None: | |
return total_results, None | |
group_results = dict() | |
grouping = get_groupings(groups) | |
for label, index in grouping.items(): | |
group_scores = individual_values[index] | |
group_results[label] = { | |
'mean': group_scores.mean(), | |
'std': group_scores.std() | |
} | |
return total_results, group_results | |
def reset(self): | |
self.individual_values = [] | |
class SSIMScore(PairwiseScore): | |
def __init__(self, window_size=11): | |
super().__init__() | |
self.score = SSIM(window_size=window_size, size_average=False).eval() | |
self.reset() | |
def forward(self, pred_batch, target_batch, mask=None): | |
batch_values = self.score(pred_batch, target_batch) | |
self.individual_values = np.hstack([ | |
self.individual_values, batch_values.detach().cpu().numpy() | |
]) | |
return batch_values | |
class LPIPSScore(PairwiseScore): | |
def __init__(self, model='net-lin', net='vgg', model_path=None, use_gpu=True): | |
super().__init__() | |
self.score = PerceptualLoss(model=model, net=net, model_path=model_path, | |
use_gpu=use_gpu, spatial=False).eval() | |
self.reset() | |
def forward(self, pred_batch, target_batch, mask=None): | |
batch_values = self.score(pred_batch, target_batch).flatten() | |
self.individual_values = np.hstack([ | |
self.individual_values, batch_values.detach().cpu().numpy() | |
]) | |
return batch_values | |
def fid_calculate_activation_statistics(act): | |
mu = np.mean(act, axis=0) | |
sigma = np.cov(act, rowvar=False) | |
return mu, sigma | |
def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6): | |
mu1, sigma1 = fid_calculate_activation_statistics(activations_pred) | |
mu2, sigma2 = fid_calculate_activation_statistics(activations_target) | |
diff = mu1 - mu2 | |
# Product might be almost singular | |
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |
if not np.isfinite(covmean).all(): | |
msg = ('fid calculation produces singular product; ' | |
'adding %s to diagonal of cov estimates') % eps | |
LOGGER.warning(msg) | |
offset = np.eye(sigma1.shape[0]) * eps | |
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |
# Numerical error might give slight imaginary component | |
if np.iscomplexobj(covmean): | |
# if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): | |
m = np.max(np.abs(covmean.imag)) | |
raise ValueError('Imaginary component {}'.format(m)) | |
covmean = covmean.real | |
tr_covmean = np.trace(covmean) | |
return (diff.dot(diff) + np.trace(sigma1) + | |
np.trace(sigma2) - 2 * tr_covmean) | |
class FIDScore(EvaluatorScore): | |
def __init__(self, dims=2048, eps=1e-6): | |
LOGGER.info("FIDscore init called") | |
super().__init__() | |
if getattr(FIDScore, '_MODEL', None) is None: | |
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] | |
FIDScore._MODEL = InceptionV3([block_idx]).eval() | |
self.model = FIDScore._MODEL | |
self.eps = eps | |
self.reset() | |
LOGGER.info("FIDscore init done") | |
def forward(self, pred_batch, target_batch, mask=None): | |
activations_pred = self._get_activations(pred_batch) | |
activations_target = self._get_activations(target_batch) | |
self.activations_pred.append(activations_pred.detach().cpu()) | |
self.activations_target.append(activations_target.detach().cpu()) | |
return activations_pred, activations_target | |
def get_value(self, groups=None, states=None): | |
LOGGER.info("FIDscore get_value called") | |
activations_pred, activations_target = zip(*states) if states is not None \ | |
else (self.activations_pred, self.activations_target) | |
activations_pred = torch.cat(activations_pred).cpu().numpy() | |
activations_target = torch.cat(activations_target).cpu().numpy() | |
total_distance = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps) | |
total_results = dict(mean=total_distance) | |
if groups is None: | |
group_results = None | |
else: | |
group_results = dict() | |
grouping = get_groupings(groups) | |
for label, index in grouping.items(): | |
if len(index) > 1: | |
group_distance = calculate_frechet_distance(activations_pred[index], activations_target[index], | |
eps=self.eps) | |
group_results[label] = dict(mean=group_distance) | |
else: | |
group_results[label] = dict(mean=float('nan')) | |
self.reset() | |
LOGGER.info("FIDscore get_value done") | |
return total_results, group_results | |
def reset(self): | |
self.activations_pred = [] | |
self.activations_target = [] | |
def _get_activations(self, batch): | |
activations = self.model(batch)[0] | |
if activations.shape[2] != 1 or activations.shape[3] != 1: | |
assert False, \ | |
'We should not have got here, because Inception always scales inputs to 299x299' | |
# activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1)) | |
activations = activations.squeeze(-1).squeeze(-1) | |
return activations | |
class SegmentationAwareScore(EvaluatorScore): | |
def __init__(self, weights_path): | |
super().__init__() | |
self.segm_network = SegmentationModule(weights_path=weights_path, use_default_normalization=True).eval() | |
self.target_class_freq_by_image_total = [] | |
self.target_class_freq_by_image_mask = [] | |
self.pred_class_freq_by_image_mask = [] | |
def forward(self, pred_batch, target_batch, mask): | |
pred_segm_flat = self.segm_network.predict(pred_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy() | |
target_segm_flat = self.segm_network.predict(target_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy() | |
mask_flat = (mask.view(mask.shape[0], -1) > 0.5).detach().cpu().numpy() | |
batch_target_class_freq_total = [] | |
batch_target_class_freq_mask = [] | |
batch_pred_class_freq_mask = [] | |
for cur_pred_segm, cur_target_segm, cur_mask in zip(pred_segm_flat, target_segm_flat, mask_flat): | |
cur_target_class_freq_total = np.bincount(cur_target_segm, minlength=NUM_CLASS)[None, ...] | |
cur_target_class_freq_mask = np.bincount(cur_target_segm[cur_mask], minlength=NUM_CLASS)[None, ...] | |
cur_pred_class_freq_mask = np.bincount(cur_pred_segm[cur_mask], minlength=NUM_CLASS)[None, ...] | |
self.target_class_freq_by_image_total.append(cur_target_class_freq_total) | |
self.target_class_freq_by_image_mask.append(cur_target_class_freq_mask) | |
self.pred_class_freq_by_image_mask.append(cur_pred_class_freq_mask) | |
batch_target_class_freq_total.append(cur_target_class_freq_total) | |
batch_target_class_freq_mask.append(cur_target_class_freq_mask) | |
batch_pred_class_freq_mask.append(cur_pred_class_freq_mask) | |
batch_target_class_freq_total = np.concatenate(batch_target_class_freq_total, axis=0) | |
batch_target_class_freq_mask = np.concatenate(batch_target_class_freq_mask, axis=0) | |
batch_pred_class_freq_mask = np.concatenate(batch_pred_class_freq_mask, axis=0) | |
return batch_target_class_freq_total, batch_target_class_freq_mask, batch_pred_class_freq_mask | |
def reset(self): | |
super().reset() | |
self.target_class_freq_by_image_total = [] | |
self.target_class_freq_by_image_mask = [] | |
self.pred_class_freq_by_image_mask = [] | |
def distribute_values_to_classes(target_class_freq_by_image_mask, values, idx2name): | |
assert target_class_freq_by_image_mask.ndim == 2 and target_class_freq_by_image_mask.shape[0] == values.shape[0] | |
total_class_freq = target_class_freq_by_image_mask.sum(0) | |
distr_values = (target_class_freq_by_image_mask * values[..., None]).sum(0) | |
result = distr_values / (total_class_freq + 1e-3) | |
return {idx2name[i]: val for i, val in enumerate(result) if total_class_freq[i] > 0} | |
def get_segmentation_idx2name(): | |
return {i - 1: name for i, name in segm_options['classes'].set_index('Idx', drop=True)['Name'].to_dict().items()} | |
class SegmentationAwarePairwiseScore(SegmentationAwareScore): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.individual_values = [] | |
self.segm_idx2name = get_segmentation_idx2name() | |
def forward(self, pred_batch, target_batch, mask): | |
cur_class_stats = super().forward(pred_batch, target_batch, mask) | |
score_values = self.calc_score(pred_batch, target_batch, mask) | |
self.individual_values.append(score_values) | |
return cur_class_stats + (score_values,) | |
def calc_score(self, pred_batch, target_batch, mask): | |
raise NotImplementedError() | |
def get_value(self, groups=None, states=None): | |
""" | |
:param groups: | |
:return: | |
total_results: dict of kind {'mean': score mean, 'std': score std} | |
group_results: None, if groups is None; | |
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} | |
""" | |
if states is not None: | |
(target_class_freq_by_image_total, | |
target_class_freq_by_image_mask, | |
pred_class_freq_by_image_mask, | |
individual_values) = states | |
else: | |
target_class_freq_by_image_total = self.target_class_freq_by_image_total | |
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask | |
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask | |
individual_values = self.individual_values | |
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0) | |
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0) | |
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0) | |
individual_values = np.concatenate(individual_values, axis=0) | |
total_results = { | |
'mean': individual_values.mean(), | |
'std': individual_values.std(), | |
**distribute_values_to_classes(target_class_freq_by_image_mask, individual_values, self.segm_idx2name) | |
} | |
if groups is None: | |
return total_results, None | |
group_results = dict() | |
grouping = get_groupings(groups) | |
for label, index in grouping.items(): | |
group_class_freq = target_class_freq_by_image_mask[index] | |
group_scores = individual_values[index] | |
group_results[label] = { | |
'mean': group_scores.mean(), | |
'std': group_scores.std(), | |
** distribute_values_to_classes(group_class_freq, group_scores, self.segm_idx2name) | |
} | |
return total_results, group_results | |
def reset(self): | |
super().reset() | |
self.individual_values = [] | |
class SegmentationClassStats(SegmentationAwarePairwiseScore): | |
def calc_score(self, pred_batch, target_batch, mask): | |
return 0 | |
def get_value(self, groups=None, states=None): | |
""" | |
:param groups: | |
:return: | |
total_results: dict of kind {'mean': score mean, 'std': score std} | |
group_results: None, if groups is None; | |
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} | |
""" | |
if states is not None: | |
(target_class_freq_by_image_total, | |
target_class_freq_by_image_mask, | |
pred_class_freq_by_image_mask, | |
_) = states | |
else: | |
target_class_freq_by_image_total = self.target_class_freq_by_image_total | |
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask | |
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask | |
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0) | |
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0) | |
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0) | |
target_class_freq_by_image_total_marginal = target_class_freq_by_image_total.sum(0).astype('float32') | |
target_class_freq_by_image_total_marginal /= target_class_freq_by_image_total_marginal.sum() | |
target_class_freq_by_image_mask_marginal = target_class_freq_by_image_mask.sum(0).astype('float32') | |
target_class_freq_by_image_mask_marginal /= target_class_freq_by_image_mask_marginal.sum() | |
pred_class_freq_diff = (pred_class_freq_by_image_mask - target_class_freq_by_image_mask).sum(0) / (target_class_freq_by_image_mask.sum(0) + 1e-3) | |
total_results = dict() | |
total_results.update({f'total_freq/{self.segm_idx2name[i]}': v | |
for i, v in enumerate(target_class_freq_by_image_total_marginal) | |
if v > 0}) | |
total_results.update({f'mask_freq/{self.segm_idx2name[i]}': v | |
for i, v in enumerate(target_class_freq_by_image_mask_marginal) | |
if v > 0}) | |
total_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v | |
for i, v in enumerate(pred_class_freq_diff) | |
if target_class_freq_by_image_total_marginal[i] > 0}) | |
if groups is None: | |
return total_results, None | |
group_results = dict() | |
grouping = get_groupings(groups) | |
for label, index in grouping.items(): | |
group_target_class_freq_by_image_total = target_class_freq_by_image_total[index] | |
group_target_class_freq_by_image_mask = target_class_freq_by_image_mask[index] | |
group_pred_class_freq_by_image_mask = pred_class_freq_by_image_mask[index] | |
group_target_class_freq_by_image_total_marginal = group_target_class_freq_by_image_total.sum(0).astype('float32') | |
group_target_class_freq_by_image_total_marginal /= group_target_class_freq_by_image_total_marginal.sum() | |
group_target_class_freq_by_image_mask_marginal = group_target_class_freq_by_image_mask.sum(0).astype('float32') | |
group_target_class_freq_by_image_mask_marginal /= group_target_class_freq_by_image_mask_marginal.sum() | |
group_pred_class_freq_diff = (group_pred_class_freq_by_image_mask - group_target_class_freq_by_image_mask).sum(0) / ( | |
group_target_class_freq_by_image_mask.sum(0) + 1e-3) | |
cur_group_results = dict() | |
cur_group_results.update({f'total_freq/{self.segm_idx2name[i]}': v | |
for i, v in enumerate(group_target_class_freq_by_image_total_marginal) | |
if v > 0}) | |
cur_group_results.update({f'mask_freq/{self.segm_idx2name[i]}': v | |
for i, v in enumerate(group_target_class_freq_by_image_mask_marginal) | |
if v > 0}) | |
cur_group_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v | |
for i, v in enumerate(group_pred_class_freq_diff) | |
if group_target_class_freq_by_image_total_marginal[i] > 0}) | |
group_results[label] = cur_group_results | |
return total_results, group_results | |
class SegmentationAwareSSIM(SegmentationAwarePairwiseScore): | |
def __init__(self, *args, window_size=11, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.score_impl = SSIM(window_size=window_size, size_average=False).eval() | |
def calc_score(self, pred_batch, target_batch, mask): | |
return self.score_impl(pred_batch, target_batch).detach().cpu().numpy() | |
class SegmentationAwareLPIPS(SegmentationAwarePairwiseScore): | |
def __init__(self, *args, model='net-lin', net='vgg', model_path=None, use_gpu=True, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.score_impl = PerceptualLoss(model=model, net=net, model_path=model_path, | |
use_gpu=use_gpu, spatial=False).eval() | |
def calc_score(self, pred_batch, target_batch, mask): | |
return self.score_impl(pred_batch, target_batch).flatten().detach().cpu().numpy() | |
def calculade_fid_no_img(img_i, activations_pred, activations_target, eps=1e-6): | |
activations_pred = activations_pred.copy() | |
activations_pred[img_i] = activations_target[img_i] | |
return calculate_frechet_distance(activations_pred, activations_target, eps=eps) | |
class SegmentationAwareFID(SegmentationAwarePairwiseScore): | |
def __init__(self, *args, dims=2048, eps=1e-6, n_jobs=-1, **kwargs): | |
super().__init__(*args, **kwargs) | |
if getattr(FIDScore, '_MODEL', None) is None: | |
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] | |
FIDScore._MODEL = InceptionV3([block_idx]).eval() | |
self.model = FIDScore._MODEL | |
self.eps = eps | |
self.n_jobs = n_jobs | |
def calc_score(self, pred_batch, target_batch, mask): | |
activations_pred = self._get_activations(pred_batch) | |
activations_target = self._get_activations(target_batch) | |
return activations_pred, activations_target | |
def get_value(self, groups=None, states=None): | |
""" | |
:param groups: | |
:return: | |
total_results: dict of kind {'mean': score mean, 'std': score std} | |
group_results: None, if groups is None; | |
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} | |
""" | |
if states is not None: | |
(target_class_freq_by_image_total, | |
target_class_freq_by_image_mask, | |
pred_class_freq_by_image_mask, | |
activation_pairs) = states | |
else: | |
target_class_freq_by_image_total = self.target_class_freq_by_image_total | |
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask | |
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask | |
activation_pairs = self.individual_values | |
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0) | |
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0) | |
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0) | |
activations_pred, activations_target = zip(*activation_pairs) | |
activations_pred = np.concatenate(activations_pred, axis=0) | |
activations_target = np.concatenate(activations_target, axis=0) | |
total_results = { | |
'mean': calculate_frechet_distance(activations_pred, activations_target, eps=self.eps), | |
'std': 0, | |
**self.distribute_fid_to_classes(target_class_freq_by_image_mask, activations_pred, activations_target) | |
} | |
if groups is None: | |
return total_results, None | |
group_results = dict() | |
grouping = get_groupings(groups) | |
for label, index in grouping.items(): | |
if len(index) > 1: | |
group_activations_pred = activations_pred[index] | |
group_activations_target = activations_target[index] | |
group_class_freq = target_class_freq_by_image_mask[index] | |
group_results[label] = { | |
'mean': calculate_frechet_distance(group_activations_pred, group_activations_target, eps=self.eps), | |
'std': 0, | |
**self.distribute_fid_to_classes(group_class_freq, | |
group_activations_pred, | |
group_activations_target) | |
} | |
else: | |
group_results[label] = dict(mean=float('nan'), std=0) | |
return total_results, group_results | |
def distribute_fid_to_classes(self, class_freq, activations_pred, activations_target): | |
real_fid = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps) | |
fid_no_images = Parallel(n_jobs=self.n_jobs)( | |
delayed(calculade_fid_no_img)(img_i, activations_pred, activations_target, eps=self.eps) | |
for img_i in range(activations_pred.shape[0]) | |
) | |
errors = real_fid - fid_no_images | |
return distribute_values_to_classes(class_freq, errors, self.segm_idx2name) | |
def _get_activations(self, batch): | |
activations = self.model(batch)[0] | |
if activations.shape[2] != 1 or activations.shape[3] != 1: | |
activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1)) | |
activations = activations.squeeze(-1).squeeze(-1).detach().cpu().numpy() | |
return activations | |