RegionSpot / regionspot /util /model_ema.py
bklg's picture
Upload 114 files
a153c95
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import math
import itertools
import logging
from typing import Dict, Any
from contextlib import contextmanager
import torch
from detectron2.engine.train_loop import HookBase
from detectron2.checkpoint import DetectionCheckpointer
logger = logging.getLogger(__name__)
class EMADetectionCheckpointer(DetectionCheckpointer):
def resume_or_load(self, path: str, *, resume: bool = True) -> Dict[str, Any]:
"""
If `resume` is True, this method attempts to resume from the last
checkpoint, if exists. Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
path (str): path to the checkpoint.
resume (bool): if True, resume from the last checkpoint if it exists
and load the model together with all the checkpointables. Otherwise
only load the model without loading any checkpointables.
Returns:
same as :meth:`load`.
"""
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
else:
# workaround `self.load`
return self.load(path, checkpointables=None) # modify
class EMAState(object):
def __init__(self):
self.state = {}
@classmethod
def FromModel(cls, model: torch.nn.Module, device: str = ""):
ret = cls()
ret.save_from(model, device)
return ret
def save_from(self, model: torch.nn.Module, device: str = ""):
"""Save model state from `model` to this object"""
for name, val in self.get_model_state_iterator(model):
val = val.detach().clone()
self.state[name] = val.to(device) if device else val
def apply_to(self, model: torch.nn.Module):
"""Apply state to `model` from this object"""
with torch.no_grad():
for name, val in self.get_model_state_iterator(model):
assert (
name in self.state
), f"Name {name} not existed, available names {self.state.keys()}"
val.copy_(self.state[name])
@contextmanager
def apply_and_restore(self, model):
old_state = EMAState.FromModel(model, self.device)
self.apply_to(model)
yield old_state
old_state.apply_to(model)
def get_ema_model(self, model):
ret = copy.deepcopy(model)
self.apply_to(ret)
return ret
@property
def device(self):
if not self.has_inited():
return None
return next(iter(self.state.values())).device
def to(self, device):
for name in self.state:
self.state[name] = self.state[name].to(device)
return self
def has_inited(self):
return self.state
def clear(self):
self.state.clear()
return self
def get_model_state_iterator(self, model):
param_iter = model.named_parameters()
buffer_iter = model.named_buffers()
return itertools.chain(param_iter, buffer_iter)
def state_dict(self):
return self.state
def load_state_dict(self, state_dict, strict: bool = True):
self.clear()
for x, y in state_dict.items():
self.state[x] = y
return torch.nn.modules.module._IncompatibleKeys(
missing_keys=[], unexpected_keys=[]
)
def __repr__(self):
ret = f"EMAState(state=[{','.join(self.state.keys())}])"
return ret
class EMAUpdater(object):
"""Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and
buffers). This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
Note: It's very important to set EMA for ALL network parameters (instead of
parameters that require gradient), including batch-norm moving average mean
and variance. This leads to significant improvement in accuracy.
For example, for EfficientNetB3, with default setting (no mixup, lr exponential
decay) without bn_sync, the EMA accuracy with EMA on params that requires
gradient is 79.87%, while the corresponding accuracy with EMA on all params
is 80.61%.
Also, bn sync should be switched on for EMA.
"""
def __init__(self, state: EMAState, decay: float = 0.999, device: str = "", yolox: bool = False):
self.decay = decay
self.device = device
self.state = state
self.updates = 0
self.yolox = yolox
if yolox:
decay = 0.9998
self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
def init_state(self, model):
self.state.clear()
self.state.save_from(model, self.device)
def update(self, model):
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates) if self.yolox else self.decay
for name, val in self.state.get_model_state_iterator(model):
ema_val = self.state.state[name]
if self.device:
val = val.to(self.device)
ema_val.copy_(ema_val * d + val * (1.0 - d))
def add_model_ema_configs(_C):
_C.MODEL_EMA = type(_C)()
_C.MODEL_EMA.ENABLED = False
_C.MODEL_EMA.DECAY = 0.999
# use the same as MODEL.DEVICE when empty
_C.MODEL_EMA.DEVICE = ""
# When True, loading the ema weight to the model when eval_only=True in build_model()
_C.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = False
# when True, use YOLOX EMA: https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/ema.py#L22
_C.MODEL_EMA.YOLOX = False
def _remove_ddp(model):
from torch.nn.parallel import DistributedDataParallel
if isinstance(model, DistributedDataParallel):
return model.module
return model
def may_build_model_ema(cfg, model):
if not cfg.MODEL_EMA.ENABLED:
return
model = _remove_ddp(model)
assert not hasattr(
model, "ema_state"
), "Name `ema_state` is reserved for model ema."
model.ema_state = EMAState()
logger.info("Using Model EMA.")
def may_get_ema_checkpointer(cfg, model):
if not cfg.MODEL_EMA.ENABLED:
return {}
model = _remove_ddp(model)
return {"ema_state": model.ema_state}
def get_model_ema_state(model):
"""Return the ema state stored in `model`"""
model = _remove_ddp(model)
assert hasattr(model, "ema_state")
ema = model.ema_state
return ema
def apply_model_ema(model, state=None, save_current=False):
"""Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)
if state is None:
state = get_model_ema_state(model)
if save_current:
# save current model state
old_state = EMAState.FromModel(model, state.device)
state.apply_to(model)
if save_current:
return old_state
return None
@contextmanager
def apply_model_ema_and_restore(model, state=None):
"""Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)
if state is None:
state = get_model_ema_state(model)
old_state = EMAState.FromModel(model, state.device)
state.apply_to(model)
yield old_state
old_state.apply_to(model)
class EMAHook(HookBase):
def __init__(self, cfg, model):
model = _remove_ddp(model)
assert cfg.MODEL_EMA.ENABLED
assert hasattr(
model, "ema_state"
), "Call `may_build_model_ema` first to initilaize the model ema"
self.model = model
self.ema = self.model.ema_state
self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE
self.ema_updater = EMAUpdater(
self.model.ema_state, decay=cfg.MODEL_EMA.DECAY, device=self.device, yolox=cfg.MODEL_EMA.YOLOX
)
def before_train(self):
if self.ema.has_inited():
self.ema.to(self.device)
else:
self.ema_updater.init_state(self.model)
def after_train(self):
pass
def before_step(self):
pass
def after_step(self):
if not self.model.train:
return
self.ema_updater.update(self.model)