Spaces:
Runtime error
Runtime error
# --------------------------------------------------------------- | |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
# | |
# This work is licensed under the NVIDIA Source Code License | |
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file. | |
# --------------------------------------------------------------- | |
''' | |
Codes adapted from https://github.com/NVlabs/LSGM/blob/main/util/ema.py | |
''' | |
import warnings | |
import torch | |
from torch.optim import Optimizer | |
class EMA(Optimizer): | |
def __init__(self, opt, ema_decay, memory_efficient=False): | |
self.ema_decay = ema_decay | |
self.apply_ema = self.ema_decay > 0. | |
self.optimizer = opt | |
self.state = opt.state | |
self.param_groups = opt.param_groups | |
self.defaults = {} | |
self.memory_efficient = memory_efficient | |
def step(self, *args, **kwargs): | |
# for group in self.optimizer.param_groups: | |
# group.setdefault('amsgrad', False) | |
# group.setdefault('maximize', False) | |
# group.setdefault('foreach', None) | |
# group.setdefault('capturable', False) | |
# group.setdefault('differentiable', False) | |
# group.setdefault('fused', False) | |
retval = self.optimizer.step(*args, **kwargs) | |
# stop here if we are not applying EMA | |
if not self.apply_ema: | |
return retval | |
ema, params = {}, {} | |
for group in self.optimizer.param_groups: | |
for i, p in enumerate(group['params']): | |
if p.grad is None: | |
continue | |
state = self.optimizer.state[p] | |
# State initialization | |
if 'ema' not in state: | |
state['ema'] = p.data.clone() | |
if p.shape not in params: | |
params[p.shape] = {'idx': 0, 'data': []} | |
ema[p.shape] = [] | |
params[p.shape]['data'].append(p.data) | |
ema[p.shape].append(state['ema']) | |
# def stack(d, dim=0): | |
# return torch.stack([di.cpu() for di in d], dim=dim).cuda() | |
for i in params: | |
if self.memory_efficient: | |
for j in range(len(params[i]['data'])): | |
ema[i][j].mul_(self.ema_decay).add_(params[i]['data'][j], alpha=1. - self.ema_decay) | |
ema[i] = torch.stack(ema[i], dim=0) | |
else: | |
params[i]['data'] = torch.stack(params[i]['data'], dim=0) | |
ema[i] = torch.stack(ema[i], dim=0) | |
ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay) | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
idx = params[p.shape]['idx'] | |
self.optimizer.state[p]['ema'] = ema[p.shape][idx, :] | |
params[p.shape]['idx'] += 1 | |
return retval | |
def load_state_dict(self, state_dict): | |
super(EMA, self).load_state_dict(state_dict) | |
# load_state_dict loads the data to self.state and self.param_groups. We need to pass this data to | |
# the underlying optimizer too. | |
self.optimizer.state = self.state | |
self.optimizer.param_groups = self.param_groups | |
def swap_parameters_with_ema(self, store_params_in_ema): | |
""" This function swaps parameters with their ema values. It records original parameters in the ema | |
parameters, if store_params_in_ema is true.""" | |
# stop here if we are not applying EMA | |
if not self.apply_ema: | |
warnings.warn('swap_parameters_with_ema was called when there is no EMA weights.') | |
return | |
for group in self.optimizer.param_groups: | |
for i, p in enumerate(group['params']): | |
if not p.requires_grad: | |
continue | |
ema = self.optimizer.state[p]['ema'] | |
if store_params_in_ema: | |
tmp = p.data.detach() | |
p.data = ema.detach() | |
self.optimizer.state[p]['ema'] = tmp | |
else: | |
p.data = ema.detach() | |