|
|
|
|
|
|
|
""" |
|
@Author : Peike Li |
|
@Contact : peike.li@yahoo.com |
|
@File : kl_loss.py |
|
@Time : 7/23/19 4:02 PM |
|
@Desc : |
|
@License : This source code is licensed under the license found in the |
|
LICENSE file in the root directory of this source tree. |
|
""" |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
|
|
def flatten_probas(input, target, labels, ignore=255): |
|
""" |
|
Flattens predictions in the batch. |
|
""" |
|
B, C, H, W = input.size() |
|
input = input.permute(0, 2, 3, 1).contiguous().view(-1, C) |
|
target = target.permute(0, 2, 3, 1).contiguous().view(-1, C) |
|
labels = labels.view(-1) |
|
if ignore is None: |
|
return input, target |
|
valid = (labels != ignore) |
|
vinput = input[valid.nonzero().squeeze()] |
|
vtarget = target[valid.nonzero().squeeze()] |
|
return vinput, vtarget |
|
|
|
|
|
class KLDivergenceLoss(nn.Module): |
|
def __init__(self, ignore_index=255, T=1): |
|
super(KLDivergenceLoss, self).__init__() |
|
self.ignore_index=ignore_index |
|
self.T = T |
|
|
|
def forward(self, input, target, label): |
|
log_input_prob = F.log_softmax(input / self.T, dim=1) |
|
target_porb = F.softmax(target / self.T, dim=1) |
|
loss = F.kl_div(*flatten_probas(log_input_prob, target_porb, label, ignore=self.ignore_index)) |
|
return self.T*self.T*loss |
|
|