File size: 458 Bytes
482ab8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
import torch.nn as nn
def get_entropy_loss(opt):
return EntropyLoss()
class EntropyLoss(nn.Module):
def __init__(self):
super().__init__()
self.exp = 1e-7
assert self.exp < 0.5
def forward(self, item):
item = item.clamp(min=self.exp, max=1 - self.exp)
entropy = -item * torch.log(item) - (1 - item) * torch.log(1 - item)
entropy = entropy.mean()
return {"loss": entropy}
|