Spaces:
Runtime error
Runtime error
import torch | |
import copy | |
import tops | |
from tops import logger | |
from .torch_utils import set_requires_grad | |
class EMA: | |
""" | |
Expoenential moving average. | |
See: | |
Yazici, Y. et al.The unusual effectiveness of averaging in GAN training. ICLR 2019 | |
""" | |
def __init__( | |
self, | |
generator: torch.nn.Module, | |
batch_size: int, | |
rampup: float, | |
): | |
self.rampup = rampup | |
self._nimg_half_time = batch_size * 10 / 32 * 1000 | |
self._batch_size = batch_size | |
with torch.no_grad(): | |
self.generator = copy.deepcopy(generator.cpu()).eval() | |
self.generator = tops.to_cuda(self.generator) | |
self.old_ra_beta = 0 | |
set_requires_grad(self.generator, False) | |
def update_beta(self): | |
y = self._nimg_half_time | |
global_step = logger.global_step() | |
if self.rampup != None: | |
y = min(y, global_step*self.rampup) | |
self.ra_beta = 0.5 ** (self._batch_size/max(y, 1e-8)) | |
if self.ra_beta != self.old_ra_beta: | |
logger.add_scalar("stats/EMA_beta", self.ra_beta) | |
self.old_ra_beta = self.ra_beta | |
def update(self, normal_G): | |
with torch.autograd.profiler.record_function("EMA_update"): | |
for ema_p, p in zip(self.generator.parameters(), | |
normal_G.parameters()): | |
ema_p.copy_(p.lerp(ema_p, self.ra_beta)) | |
for ema_buf, buff in zip(self.generator.buffers(), | |
normal_G.buffers()): | |
ema_buf.copy_(buff) | |
def __call__(self, *args, **kwargs): | |
return self.generator(*args, **kwargs) | |
def __getattr__(self, name: str): | |
if hasattr(self.generator, name): | |
return getattr(self.generator, name) | |
raise AttributeError(f"Generator object has no attribute {name}") | |
def cuda(self, *args, **kwargs): | |
self.generator = self.generator.cuda() | |
return self | |
def state_dict(self, *args, **kwargs): | |
return self.generator.state_dict(*args, **kwargs) | |
def load_state_dict(self, *args, **kwargs): | |
return self.generator.load_state_dict(*args, **kwargs) | |
def eval(self): | |
self.generator.eval() | |
def train(self): | |
self.generator.train() | |
def module(self): | |
return self.generator.module | |
def sample(self, *args, **kwargs): | |
return self.generator.sample(*args, **kwargs) | |