|
from typing import Optional |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
def get_breadstick_probabilities(logits): |
|
|
|
projections = torch.sigmoid(logits) |
|
batch_size, num_projections = projections.shape |
|
num_classes = num_projections + 1 |
|
probs = torch.zeros(batch_size, num_classes, dtype=projections.dtype, device=projections.device) |
|
prob = 0 |
|
for i in range(projections.shape[-1]): |
|
prob = projections[:,i] * (1 - prob) |
|
probs[:,i] = prob |
|
probs[:, -1] = (1 - projections.select(-1, -1)) * prob |
|
return probs |
|
|
|
def ordinal_regression_loss(logits, targets, tao=1, eta=0.15, class_weights=None): |
|
''' |
|
ordinal regression loss based on this paper: |
|
Liu, Xiaofeng, et al. "Unimodal regularized neuron stick-breaking for ordinal classification." Neurocomputing 388 (2020): 34-44. |
|
|
|
it is important to have N-1 logits with N classes |
|
|
|
''' |
|
probs = get_breadstick_probabilities(logits) |
|
batch_size, num_classes = probs.shape |
|
class_weights = class_weights if class_weights is not None else torch.ones(num_classes, device=logits.device) |
|
|
|
|
|
|
|
|
|
q = torch.softmax( |
|
torch.exp( |
|
-torch.abs( |
|
torch.arange(num_classes, device=targets.device).repeat(batch_size, 1) - targets.reshape(batch_size,1) |
|
) / tao |
|
), dim=-1 |
|
) |
|
|
|
|
|
q = (1 - eta) * q + eta / num_classes |
|
loss = torch.sum(class_weights * q * -torch.log(probs), dim=-1).mean() |
|
|
|
return loss |
|
|
|
|
|
def focal_loss(input: torch.Tensor, target: torch.Tensor, alpha: float = 0.25, gamma: Optional[float] = 2.0) -> torch.Tensor: |
|
"""Criterion that computes Focal loss. |
|
According to Lin, Tsung-Yi, et al. "Focal Loss for Dense Object Detection". Proceedings of the IEEE international conference on computer vision (2017): 2980-2988, |
|
the Focal loss is computed as follows: |
|
.. math:: |
|
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t) |
|
""" |
|
assert input.size(0) == target.size(0) |
|
|
|
if len(input.shape) == 1: |
|
|
|
probs_pos = torch.sigmoid(input) |
|
probs_neg = torch.sigmoid(-input) |
|
loss_tmp = -alpha * torch.pow(probs_neg, gamma) * target * F.logsigmoid(input) - (1 - alpha) * torch.pow(probs_pos, gamma) * (1.0 - target) * F.logsigmoid(-input) |
|
|
|
else: |
|
assert len(input.shape) >= 2 |
|
|
|
assert target.size()[1:] == input.size()[2:] |
|
|
|
|
|
input_soft: torch.Tensor = F.softmax(input, dim=1) |
|
log_input_soft: torch.Tensor = F.log_softmax(input, dim=1) |
|
|
|
|
|
one_hot = torch.zeros((target.shape[0], input.shape[1]) + target.shape[1:], device=input.device, dtype=input.dtype) |
|
target_one_hot: torch.Tensor = one_hot.scatter_(1, target.unsqueeze(1), 1.0) + 1e-6 |
|
|
|
|
|
weight = torch.pow(-input_soft + 1.0, gamma) |
|
|
|
focal = -alpha * weight * log_input_soft |
|
loss_tmp = torch.einsum('bc...,bc...->b...', (target_one_hot, focal)) |
|
|
|
return loss_tmp.mean() |