Spaces:
Runtime error
Runtime error
File size: 891 Bytes
2de1f98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
import torch.nn as nn
class CoordLoss(nn.Module):
def __init__(self):
super(CoordLoss, self).__init__()
def forward(self, coord_out, coord_gt, valid, is_3D=None):
loss = torch.abs(coord_out - coord_gt) * valid
if is_3D is not None:
loss_z = loss[:,:,2:] * is_3D[:,None,None].float()
loss = torch.cat((loss[:,:,:2], loss_z),2)
return loss
class ParamLoss(nn.Module):
def __init__(self):
super(ParamLoss, self).__init__()
def forward(self, param_out, param_gt, valid):
loss = torch.abs(param_out - param_gt) * valid
return loss
class CELoss(nn.Module):
def __init__(self):
super(CELoss, self).__init__()
self.ce_loss = nn.CrossEntropyLoss(reduction='none')
def forward(self, out, gt_index):
loss = self.ce_loss(out, gt_index)
return loss
|