kirch's picture
Duplicate from PAIR/Text2Video-Zero
508927a
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [
'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward',
'softmax_focal_loss_forward', 'softmax_focal_loss_backward'
])
class SigmoidFocalLossFunction(Function):
@staticmethod
def symbolic(g, input, target, gamma, alpha, weight, reduction):
return g.op(
'mmcv::MMCVSigmoidFocalLoss',
input,
target,
gamma_f=gamma,
alpha_f=alpha,
weight_f=weight,
reduction_s=reduction)
@staticmethod
def forward(ctx,
input,
target,
gamma=2.0,
alpha=0.25,
weight=None,
reduction='mean'):
assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
assert input.dim() == 2
assert target.dim() == 1
assert input.size(0) == target.size(0)
if weight is None:
weight = input.new_empty(0)
else:
assert weight.dim() == 1
assert input.size(1) == weight.size(0)
ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
assert reduction in ctx.reduction_dict.keys()
ctx.gamma = float(gamma)
ctx.alpha = float(alpha)
ctx.reduction = ctx.reduction_dict[reduction]
output = input.new_zeros(input.size())
ext_module.sigmoid_focal_loss_forward(
input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha)
if ctx.reduction == ctx.reduction_dict['mean']:
output = output.sum() / input.size(0)
elif ctx.reduction == ctx.reduction_dict['sum']:
output = output.sum()
ctx.save_for_backward(input, target, weight)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, target, weight = ctx.saved_tensors
grad_input = input.new_zeros(input.size())
ext_module.sigmoid_focal_loss_backward(
input,
target,
weight,
grad_input,
gamma=ctx.gamma,
alpha=ctx.alpha)
grad_input *= grad_output
if ctx.reduction == ctx.reduction_dict['mean']:
grad_input /= input.size(0)
return grad_input, None, None, None, None, None
sigmoid_focal_loss = SigmoidFocalLossFunction.apply
class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
super(SigmoidFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.register_buffer('weight', weight)
self.reduction = reduction
def forward(self, input, target):
return sigmoid_focal_loss(input, target, self.gamma, self.alpha,
self.weight, self.reduction)
def __repr__(self):
s = self.__class__.__name__
s += f'(gamma={self.gamma}, '
s += f'alpha={self.alpha}, '
s += f'reduction={self.reduction})'
return s
class SoftmaxFocalLossFunction(Function):
@staticmethod
def symbolic(g, input, target, gamma, alpha, weight, reduction):
return g.op(
'mmcv::MMCVSoftmaxFocalLoss',
input,
target,
gamma_f=gamma,
alpha_f=alpha,
weight_f=weight,
reduction_s=reduction)
@staticmethod
def forward(ctx,
input,
target,
gamma=2.0,
alpha=0.25,
weight=None,
reduction='mean'):
assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
assert input.dim() == 2
assert target.dim() == 1
assert input.size(0) == target.size(0)
if weight is None:
weight = input.new_empty(0)
else:
assert weight.dim() == 1
assert input.size(1) == weight.size(0)
ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
assert reduction in ctx.reduction_dict.keys()
ctx.gamma = float(gamma)
ctx.alpha = float(alpha)
ctx.reduction = ctx.reduction_dict[reduction]
channel_stats, _ = torch.max(input, dim=1)
input_softmax = input - channel_stats.unsqueeze(1).expand_as(input)
input_softmax.exp_()
channel_stats = input_softmax.sum(dim=1)
input_softmax /= channel_stats.unsqueeze(1).expand_as(input)
output = input.new_zeros(input.size(0))
ext_module.softmax_focal_loss_forward(
input_softmax,
target,
weight,
output,
gamma=ctx.gamma,
alpha=ctx.alpha)
if ctx.reduction == ctx.reduction_dict['mean']:
output = output.sum() / input.size(0)
elif ctx.reduction == ctx.reduction_dict['sum']:
output = output.sum()
ctx.save_for_backward(input_softmax, target, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input_softmax, target, weight = ctx.saved_tensors
buff = input_softmax.new_zeros(input_softmax.size(0))
grad_input = input_softmax.new_zeros(input_softmax.size())
ext_module.softmax_focal_loss_backward(
input_softmax,
target,
weight,
buff,
grad_input,
gamma=ctx.gamma,
alpha=ctx.alpha)
grad_input *= grad_output
if ctx.reduction == ctx.reduction_dict['mean']:
grad_input /= input_softmax.size(0)
return grad_input, None, None, None, None, None
softmax_focal_loss = SoftmaxFocalLossFunction.apply
class SoftmaxFocalLoss(nn.Module):
def __init__(self, gamma, alpha, weight=None, reduction='mean'):
super(SoftmaxFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.register_buffer('weight', weight)
self.reduction = reduction
def forward(self, input, target):
return softmax_focal_loss(input, target, self.gamma, self.alpha,
self.weight, self.reduction)
def __repr__(self):
s = self.__class__.__name__
s += f'(gamma={self.gamma}, '
s += f'alpha={self.alpha}, '
s += f'reduction={self.reduction})'
return s