File size: 498 Bytes
1ba3df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn

class GANHingeLoss(nn.Module):
    def __init__(self):
        super(GANHingeLoss, self).__init__()
        self.relu = nn.ReLU()
    
    def __call__(self, pred, is_real, for_discriminator):
        if for_discriminator:
            if is_real:
                return self.relu(1 - pred).mean()
            return self.relu(1 + pred).mean()
        
        assert is_real, "The generator's hinge loss must be aiming for real"
        return -1.0 * pred.mean()