Spaces:
Runtime error
Runtime error
File size: 9,733 Bytes
d380b77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import logging
import math
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
import tqdm
from torch.utils.data import DataLoader
from saicinpainting.evaluation.utils import move_to_device
LOGGER = logging.getLogger(__name__)
class InpaintingEvaluator():
def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda',
integral_func=None, integral_title=None, clamp_image_range=None):
"""
:param dataset: torch.utils.data.Dataset which contains images and masks
:param scores: dict {score_name: EvaluatorScore object}
:param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples
which are defined by share of area occluded by mask
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
:param batch_size: batch_size for the dataloader
:param device: device to use
"""
self.scores = scores
self.dataset = dataset
self.area_grouping = area_grouping
self.bins = bins
self.device = torch.device(device)
self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size)
self.integral_func = integral_func
self.integral_title = integral_title
self.clamp_image_range = clamp_image_range
def _get_bin_edges(self):
bin_edges = np.linspace(0, 1, self.bins + 1)
num_digits = max(0, math.ceil(math.log10(self.bins)) - 1)
interval_names = []
for idx_bin in range(self.bins):
start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \
round(100 * bin_edges[idx_bin + 1], num_digits)
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
interval_names.append("{0}-{1}%".format(start_percent, end_percent))
groups = []
for batch in self.dataloader:
mask = batch['mask']
batch_size = mask.shape[0]
area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1)
bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1
# corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element
bin_indices[bin_indices == self.bins] = self.bins - 1
groups.append(bin_indices)
groups = np.hstack(groups)
return groups, interval_names
def evaluate(self, model=None):
"""
:param model: callable with signature (image_batch, mask_batch); should return inpainted_batch
:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
name of the particular group arranged by area of mask (e.g. '10-20%')
and score statistics for the group as values.
"""
results = dict()
if self.area_grouping:
groups, interval_names = self._get_bin_edges()
else:
groups = None
for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'):
score.to(self.device)
with torch.no_grad():
score.reset()
for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False):
batch = move_to_device(batch, self.device)
image_batch, mask_batch = batch['image'], batch['mask']
if self.clamp_image_range is not None:
image_batch = torch.clamp(image_batch,
min=self.clamp_image_range[0],
max=self.clamp_image_range[1])
if model is None:
assert 'inpainted' in batch, \
'Model is None, so we expected precomputed inpainting results at key "inpainted"'
inpainted_batch = batch['inpainted']
else:
inpainted_batch = model(image_batch, mask_batch)
score(inpainted_batch, image_batch, mask_batch)
total_results, group_results = score.get_value(groups=groups)
results[(score_name, 'total')] = total_results
if groups is not None:
for group_index, group_values in group_results.items():
group_name = interval_names[group_index]
results[(score_name, group_name)] = group_values
if self.integral_func is not None:
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
return results
def ssim_fid100_f1(metrics, fid_scale=100):
ssim = metrics[('ssim', 'total')]['mean']
fid = metrics[('fid', 'total')]['mean']
fid_rel = max(0, fid_scale - fid) / fid_scale
f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
return f1
def lpips_fid100_f1(metrics, fid_scale=100):
neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better
fid = metrics[('fid', 'total')]['mean']
fid_rel = max(0, fid_scale - fid) / fid_scale
f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3)
return f1
class InpaintingEvaluatorOnline(nn.Module):
def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted',
integral_func=None, integral_title=None, clamp_image_range=None):
"""
:param scores: dict {score_name: EvaluatorScore object}
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
:param device: device to use
"""
super().__init__()
LOGGER.info(f'{type(self)} init called')
self.scores = nn.ModuleDict(scores)
self.image_key = image_key
self.inpainted_key = inpainted_key
self.bins_num = bins
self.bin_edges = np.linspace(0, 1, self.bins_num + 1)
num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1)
self.interval_names = []
for idx_bin in range(self.bins_num):
start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \
round(100 * self.bin_edges[idx_bin + 1], num_digits)
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
self.interval_names.append("{0}-{1}%".format(start_percent, end_percent))
self.groups = []
self.integral_func = integral_func
self.integral_title = integral_title
self.clamp_image_range = clamp_image_range
LOGGER.info(f'{type(self)} init done')
def _get_bins(self, mask_batch):
batch_size = mask_batch.shape[0]
area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy()
bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1)
return bin_indices
def forward(self, batch: Dict[str, torch.Tensor]):
"""
Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end
:param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key)
"""
result = {}
with torch.no_grad():
image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key]
if self.clamp_image_range is not None:
image_batch = torch.clamp(image_batch,
min=self.clamp_image_range[0],
max=self.clamp_image_range[1])
self.groups.extend(self._get_bins(mask_batch))
for score_name, score in self.scores.items():
result[score_name] = score(inpainted_batch, image_batch, mask_batch)
return result
def process_batch(self, batch: Dict[str, torch.Tensor]):
return self(batch)
def evaluation_end(self, states=None):
""":return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
name of the particular group arranged by area of mask (e.g. '10-20%')
and score statistics for the group as values.
"""
LOGGER.info(f'{type(self)}: evaluation_end called')
self.groups = np.array(self.groups)
results = {}
for score_name, score in self.scores.items():
LOGGER.info(f'Getting value of {score_name}')
cur_states = [s[score_name] for s in states] if states is not None else None
total_results, group_results = score.get_value(groups=self.groups, states=cur_states)
LOGGER.info(f'Getting value of {score_name} done')
results[(score_name, 'total')] = total_results
for group_index, group_values in group_results.items():
group_name = self.interval_names[group_index]
results[(score_name, group_name)] = group_values
if self.integral_func is not None:
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
LOGGER.info(f'{type(self)}: reset scores')
self.groups = []
for sc in self.scores.values():
sc.reset()
LOGGER.info(f'{type(self)}: reset scores done')
LOGGER.info(f'{type(self)}: evaluation_end done')
return results
|