Spaces:
Running
on
A10G
Running
on
A10G
File size: 6,612 Bytes
5325fcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import flashy
import torch
from torch import autograd
class Balancer:
"""Loss balancer.
The loss balancer combines losses together to compute gradients for the backward.
Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
`d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
(with `avg` an exponential moving average over the updates),
G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
standard sum of the partial gradients with the given weights.
A call to the backward method of the balancer will compute the the partial gradients,
combining all the losses and potentially rescaling the gradients,
which can help stabilize the training and reason about multiple losses with varying scales.
The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
Expected usage:
weights = {'loss_a': 1, 'loss_b': 4}
balancer = Balancer(weights, ...)
losses: dict = {}
losses['loss_a'] = compute_loss_a(x, y)
losses['loss_b'] = compute_loss_b(x, y)
if model.training():
effective_loss = balancer.backward(losses, x)
Args:
weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
from the backward method to match the weights keys to assign weight to each of the provided loss.
balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
overall gradient, rather than a constant multiplier.
total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
emay_decay (float): EMA decay for averaging the norms.
per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
when rescaling the gradients.
epsilon (float): Epsilon value for numerical stability.
monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
coming from each loss, when calling `backward()`.
"""
def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
monitor: bool = False):
self.weights = weights
self.per_batch_item = per_batch_item
self.total_norm = total_norm or 1.
self.averager = flashy.averager(ema_decay or 1.)
self.epsilon = epsilon
self.monitor = monitor
self.balance_grads = balance_grads
self._metrics: tp.Dict[str, tp.Any] = {}
@property
def metrics(self):
return self._metrics
def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
"""Compute the backward and return the effective train loss, e.g. the loss obtained from
computing the effective weights. If `balance_grads` is True, the effective weights
are the one that needs to be applied to each gradient to respect the desired relative
scale of gradients coming from each loss.
Args:
losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
input (torch.Tensor): the input of the losses, typically the output of the model.
This should be the single point of dependence between the losses
and the model being trained.
"""
norms = {}
grads = {}
for name, loss in losses.items():
# Compute partial derivative of the less with respect to the input.
grad, = autograd.grad(loss, [input], retain_graph=True)
if self.per_batch_item:
# We do not average the gradient over the batch dimension.
dims = tuple(range(1, grad.dim()))
norm = grad.norm(dim=dims, p=2).mean()
else:
norm = grad.norm(p=2)
norms[name] = norm
grads[name] = grad
count = 1
if self.per_batch_item:
count = len(grad)
# Average norms across workers. Theoretically we should average the
# squared norm, then take the sqrt, but it worked fine like that.
avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
# We approximate the total norm of the gradient as the sums of the norms.
# Obviously this can be very incorrect if all gradients are aligned, but it works fine.
total = sum(avg_norms.values())
self._metrics = {}
if self.monitor:
# Store the ratio of the total gradient represented by each loss.
for k, v in avg_norms.items():
self._metrics[f'ratio_{k}'] = v / total
total_weights = sum([self.weights[k] for k in avg_norms])
assert total_weights > 0.
desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
out_grad = torch.zeros_like(input)
effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
for name, avg_norm in avg_norms.items():
if self.balance_grads:
# g_balanced = g / avg(||g||) * total_norm * desired_ratio
scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
else:
# We just do regular weighted sum of the gradients.
scale = self.weights[name]
out_grad.add_(grads[name], alpha=scale)
effective_loss += scale * losses[name].detach()
# Send the computed partial derivative with respect to the output of the model to the model.
input.backward(out_grad)
return effective_loss
|