# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from typing import Optional import torch import torch.nn as nn from mmengine.model import ExponentialMovingAverage from torch import Tensor from mmdet.registry import MODELS @MODELS.register_module() class ExpMomentumEMA(ExponentialMovingAverage): """Exponential moving average (EMA) with exponential momentum strategy, which is used in YOLOX. Args: model (nn.Module): The model to be averaged. momentum (float): The momentum used for updating ema parameter. Ema's parameter are updated with the formula: `averaged_param = (1-momentum) * averaged_param + momentum * source_param`. Defaults to 0.0002. gamma (int): Use a larger momentum early in training and gradually annealing to a smaller value to update the ema model smoothly. The momentum is calculated as `(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`. Defaults to 2000. interval (int): Interval between two updates. Defaults to 1. device (torch.device, optional): If provided, the averaged model will be stored on the :attr:`device`. Defaults to None. update_buffers (bool): if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False. """ def __init__(self, model: nn.Module, momentum: float = 0.0002, gamma: int = 2000, interval=1, device: Optional[torch.device] = None, update_buffers: bool = False) -> None: super().__init__( model=model, momentum=momentum, interval=interval, device=device, update_buffers=update_buffers) assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' self.gamma = gamma def avg_func(self, averaged_param: Tensor, source_param: Tensor, steps: int) -> None: """Compute the moving average of the parameters using the exponential momentum strategy. Args: averaged_param (Tensor): The averaged parameters. source_param (Tensor): The source parameters. steps (int): The number of times the parameters have been updated. """ momentum = (1 - self.momentum) * math.exp( -float(1 + steps) / self.gamma) + self.momentum averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)