rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
2.77 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional
import torch
import torch.nn as nn
from mmengine.model import ExponentialMovingAverage
from torch import Tensor
from mmdet.registry import MODELS
@MODELS.register_module()
class ExpMomentumEMA(ExponentialMovingAverage):
"""Exponential moving average (EMA) with exponential momentum strategy,
which is used in YOLOX.
Args:
model (nn.Module): The model to be averaged.
momentum (float): The momentum used for updating ema parameter.
Ema's parameter are updated with the formula:
`averaged_param = (1-momentum) * averaged_param + momentum *
source_param`. Defaults to 0.0002.
gamma (int): Use a larger momentum early in training and gradually
annealing to a smaller value to update the ema model smoothly. The
momentum is calculated as
`(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`.
Defaults to 2000.
interval (int): Interval between two updates. Defaults to 1.
device (torch.device, optional): If provided, the averaged model will
be stored on the :attr:`device`. Defaults to None.
update_buffers (bool): if True, it will compute running averages for
both the parameters and the buffers of the model. Defaults to
False.
"""
def __init__(self,
model: nn.Module,
momentum: float = 0.0002,
gamma: int = 2000,
interval=1,
device: Optional[torch.device] = None,
update_buffers: bool = False) -> None:
super().__init__(
model=model,
momentum=momentum,
interval=interval,
device=device,
update_buffers=update_buffers)
assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'
self.gamma = gamma
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int) -> None:
"""Compute the moving average of the parameters using the exponential
momentum strategy.
Args:
averaged_param (Tensor): The averaged parameters.
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
"""
momentum = (1 - self.momentum) * math.exp(
-float(1 + steps) / self.gamma) + self.momentum
averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)