Pinwheel's picture
HF Demo
128757a
raw
history blame
1.69 kB
from copy import deepcopy
from collections import OrderedDict
import torch
class ModelEma:
def __init__(self, model, decay=0.9999, device=''):
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.device = device
if device:
self.ema.to(device=device)
self.ema_is_dp = hasattr(self.ema, 'module')
for p in self.ema.parameters():
p.requires_grad_(False)
def load_checkpoint(self, checkpoint):
if isinstance(checkpoint, str):
checkpoint = torch.load(checkpoint)
assert isinstance(checkpoint, dict)
if 'model_ema' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['model_ema'].items():
if self.ema_is_dp:
name = k if k.startswith('module') else 'module.' + k
else:
name = k.replace('module.', '') if k.startswith('module') else k
new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict)
def state_dict(self):
return self.ema.state_dict()
def update(self, model):
pre_module = hasattr(model, 'module') and not self.ema_is_dp
with torch.no_grad():
curr_msd = model.state_dict()
for k, ema_v in self.ema.state_dict().items():
k = 'module.' + k if pre_module else k
model_v = curr_msd[k].detach()
if self.device:
model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)