Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Utility module to handle adversarial losses without requiring to mess up the main training loop. | |
""" | |
import typing as tp | |
import flashy | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2'] | |
AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]] | |
FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] | |
class AdversarialLoss(nn.Module): | |
"""Adversary training wrapper. | |
Args: | |
adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples. | |
We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]`` | |
where the first item is a list of logits and the second item is a list of feature maps. | |
optimizer (torch.optim.Optimizer): Optimizer used for training the given module. | |
loss (AdvLossType): Loss function for generator training. | |
loss_real (AdvLossType): Loss function for adversarial training on logits from real samples. | |
loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples. | |
loss_feat (FeatLossType): Feature matching loss function for generator training. | |
normalize (bool): Whether to normalize by number of sub-discriminators. | |
Example of usage: | |
adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake) | |
for real in loader: | |
noise = torch.randn(...) | |
fake = model(noise) | |
adv_loss.train_adv(fake, real) | |
loss, _ = adv_loss(fake, real) | |
loss.backward() | |
""" | |
def __init__(self, | |
adversary: nn.Module, | |
optimizer: torch.optim.Optimizer, | |
loss: AdvLossType, | |
loss_real: AdvLossType, | |
loss_fake: AdvLossType, | |
loss_feat: tp.Optional[FeatLossType] = None, | |
normalize: bool = True): | |
super().__init__() | |
self.adversary: nn.Module = adversary | |
flashy.distrib.broadcast_model(self.adversary) | |
self.optimizer = optimizer | |
self.loss = loss | |
self.loss_real = loss_real | |
self.loss_fake = loss_fake | |
self.loss_feat = loss_feat | |
self.normalize = normalize | |
def _save_to_state_dict(self, destination, prefix, keep_vars): | |
# Add the optimizer state dict inside our own. | |
super()._save_to_state_dict(destination, prefix, keep_vars) | |
destination[prefix + 'optimizer'] = self.optimizer.state_dict() | |
return destination | |
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): | |
# Load optimizer state. | |
self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer')) | |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) | |
def get_adversary_pred(self, x): | |
"""Run adversary model, validating expected output format.""" | |
logits, fmaps = self.adversary(x) | |
assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \ | |
f'Expecting a list of tensors as logits but {type(logits)} found.' | |
assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.' | |
for fmap in fmaps: | |
assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \ | |
f'Expecting a list of tensors as feature maps but {type(fmap)} found.' | |
return logits, fmaps | |
def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor: | |
"""Train the adversary with the given fake and real example. | |
We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]]. | |
The first item being the logits and second item being a list of feature maps for each sub-discriminator. | |
This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`) | |
and call the optimizer. | |
""" | |
loss = torch.tensor(0., device=fake.device) | |
all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach()) | |
all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach()) | |
n_sub_adversaries = len(all_logits_fake_is_fake) | |
for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake): | |
loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake) | |
if self.normalize: | |
loss /= n_sub_adversaries | |
self.optimizer.zero_grad() | |
with flashy.distrib.eager_sync_model(self.adversary): | |
loss.backward() | |
self.optimizer.step() | |
return loss | |
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
"""Return the loss for the generator, i.e. trying to fool the adversary, | |
and feature matching loss if provided. | |
""" | |
adv = torch.tensor(0., device=fake.device) | |
feat = torch.tensor(0., device=fake.device) | |
with flashy.utils.readonly(self.adversary): | |
all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake) | |
all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real) | |
n_sub_adversaries = len(all_logits_fake_is_fake) | |
for logit_fake_is_fake in all_logits_fake_is_fake: | |
adv += self.loss(logit_fake_is_fake) | |
if self.loss_feat: | |
for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real): | |
feat += self.loss_feat(fmap_fake, fmap_real) | |
if self.normalize: | |
adv /= n_sub_adversaries | |
feat /= n_sub_adversaries | |
return adv, feat | |
def get_adv_criterion(loss_type: str) -> tp.Callable: | |
assert loss_type in ADVERSARIAL_LOSSES | |
if loss_type == 'mse': | |
return mse_loss | |
elif loss_type == 'hinge': | |
return hinge_loss | |
elif loss_type == 'hinge2': | |
return hinge2_loss | |
raise ValueError('Unsupported loss') | |
def get_fake_criterion(loss_type: str) -> tp.Callable: | |
assert loss_type in ADVERSARIAL_LOSSES | |
if loss_type == 'mse': | |
return mse_fake_loss | |
elif loss_type in ['hinge', 'hinge2']: | |
return hinge_fake_loss | |
raise ValueError('Unsupported loss') | |
def get_real_criterion(loss_type: str) -> tp.Callable: | |
assert loss_type in ADVERSARIAL_LOSSES | |
if loss_type == 'mse': | |
return mse_real_loss | |
elif loss_type in ['hinge', 'hinge2']: | |
return hinge_real_loss | |
raise ValueError('Unsupported loss') | |
def mse_real_loss(x: torch.Tensor) -> torch.Tensor: | |
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x)) | |
def mse_fake_loss(x: torch.Tensor) -> torch.Tensor: | |
return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x)) | |
def hinge_real_loss(x: torch.Tensor) -> torch.Tensor: | |
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x))) | |
def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor: | |
return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x))) | |
def mse_loss(x: torch.Tensor) -> torch.Tensor: | |
if x.numel() == 0: | |
return torch.tensor([0.0], device=x.device) | |
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x)) | |
def hinge_loss(x: torch.Tensor) -> torch.Tensor: | |
if x.numel() == 0: | |
return torch.tensor([0.0], device=x.device) | |
return -x.mean() | |
def hinge2_loss(x: torch.Tensor) -> torch.Tensor: | |
if x.numel() == 0: | |
return torch.tensor([0.0]) | |
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x))) | |
class FeatureMatchingLoss(nn.Module): | |
"""Feature matching loss for adversarial training. | |
Args: | |
loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1). | |
normalize (bool): Whether to normalize the loss. | |
by number of feature maps. | |
""" | |
def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True): | |
super().__init__() | |
self.loss = loss | |
self.normalize = normalize | |
def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor: | |
assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0 | |
feat_loss = torch.tensor(0., device=fmap_fake[0].device) | |
feat_scale = torch.tensor(0., device=fmap_fake[0].device) | |
n_fmaps = 0 | |
for (feat_fake, feat_real) in zip(fmap_fake, fmap_real): | |
assert feat_fake.shape == feat_real.shape | |
n_fmaps += 1 | |
feat_loss += self.loss(feat_fake, feat_real) | |
feat_scale += torch.mean(torch.abs(feat_real)) | |
if self.normalize: | |
feat_loss /= n_fmaps | |
return feat_loss | |