deepkyu's picture
initial commit
1ba3df3
import torch
import torch.nn as nn
class GANHingeLoss(nn.Module):
def __init__(self):
super(GANHingeLoss, self).__init__()
self.relu = nn.ReLU()
def __call__(self, pred, is_real, for_discriminator):
if for_discriminator:
if is_real:
return self.relu(1 - pred).mean()
return self.relu(1 + pred).mean()
assert is_real, "The generator's hinge loss must be aiming for real"
return -1.0 * pred.mean()