File size: 3,393 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import Optional
import torch
import torch.nn.functional as F

def get_breadstick_probabilities(logits):
    # logits to linear projectsion to probabilities 
    projections = torch.sigmoid(logits)
    batch_size, num_projections = projections.shape
    num_classes = num_projections + 1
    probs = torch.zeros(batch_size, num_classes, dtype=projections.dtype, device=projections.device)
    prob = 0
    for i in range(projections.shape[-1]):
        prob = projections[:,i] * (1 - prob)
        probs[:,i] = prob
    probs[:, -1] = (1 - projections.select(-1, -1)) * prob
    return probs

def ordinal_regression_loss(logits, targets, tao=1, eta=0.15, class_weights=None):
    '''
    ordinal regression loss based on this paper:
    Liu, Xiaofeng, et al. "Unimodal regularized neuron stick-breaking for ordinal classification." Neurocomputing 388 (2020): 34-44.

    it is important to have N-1 logits with N classes

    '''
    probs = get_breadstick_probabilities(logits)
    batch_size, num_classes = probs.shape
    class_weights = class_weights if class_weights is not None else torch.ones(num_classes, device=logits.device)

    # cross entropy loss with unimodal regularization
    # distribution of normalized exponential function

    q = torch.softmax(
        torch.exp(
            -torch.abs(
                torch.arange(num_classes, device=targets.device).repeat(batch_size, 1) - targets.reshape(batch_size,1)
            ) / tao
        ), dim=-1
    )
    # smooth distribution with eta
    
    q = (1 - eta) * q + eta / num_classes
    loss = torch.sum(class_weights * q * -torch.log(probs), dim=-1).mean()

    return loss

# Adapted from https://github.com/kornia/kornia/blob/c2273bbfe152c86a48923e473a37c05e28f7fe43/kornia/losses/focal.py
def focal_loss(input: torch.Tensor, target: torch.Tensor, alpha: float = 0.25, gamma: Optional[float] = 2.0) -> torch.Tensor:
    """Criterion that computes Focal loss.
    According to Lin, Tsung-Yi, et al. "Focal Loss for Dense Object Detection". Proceedings of the IEEE international conference on computer vision (2017): 2980-2988,
    the Focal loss is computed as follows:
    .. math::
        \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
    """
    assert input.size(0) == target.size(0)

    if len(input.shape) == 1:
        # binary focal loss
        probs_pos = torch.sigmoid(input)
        probs_neg = torch.sigmoid(-input)
        loss_tmp = -alpha * torch.pow(probs_neg, gamma) * target * F.logsigmoid(input) - (1 - alpha) * torch.pow(probs_pos, gamma) * (1.0 - target) * F.logsigmoid(-input)

    else:
        assert len(input.shape) >= 2

        assert target.size()[1:] == input.size()[2:]

        # compute softmax over the classes axis
        input_soft: torch.Tensor = F.softmax(input, dim=1)
        log_input_soft: torch.Tensor = F.log_softmax(input, dim=1)

        # create the labels one hot tensor
        one_hot = torch.zeros((target.shape[0], input.shape[1]) + target.shape[1:], device=input.device, dtype=input.dtype)
        target_one_hot: torch.Tensor = one_hot.scatter_(1, target.unsqueeze(1), 1.0) + 1e-6

        # compute the actual focal loss
        weight = torch.pow(-input_soft + 1.0, gamma)

        focal = -alpha * weight * log_input_soft
        loss_tmp = torch.einsum('bc...,bc...->b...', (target_one_hot, focal))

    return loss_tmp.mean()