Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class WeightedLoss(nn.Module): | |
def func(self): | |
raise NotImplementedError | |
def forward(self, inputs, targets, weight=None, reduction='mean'): | |
assert reduction in ['none', 'sum', 'mean', 'valid_mean'] | |
loss = self.func(inputs, targets, reduction='none') | |
if weight is not None: | |
while weight.ndim < inputs.ndim: | |
weight = weight[..., None] | |
loss *= weight.float() | |
if reduction == 'none': | |
return loss | |
elif reduction == 'sum': | |
return loss.sum() | |
elif reduction == 'mean': | |
return loss.mean() | |
elif reduction == 'valid_mean': | |
return loss.sum() / weight.float().sum() | |
class MSELoss(WeightedLoss): | |
def func(self): | |
return F.mse_loss | |
class L1Loss(WeightedLoss): | |
def func(self): | |
return F.l1_loss | |
class PSNR(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, inputs, targets, valid_mask=None, reduction='mean'): | |
assert reduction in ['mean', 'none'] | |
value = (inputs - targets)**2 | |
if valid_mask is not None: | |
value = value[valid_mask] | |
if reduction == 'mean': | |
return -10 * torch.log10(torch.mean(value)) | |
elif reduction == 'none': | |
return -10 * torch.log10(torch.mean(value, dim=tuple(range(value.ndim)[1:]))) | |
class SSIM(): | |
def __init__(self, data_range=(0, 1), kernel_size=(11, 11), sigma=(1.5, 1.5), k1=0.01, k2=0.03, gaussian=True): | |
self.kernel_size = kernel_size | |
self.sigma = sigma | |
self.gaussian = gaussian | |
if any(x % 2 == 0 or x <= 0 for x in self.kernel_size): | |
raise ValueError(f"Expected kernel_size to have odd positive number. Got {kernel_size}.") | |
if any(y <= 0 for y in self.sigma): | |
raise ValueError(f"Expected sigma to have positive number. Got {sigma}.") | |
data_scale = data_range[1] - data_range[0] | |
self.c1 = (k1 * data_scale)**2 | |
self.c2 = (k2 * data_scale)**2 | |
self.pad_h = (self.kernel_size[0] - 1) // 2 | |
self.pad_w = (self.kernel_size[1] - 1) // 2 | |
self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma) | |
def _uniform(self, kernel_size): | |
max, min = 2.5, -2.5 | |
ksize_half = (kernel_size - 1) * 0.5 | |
kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) | |
for i, j in enumerate(kernel): | |
if min <= j <= max: | |
kernel[i] = 1 / (max - min) | |
else: | |
kernel[i] = 0 | |
return kernel.unsqueeze(dim=0) # (1, kernel_size) | |
def _gaussian(self, kernel_size, sigma): | |
ksize_half = (kernel_size - 1) * 0.5 | |
kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) | |
gauss = torch.exp(-0.5 * (kernel / sigma).pow(2)) | |
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) | |
def _gaussian_or_uniform_kernel(self, kernel_size, sigma): | |
if self.gaussian: | |
kernel_x = self._gaussian(kernel_size[0], sigma[0]) | |
kernel_y = self._gaussian(kernel_size[1], sigma[1]) | |
else: | |
kernel_x = self._uniform(kernel_size[0]) | |
kernel_y = self._uniform(kernel_size[1]) | |
return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size) | |
def __call__(self, output, target, reduction='mean'): | |
if output.dtype != target.dtype: | |
raise TypeError( | |
f"Expected output and target to have the same data type. Got output: {output.dtype} and y: {target.dtype}." | |
) | |
if output.shape != target.shape: | |
raise ValueError( | |
f"Expected output and target to have the same shape. Got output: {output.shape} and y: {target.shape}." | |
) | |
if len(output.shape) != 4 or len(target.shape) != 4: | |
raise ValueError( | |
f"Expected output and target to have BxCxHxW shape. Got output: {output.shape} and y: {target.shape}." | |
) | |
assert reduction in ['mean', 'sum', 'none'] | |
channel = output.size(1) | |
if len(self._kernel.shape) < 4: | |
self._kernel = self._kernel.expand(channel, 1, -1, -1) | |
output = F.pad(output, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") | |
target = F.pad(target, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") | |
input_list = torch.cat([output, target, output * output, target * target, output * target]) | |
outputs = F.conv2d(input_list, self._kernel, groups=channel) | |
output_list = [outputs[x * output.size(0) : (x + 1) * output.size(0)] for x in range(len(outputs))] | |
mu_pred_sq = output_list[0].pow(2) | |
mu_target_sq = output_list[1].pow(2) | |
mu_pred_target = output_list[0] * output_list[1] | |
sigma_pred_sq = output_list[2] - mu_pred_sq | |
sigma_target_sq = output_list[3] - mu_target_sq | |
sigma_pred_target = output_list[4] - mu_pred_target | |
a1 = 2 * mu_pred_target + self.c1 | |
a2 = 2 * sigma_pred_target + self.c2 | |
b1 = mu_pred_sq + mu_target_sq + self.c1 | |
b2 = sigma_pred_sq + sigma_target_sq + self.c2 | |
ssim_idx = (a1 * a2) / (b1 * b2) | |
_ssim = torch.mean(ssim_idx, (1, 2, 3)) | |
if reduction == 'none': | |
return _ssim | |
elif reduction == 'sum': | |
return _ssim.sum() | |
elif reduction == 'mean': | |
return _ssim.mean() | |
def binary_cross_entropy(input, target, reduction='mean'): | |
""" | |
F.binary_cross_entropy is not numerically stable in mixed-precision training. | |
""" | |
loss = -(target * torch.log(input) + (1 - target) * torch.log(1 - input)) | |
if reduction == 'mean': | |
return loss.mean() | |
elif reduction == 'none': | |
return loss | |