Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import typing as tp | |
import flashy | |
import torch | |
from torch import autograd | |
class Balancer: | |
"""Loss balancer. | |
The loss balancer combines losses together to compute gradients for the backward. | |
Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...` | |
not having any dependence on `f`, the balancer can efficiently normalize the partial gradients | |
`d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between | |
the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient | |
going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy | |
interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown. | |
Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be | |
(with `avg` an exponential moving average over the updates), | |
G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i) | |
If `balance_grads` is False, this is deactivated, and instead the gradient will just be the | |
standard sum of the partial gradients with the given weights. | |
A call to the backward method of the balancer will compute the the partial gradients, | |
combining all the losses and potentially rescaling the gradients, | |
which can help stabilize the training and reason about multiple losses with varying scales. | |
The obtained gradient with respect to `y` is then back-propagated to `f(...)`. | |
Expected usage: | |
weights = {'loss_a': 1, 'loss_b': 4} | |
balancer = Balancer(weights, ...) | |
losses: dict = {} | |
losses['loss_a'] = compute_loss_a(x, y) | |
losses['loss_b'] = compute_loss_b(x, y) | |
if model.training(): | |
effective_loss = balancer.backward(losses, x) | |
Args: | |
weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys | |
from the backward method to match the weights keys to assign weight to each of the provided loss. | |
balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the | |
overall gradient, rather than a constant multiplier. | |
total_norm (float): Reference norm when rescaling gradients, ignored otherwise. | |
emay_decay (float): EMA decay for averaging the norms. | |
per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds | |
when rescaling the gradients. | |
epsilon (float): Epsilon value for numerical stability. | |
monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients | |
coming from each loss, when calling `backward()`. | |
""" | |
def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1., | |
ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12, | |
monitor: bool = False): | |
self.weights = weights | |
self.per_batch_item = per_batch_item | |
self.total_norm = total_norm or 1. | |
self.averager = flashy.averager(ema_decay or 1.) | |
self.epsilon = epsilon | |
self.monitor = monitor | |
self.balance_grads = balance_grads | |
self._metrics: tp.Dict[str, tp.Any] = {} | |
def metrics(self): | |
return self._metrics | |
def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor: | |
"""Compute the backward and return the effective train loss, e.g. the loss obtained from | |
computing the effective weights. If `balance_grads` is True, the effective weights | |
are the one that needs to be applied to each gradient to respect the desired relative | |
scale of gradients coming from each loss. | |
Args: | |
losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`. | |
input (torch.Tensor): the input of the losses, typically the output of the model. | |
This should be the single point of dependence between the losses | |
and the model being trained. | |
""" | |
norms = {} | |
grads = {} | |
for name, loss in losses.items(): | |
# Compute partial derivative of the less with respect to the input. | |
grad, = autograd.grad(loss, [input], retain_graph=True) | |
if self.per_batch_item: | |
# We do not average the gradient over the batch dimension. | |
dims = tuple(range(1, grad.dim())) | |
norm = grad.norm(dim=dims, p=2).mean() | |
else: | |
norm = grad.norm(p=2) | |
norms[name] = norm | |
grads[name] = grad | |
count = 1 | |
if self.per_batch_item: | |
count = len(grad) | |
# Average norms across workers. Theoretically we should average the | |
# squared norm, then take the sqrt, but it worked fine like that. | |
avg_norms = flashy.distrib.average_metrics(self.averager(norms), count) | |
# We approximate the total norm of the gradient as the sums of the norms. | |
# Obviously this can be very incorrect if all gradients are aligned, but it works fine. | |
total = sum(avg_norms.values()) | |
self._metrics = {} | |
if self.monitor: | |
# Store the ratio of the total gradient represented by each loss. | |
for k, v in avg_norms.items(): | |
self._metrics[f'ratio_{k}'] = v / total | |
total_weights = sum([self.weights[k] for k in avg_norms]) | |
assert total_weights > 0. | |
desired_ratios = {k: w / total_weights for k, w in self.weights.items()} | |
out_grad = torch.zeros_like(input) | |
effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype) | |
for name, avg_norm in avg_norms.items(): | |
if self.balance_grads: | |
# g_balanced = g / avg(||g||) * total_norm * desired_ratio | |
scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm) | |
else: | |
# We just do regular weighted sum of the gradients. | |
scale = self.weights[name] | |
out_grad.add_(grads[name], alpha=scale) | |
effective_loss += scale * losses[name].detach() | |
# Send the computed partial derivative with respect to the output of the model to the model. | |
input.backward(out_grad) | |
return effective_loss | |