Spaces:
Sleeping
Sleeping
from typing import Tuple, Dict, Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class BaseAdversarialLoss: | |
def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |
generator: nn.Module, discriminator: nn.Module): | |
""" | |
Prepare for generator step | |
:param real_batch: Tensor, a batch of real samples | |
:param fake_batch: Tensor, a batch of samples produced by generator | |
:param generator: | |
:param discriminator: | |
:return: None | |
""" | |
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |
generator: nn.Module, discriminator: nn.Module): | |
""" | |
Prepare for discriminator step | |
:param real_batch: Tensor, a batch of real samples | |
:param fake_batch: Tensor, a batch of samples produced by generator | |
:param generator: | |
:param discriminator: | |
:return: None | |
""" | |
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |
mask: Optional[torch.Tensor] = None) \ | |
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
""" | |
Calculate generator loss | |
:param real_batch: Tensor, a batch of real samples | |
:param fake_batch: Tensor, a batch of samples produced by generator | |
:param discr_real_pred: Tensor, discriminator output for real_batch | |
:param discr_fake_pred: Tensor, discriminator output for fake_batch | |
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch | |
:return: total generator loss along with some values that might be interesting to log | |
""" | |
raise NotImplemented() | |
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |
mask: Optional[torch.Tensor] = None) \ | |
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
""" | |
Calculate discriminator loss and call .backward() on it | |
:param real_batch: Tensor, a batch of real samples | |
:param fake_batch: Tensor, a batch of samples produced by generator | |
:param discr_real_pred: Tensor, discriminator output for real_batch | |
:param discr_fake_pred: Tensor, discriminator output for fake_batch | |
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch | |
:return: total discriminator loss along with some values that might be interesting to log | |
""" | |
raise NotImplemented() | |
def interpolate_mask(self, mask, shape): | |
assert mask is not None | |
assert self.allow_scale_mask or shape == mask.shape[-2:] | |
if shape != mask.shape[-2:] and self.allow_scale_mask: | |
if self.mask_scale_mode == 'maxpool': | |
mask = F.adaptive_max_pool2d(mask, shape) | |
else: | |
mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode) | |
return mask | |
def make_r1_gp(discr_real_pred, real_batch): | |
if torch.is_grad_enabled(): | |
grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0] | |
grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean() | |
else: | |
grad_penalty = 0 | |
real_batch.requires_grad = False | |
return grad_penalty | |
class NonSaturatingWithR1(BaseAdversarialLoss): | |
def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False, | |
mask_scale_mode='nearest', extra_mask_weight_for_gen=0, | |
use_unmasked_for_gen=True, use_unmasked_for_discr=True): | |
self.gp_coef = gp_coef | |
self.weight = weight | |
# use for discr => use for gen; | |
# otherwise we teach only the discr to pay attention to very small difference | |
assert use_unmasked_for_gen or (not use_unmasked_for_discr) | |
# mask as target => use unmasked for discr: | |
# if we don't care about unmasked regions at all | |
# then it doesn't matter if the value of mask_as_fake_target is true or false | |
assert use_unmasked_for_discr or (not mask_as_fake_target) | |
self.use_unmasked_for_gen = use_unmasked_for_gen | |
self.use_unmasked_for_discr = use_unmasked_for_discr | |
self.mask_as_fake_target = mask_as_fake_target | |
self.allow_scale_mask = allow_scale_mask | |
self.mask_scale_mode = mask_scale_mode | |
self.extra_mask_weight_for_gen = extra_mask_weight_for_gen | |
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |
mask=None) \ | |
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
fake_loss = F.softplus(-discr_fake_pred) | |
if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \ | |
not self.use_unmasked_for_gen: # == if masked region should be treated differently | |
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) | |
if not self.use_unmasked_for_gen: | |
fake_loss = fake_loss * mask | |
else: | |
pixel_weights = 1 + mask * self.extra_mask_weight_for_gen | |
fake_loss = fake_loss * pixel_weights | |
return fake_loss.mean() * self.weight, dict() | |
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |
generator: nn.Module, discriminator: nn.Module): | |
real_batch.requires_grad = True | |
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |
mask=None) \ | |
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
real_loss = F.softplus(-discr_real_pred) | |
grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef | |
fake_loss = F.softplus(discr_fake_pred) | |
if not self.use_unmasked_for_discr or self.mask_as_fake_target: | |
# == if masked region should be treated differently | |
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) | |
# use_unmasked_for_discr=False only makes sense for fakes; | |
# for reals there is no difference beetween two regions | |
fake_loss = fake_loss * mask | |
if self.mask_as_fake_target: | |
fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred) | |
sum_discr_loss = real_loss + grad_penalty + fake_loss | |
metrics = dict(discr_real_out=discr_real_pred.mean(), | |
discr_fake_out=discr_fake_pred.mean(), | |
discr_real_gp=grad_penalty) | |
return sum_discr_loss.mean(), metrics | |
class BCELoss(BaseAdversarialLoss): | |
def __init__(self, weight): | |
self.weight = weight | |
self.bce_loss = nn.BCEWithLogitsLoss() | |
def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device) | |
fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight | |
return fake_loss, dict() | |
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |
generator: nn.Module, discriminator: nn.Module): | |
real_batch.requires_grad = True | |
def discriminator_loss(self, | |
mask: torch.Tensor, | |
discr_real_pred: torch.Tensor, | |
discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device) | |
sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2 | |
metrics = dict(discr_real_out=discr_real_pred.mean(), | |
discr_fake_out=discr_fake_pred.mean(), | |
discr_real_gp=0) | |
return sum_discr_loss, metrics | |
def make_discrim_loss(kind, **kwargs): | |
if kind == 'r1': | |
return NonSaturatingWithR1(**kwargs) | |
elif kind == 'bce': | |
return BCELoss(**kwargs) | |
raise ValueError(f'Unknown adversarial loss kind {kind}') | |