bill-jiang's picture
Init
4409449
raw
history blame
No virus
2.27 kB
import torch
import torch.nn as nn
class BaseLosses(nn.Module):
def __init__(self, cfg, losses, params, losses_func, num_joints, **kwargs):
super().__init__()
# Save parameters
self.num_joints = num_joints
self._params = params
# Add total indicator
losses.append("total") if "total" not in losses else None
# Register losses
for loss in losses:
self.register_buffer(loss, torch.tensor(0.0))
self.register_buffer("count", torch.tensor(0.0))
self.losses = losses
# Instantiate loss functions
self._losses_func = {}
for loss in losses[:-1]:
self._losses_func[loss] = losses_func[loss](reduction='mean')
def _update_loss(self, loss: str, outputs, inputs):
'''Update the loss and return the weighted loss.'''
# Update the loss
val = self._losses_func[loss](outputs, inputs)
# self.losses_values[loss] += val.detach()
getattr(self, loss).add_(val.detach())
# Return a weighted sum
weighted_loss = self._params[loss] * val
return weighted_loss
def reset(self):
'''Reset the losses to 0.'''
for loss in self.losses:
setattr(self, loss, torch.tensor(0.0, device=getattr(self, loss).device))
setattr(self, "count", torch.tensor(0.0, device=getattr(self, "count").device))
def compute(self, split):
'''Compute the losses and return a dictionary with the losses.'''
count = self.count
# Loss dictionary
loss_dict = {loss: getattr(self, loss)/count for loss in self.losses}
# Format the losses for logging
log_dict = { self.loss2logname(loss, split): value.item()
for loss, value in loss_dict.items() if not torch.isnan(value)}
# Reset the losses
self.reset()
return log_dict
def loss2logname(self, loss: str, split: str):
'''Convert the loss name to a log name.'''
if loss == "total":
log_name = f"{loss}/{split}"
else:
loss_type, name = loss.split("_")
log_name = f"{loss_type}/{name}/{split}"
return log_name