import torch import torch.nn as nn class GANHingeLoss(nn.Module): def __init__(self): super(GANHingeLoss, self).__init__() self.relu = nn.ReLU() def __call__(self, pred, is_real, for_discriminator): if for_discriminator: if is_real: return self.relu(1 - pred).mean() return self.relu(1 + pred).mean() assert is_real, "The generator's hinge loss must be aiming for real" return -1.0 * pred.mean()