Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# Based on the timm code base | |
# https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# -------------------------------------------------------- | |
""" Cross Entropy w/ smoothing or soft targets | |
Hacked together by / Copyright 2021 Ross Wightman | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class LabelSmoothingCrossEntropy(nn.Module): | |
""" NLL loss with label smoothing. | |
""" | |
def __init__(self, smoothing=0.1): | |
super(LabelSmoothingCrossEntropy, self).__init__() | |
assert smoothing < 1.0 | |
self.smoothing = smoothing | |
self.confidence = 1. - smoothing | |
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
logprobs = F.log_softmax(x, dim=-1) | |
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) | |
nll_loss = nll_loss.squeeze(1) | |
smooth_loss = -logprobs.mean(dim=-1) | |
loss = self.confidence * nll_loss + self.smoothing * smooth_loss | |
return loss.mean() | |
class SoftTargetCrossEntropy(nn.Module): | |
def __init__(self): | |
super(SoftTargetCrossEntropy, self).__init__() | |
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) | |
return loss.mean() | |