|
import math |
|
import mmcv |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from ..builder import LOSSES |
|
|
|
|
|
@LOSSES.register_module() |
|
class RLELoss_poseur_old(nn.Module): |
|
''' RLE Regression Loss |
|
''' |
|
|
|
def __init__(self, OUTPUT_3D=False, use_target_weight=True, size_average=True): |
|
super(RLELoss_poseur_old, self).__init__() |
|
self.size_average = size_average |
|
self.amp = 1 / math.sqrt(2 * math.pi) |
|
|
|
def logQ(self, gt_uv, pred_jts, sigma): |
|
return torch.log(sigma / self.amp) + torch.abs(gt_uv - pred_jts) / (math.sqrt(2) * sigma + 1e-9) |
|
|
|
def forward(self, output, target_uv, target_uv_weight): |
|
|
|
pred_jts = output.pred_jts |
|
sigma = output.sigma |
|
gt_uv = target_uv.reshape(pred_jts.shape) |
|
gt_uv_weight = target_uv_weight.reshape(pred_jts.shape) |
|
|
|
|
|
|
|
nf_loss = output.nf_loss * gt_uv_weight[:, :, :1] |
|
|
|
|
|
residual = True |
|
if residual: |
|
Q_logprob = self.logQ(gt_uv, pred_jts, sigma) * gt_uv_weight |
|
loss = nf_loss + Q_logprob |
|
|
|
if self.size_average and gt_uv_weight.sum() > 0: |
|
return loss.sum() / len(loss) |
|
else: |
|
return loss.sum() |
|
|
|
@LOSSES.register_module() |
|
class RLELoss_poseur(nn.Module): |
|
''' RLE Regression Loss |
|
''' |
|
|
|
def __init__(self, OUTPUT_3D=False, use_target_weight=True, size_average=True): |
|
super(RLELoss_poseur, self).__init__() |
|
self.size_average = size_average |
|
self.amp = 1 / math.sqrt(2 * math.pi) |
|
|
|
def logQ(self, gt_uv, pred_jts, sigma): |
|
return torch.log(sigma / self.amp) + torch.abs(gt_uv - pred_jts) / (math.sqrt(2) * sigma + 1e-9) |
|
|
|
def forward(self, output, target_uvd, target_uvd_weight): |
|
|
|
pred_jts = output.pred_jts |
|
sigma = output.sigma |
|
gt_uv = target_uvd.reshape(pred_jts.shape) |
|
gt_uv_weight = target_uvd_weight.reshape(pred_jts.shape) |
|
|
|
|
|
nf_loss = output.nf_loss * gt_uv_weight |
|
|
|
residual = True |
|
if residual: |
|
Q_logprob = self.logQ(gt_uv, pred_jts, sigma) * gt_uv_weight |
|
loss = nf_loss + Q_logprob |
|
|
|
if self.size_average and gt_uv_weight.sum() > 0: |
|
return loss.sum() / len(loss) |
|
else: |
|
return loss.sum() |
|
|
|
@LOSSES.register_module() |
|
class RLEOHKMLoss(nn.Module): |
|
''' RLE Regression Loss |
|
''' |
|
|
|
def __init__(self, OUTPUT_3D=False, use_target_weight=True, size_average=True, topk=8, |
|
ori_weight = 1.0, ohkm_weight = 0.0): |
|
super(RLEOHKMLoss, self).__init__() |
|
self.size_average = size_average |
|
self.amp = 1 / math.sqrt(2 * math.pi) |
|
self.topk = topk |
|
self.ori_weight = ori_weight |
|
self.ohkm_weight = ohkm_weight |
|
self.neg_inf = -float("Inf") |
|
|
|
def logQ(self, gt_uv, pred_jts, sigma): |
|
return torch.log(sigma / self.amp) + torch.abs(gt_uv - pred_jts) / (math.sqrt(2) * sigma + 1e-9) |
|
|
|
def ohkm(self, loss, weight): |
|
|
|
loss_value = loss.clone().detach() |
|
loss_value[weight == 0] = self.neg_inf |
|
_, topk_idx = torch.topk( |
|
loss_value, k=self.topk, dim=1, sorted=False) |
|
tmp_loss = torch.gather(loss, 1, topk_idx) |
|
tmp_weight = torch.gather(weight, 1, topk_idx) |
|
|
|
tmp_loss = tmp_loss * tmp_weight |
|
tmp_loss = tmp_loss.flatten(start_dim=1).sum(dim = 1) |
|
|
|
|
|
|
|
return tmp_loss.mean() |
|
|
|
def ori(self, loss, weight): |
|
|
|
loss = loss * weight |
|
loss = loss.flatten(start_dim=1).sum(dim = 1) |
|
|
|
|
|
return loss.mean() |
|
|
|
def forward(self, output, target_uv, target_uv_weight): |
|
|
|
pred_jts = output.pred_jts |
|
sigma = output.sigma |
|
gt_uv = target_uv.reshape(pred_jts.shape) |
|
gt_uv_weight = target_uv_weight.reshape(pred_jts.shape) |
|
|
|
|
|
nf_loss = output.nf_loss |
|
q_loss = self.logQ(gt_uv, pred_jts, sigma) |
|
|
|
|
|
|
|
|
|
ori_loss = nf_loss + q_loss |
|
ohkm_loss = self.ohkm(ori_loss, gt_uv_weight) |
|
ori_loss = self.ori(ori_loss, gt_uv_weight) |
|
|
|
loss = self.ori_weight * ori_loss + self.ohkm_weight * ohkm_loss |
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@LOSSES.register_module() |
|
class RLELoss3D(nn.Module): |
|
''' RLE Regression Loss 3D |
|
''' |
|
|
|
def __init__(self, OUTPUT_3D=False, size_average=True): |
|
super(RLELoss3D, self).__init__() |
|
self.size_average = size_average |
|
self.amp = 1 / math.sqrt(2 * math.pi) |
|
|
|
def logQ(self, gt_uv, pred_jts, sigma): |
|
return torch.log(sigma / self.amp) + torch.abs(gt_uv - pred_jts) / (math.sqrt(2) * sigma + 1e-9) |
|
|
|
def forward(self, output, labels): |
|
nf_loss = output.nf_loss |
|
pred_jts = output.pred_jts |
|
sigma = output.sigma |
|
gt_uv = labels['target_uvd'].reshape(pred_jts.shape) |
|
gt_uv_weight = labels['target_uvd_weight'].reshape(pred_jts.shape) |
|
nf_loss = nf_loss * gt_uv_weight |
|
|
|
residual = True |
|
if residual: |
|
Q_logprob = self.logQ(gt_uv, pred_jts, sigma) * gt_uv_weight |
|
loss = nf_loss + Q_logprob |
|
|
|
if self.size_average and gt_uv_weight.sum() > 0: |
|
return loss.sum() / len(loss) |
|
else: |
|
return loss.sum() |