Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import functools | |
from typing import Callable, Optional | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor | |
def reduce_loss(loss: Tensor, reduction: str) -> Tensor: | |
"""Reduce loss as specified. | |
Args: | |
loss (Tensor): Elementwise loss tensor. | |
reduction (str): Options are "none", "mean" and "sum". | |
Return: | |
Tensor: Reduced loss tensor. | |
""" | |
reduction_enum = F._Reduction.get_enum(reduction) | |
# none: 0, elementwise_mean:1, sum: 2 | |
if reduction_enum == 0: | |
return loss | |
elif reduction_enum == 1: | |
return loss.mean() | |
elif reduction_enum == 2: | |
return loss.sum() | |
def weight_reduce_loss(loss: Tensor, | |
weight: Optional[Tensor] = None, | |
reduction: str = 'mean', | |
avg_factor: Optional[float] = None) -> Tensor: | |
"""Apply element-wise weight and reduce loss. | |
Args: | |
loss (Tensor): Element-wise loss. | |
weight (Optional[Tensor], optional): Element-wise weights. | |
Defaults to None. | |
reduction (str, optional): Same as built-in losses of PyTorch. | |
Defaults to 'mean'. | |
avg_factor (Optional[float], optional): Average factor when | |
computing the mean of losses. Defaults to None. | |
Returns: | |
Tensor: Processed loss values. | |
""" | |
# if weight is specified, apply element-wise weight | |
if weight is not None: | |
loss = loss * weight | |
# if avg_factor is not specified, just reduce the loss | |
if avg_factor is None: | |
loss = reduce_loss(loss, reduction) | |
else: | |
# if reduction is mean, then average the loss by avg_factor | |
if reduction == 'mean': | |
# Avoid causing ZeroDivisionError when avg_factor is 0.0, | |
# i.e., all labels of an image belong to ignore index. | |
eps = torch.finfo(torch.float32).eps | |
loss = loss.sum() / (avg_factor + eps) | |
# if reduction is 'none', then do nothing, otherwise raise an error | |
elif reduction != 'none': | |
raise ValueError('avg_factor can not be used with reduction="sum"') | |
return loss | |
def weighted_loss(loss_func: Callable) -> Callable: | |
"""Create a weighted version of a given loss function. | |
To use this decorator, the loss function must have the signature like | |
`loss_func(pred, target, **kwargs)`. The function only needs to compute | |
element-wise loss without any reduction. This decorator will add weight | |
and reduction arguments to the function. The decorated function will have | |
the signature like `loss_func(pred, target, weight=None, reduction='mean', | |
avg_factor=None, **kwargs)`. | |
:Example: | |
>>> import torch | |
>>> @weighted_loss | |
>>> def l1_loss(pred, target): | |
>>> return (pred - target).abs() | |
>>> pred = torch.Tensor([0, 2, 3]) | |
>>> target = torch.Tensor([1, 1, 1]) | |
>>> weight = torch.Tensor([1, 0, 1]) | |
>>> l1_loss(pred, target) | |
tensor(1.3333) | |
>>> l1_loss(pred, target, weight) | |
tensor(1.) | |
>>> l1_loss(pred, target, reduction='none') | |
tensor([1., 1., 2.]) | |
>>> l1_loss(pred, target, weight, avg_factor=2) | |
tensor(1.5000) | |
""" | |
def wrapper(pred: Tensor, | |
target: Tensor, | |
weight: Optional[Tensor] = None, | |
reduction: str = 'mean', | |
avg_factor: Optional[int] = None, | |
**kwargs) -> Tensor: | |
""" | |
Args: | |
pred (Tensor): The prediction. | |
target (Tensor): Target bboxes. | |
weight (Optional[Tensor], optional): The weight of loss for each | |
prediction. Defaults to None. | |
reduction (str, optional): Options are "none", "mean" and "sum". | |
Defaults to 'mean'. | |
avg_factor (Optional[int], optional): Average factor that is used | |
to average the loss. Defaults to None. | |
Returns: | |
Tensor: Loss tensor. | |
""" | |
# get element-wise loss | |
loss = loss_func(pred, target, **kwargs) | |
loss = weight_reduce_loss(loss, weight, reduction, avg_factor) | |
return loss | |
return wrapper | |