# --------------------------------------------------------------- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # This work is licensed under the NVIDIA Source Code License # for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file. # --------------------------------------------------------------- ''' Codes adapted from https://github.com/NVlabs/LSGM/blob/main/util/ema.py ''' import warnings import torch from torch.optim import Optimizer class EMA(Optimizer): def __init__(self, opt, ema_decay, memory_efficient=False): self.ema_decay = ema_decay self.apply_ema = self.ema_decay > 0. self.optimizer = opt self.state = opt.state self.param_groups = opt.param_groups self.defaults = {} self.memory_efficient = memory_efficient def step(self, *args, **kwargs): # for group in self.optimizer.param_groups: # group.setdefault('amsgrad', False) # group.setdefault('maximize', False) # group.setdefault('foreach', None) # group.setdefault('capturable', False) # group.setdefault('differentiable', False) # group.setdefault('fused', False) retval = self.optimizer.step(*args, **kwargs) # stop here if we are not applying EMA if not self.apply_ema: return retval ema, params = {}, {} for group in self.optimizer.param_groups: for i, p in enumerate(group['params']): if p.grad is None: continue state = self.optimizer.state[p] # State initialization if 'ema' not in state: state['ema'] = p.data.clone() if p.shape not in params: params[p.shape] = {'idx': 0, 'data': []} ema[p.shape] = [] params[p.shape]['data'].append(p.data) ema[p.shape].append(state['ema']) # def stack(d, dim=0): # return torch.stack([di.cpu() for di in d], dim=dim).cuda() for i in params: if self.memory_efficient: for j in range(len(params[i]['data'])): ema[i][j].mul_(self.ema_decay).add_(params[i]['data'][j], alpha=1. - self.ema_decay) ema[i] = torch.stack(ema[i], dim=0) else: params[i]['data'] = torch.stack(params[i]['data'], dim=0) ema[i] = torch.stack(ema[i], dim=0) ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay) for p in group['params']: if p.grad is None: continue idx = params[p.shape]['idx'] self.optimizer.state[p]['ema'] = ema[p.shape][idx, :] params[p.shape]['idx'] += 1 return retval def load_state_dict(self, state_dict): super(EMA, self).load_state_dict(state_dict) # load_state_dict loads the data to self.state and self.param_groups. We need to pass this data to # the underlying optimizer too. self.optimizer.state = self.state self.optimizer.param_groups = self.param_groups def swap_parameters_with_ema(self, store_params_in_ema): """ This function swaps parameters with their ema values. It records original parameters in the ema parameters, if store_params_in_ema is true.""" # stop here if we are not applying EMA if not self.apply_ema: warnings.warn('swap_parameters_with_ema was called when there is no EMA weights.') return for group in self.optimizer.param_groups: for i, p in enumerate(group['params']): if not p.requires_grad: continue ema = self.optimizer.state[p]['ema'] if store_params_in_ema: tmp = p.data.detach() p.data = ema.detach() self.optimizer.state[p]['ema'] = tmp else: p.data = ema.detach()