Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
class CoordLoss(nn.Module): | |
def __init__(self): | |
super(CoordLoss, self).__init__() | |
def forward(self, coord_out, coord_gt, valid, is_3D=None): | |
loss = torch.abs(coord_out - coord_gt) * valid | |
if is_3D is not None: | |
loss_z = loss[:,:,2:] * is_3D[:,None,None].float() | |
loss = torch.cat((loss[:,:,:2], loss_z),2) | |
return loss | |
class ParamLoss(nn.Module): | |
def __init__(self): | |
super(ParamLoss, self).__init__() | |
def forward(self, param_out, param_gt, valid): | |
loss = torch.abs(param_out - param_gt) * valid | |
return loss | |
class CELoss(nn.Module): | |
def __init__(self): | |
super(CELoss, self).__init__() | |
self.ce_loss = nn.CrossEntropyLoss(reduction='none') | |
def forward(self, out, gt_index): | |
loss = self.ce_loss(out, gt_index) | |
return loss | |