Spaces:
Runtime error
Runtime error
File size: 4,406 Bytes
99e984c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch import Tensor
class ContentLoss(nn.Module):
"""Constructs a content loss function based on the VGG19 network.
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
Paper reference list:
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
"""
def __init__(self) -> None:
super(ContentLoss, self).__init__()
# Load the VGG19 model trained on the ImageNet dataset.
vgg19 = models.vgg19(pretrained=True).eval()
# Extract the thirty-sixth layer output in the VGG19 model as the content loss.
self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
# Freeze model parameters.
for parameters in self.feature_extractor.parameters():
parameters.requires_grad = False
# The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
# Standardized operations
sr = sr.sub(self.mean).div(self.std)
hr = hr.sub(self.mean).div(self.std)
# Find the feature map difference between the two images
loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
return loss
class GenGaussLoss(nn.Module):
def __init__(
self, reduction='mean',
alpha_eps = 1e-4, beta_eps=1e-4,
resi_min = 1e-4, resi_max=1e3
) -> None:
super(GenGaussLoss, self).__init__()
self.reduction = reduction
self.alpha_eps = alpha_eps
self.beta_eps = beta_eps
self.resi_min = resi_min
self.resi_max = resi_max
def forward(
self,
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
):
one_over_alpha1 = one_over_alpha + self.alpha_eps
beta1 = beta + self.beta_eps
resi = torch.abs(mean - target)
# resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
## check if resi has nans
if torch.sum(resi != resi) > 0:
print('resi has nans!!')
return None
log_one_over_alpha = torch.log(one_over_alpha1)
log_beta = torch.log(beta1)
lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
print('log_one_over_alpha has nan')
if torch.sum(lgamma_beta != lgamma_beta) > 0:
print('lgamma_beta has nan')
if torch.sum(log_beta != log_beta) > 0:
print('log_beta has nan')
l = resi - log_one_over_alpha + lgamma_beta - log_beta
if self.reduction == 'mean':
return l.mean()
elif self.reduction == 'sum':
return l.sum()
else:
print('Reduction not supported')
return None
class TempCombLoss(nn.Module):
def __init__(
self, reduction='mean',
alpha_eps = 1e-4, beta_eps=1e-4,
resi_min = 1e-4, resi_max=1e3
) -> None:
super(TempCombLoss, self).__init__()
self.reduction = reduction
self.alpha_eps = alpha_eps
self.beta_eps = beta_eps
self.resi_min = resi_min
self.resi_max = resi_max
self.L_GenGauss = GenGaussLoss(
reduction=self.reduction,
alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
resi_min=self.resi_min, resi_max=self.resi_max
)
self.L_l1 = nn.L1Loss(reduction=self.reduction)
def forward(
self,
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
T1: float, T2: float
):
l1 = self.L_l1(mean, target)
l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
l = T1*l1 + T2*l2
return l
# x1 = torch.randn(4,3,32,32)
# x2 = torch.rand(4,3,32,32)
# x3 = torch.rand(4,3,32,32)
# x4 = torch.randn(4,3,32,32)
# L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
# L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
# print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2)) |