Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import copy | |
import math | |
import itertools | |
import logging | |
from typing import Dict, Any | |
from contextlib import contextmanager | |
import torch | |
from detectron2.engine.train_loop import HookBase | |
from detectron2.checkpoint import DetectionCheckpointer | |
logger = logging.getLogger(__name__) | |
class EMADetectionCheckpointer(DetectionCheckpointer): | |
def resume_or_load(self, path: str, *, resume: bool = True) -> Dict[str, Any]: | |
""" | |
If `resume` is True, this method attempts to resume from the last | |
checkpoint, if exists. Otherwise, load checkpoint from the given path. | |
This is useful when restarting an interrupted training job. | |
Args: | |
path (str): path to the checkpoint. | |
resume (bool): if True, resume from the last checkpoint if it exists | |
and load the model together with all the checkpointables. Otherwise | |
only load the model without loading any checkpointables. | |
Returns: | |
same as :meth:`load`. | |
""" | |
if resume and self.has_checkpoint(): | |
path = self.get_checkpoint_file() | |
return self.load(path) | |
else: | |
# workaround `self.load` | |
return self.load(path, checkpointables=None) # modify | |
class EMAState(object): | |
def __init__(self): | |
self.state = {} | |
def FromModel(cls, model: torch.nn.Module, device: str = ""): | |
ret = cls() | |
ret.save_from(model, device) | |
return ret | |
def save_from(self, model: torch.nn.Module, device: str = ""): | |
"""Save model state from `model` to this object""" | |
for name, val in self.get_model_state_iterator(model): | |
val = val.detach().clone() | |
self.state[name] = val.to(device) if device else val | |
def apply_to(self, model: torch.nn.Module): | |
"""Apply state to `model` from this object""" | |
with torch.no_grad(): | |
for name, val in self.get_model_state_iterator(model): | |
assert ( | |
name in self.state | |
), f"Name {name} not existed, available names {self.state.keys()}" | |
val.copy_(self.state[name]) | |
def apply_and_restore(self, model): | |
old_state = EMAState.FromModel(model, self.device) | |
self.apply_to(model) | |
yield old_state | |
old_state.apply_to(model) | |
def get_ema_model(self, model): | |
ret = copy.deepcopy(model) | |
self.apply_to(ret) | |
return ret | |
def device(self): | |
if not self.has_inited(): | |
return None | |
return next(iter(self.state.values())).device | |
def to(self, device): | |
for name in self.state: | |
self.state[name] = self.state[name].to(device) | |
return self | |
def has_inited(self): | |
return self.state | |
def clear(self): | |
self.state.clear() | |
return self | |
def get_model_state_iterator(self, model): | |
param_iter = model.named_parameters() | |
buffer_iter = model.named_buffers() | |
return itertools.chain(param_iter, buffer_iter) | |
def state_dict(self): | |
return self.state | |
def load_state_dict(self, state_dict, strict: bool = True): | |
self.clear() | |
for x, y in state_dict.items(): | |
self.state[x] = y | |
return torch.nn.modules.module._IncompatibleKeys( | |
missing_keys=[], unexpected_keys=[] | |
) | |
def __repr__(self): | |
ret = f"EMAState(state=[{','.join(self.state.keys())}])" | |
return ret | |
class EMAUpdater(object): | |
"""Model Exponential Moving Average | |
Keep a moving average of everything in the model state_dict (parameters and | |
buffers). This is intended to allow functionality like | |
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage | |
Note: It's very important to set EMA for ALL network parameters (instead of | |
parameters that require gradient), including batch-norm moving average mean | |
and variance. This leads to significant improvement in accuracy. | |
For example, for EfficientNetB3, with default setting (no mixup, lr exponential | |
decay) without bn_sync, the EMA accuracy with EMA on params that requires | |
gradient is 79.87%, while the corresponding accuracy with EMA on all params | |
is 80.61%. | |
Also, bn sync should be switched on for EMA. | |
""" | |
def __init__(self, state: EMAState, decay: float = 0.999, device: str = "", yolox: bool = False): | |
self.decay = decay | |
self.device = device | |
self.state = state | |
self.updates = 0 | |
self.yolox = yolox | |
if yolox: | |
decay = 0.9998 | |
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) | |
def init_state(self, model): | |
self.state.clear() | |
self.state.save_from(model, self.device) | |
def update(self, model): | |
with torch.no_grad(): | |
self.updates += 1 | |
d = self.decay(self.updates) if self.yolox else self.decay | |
for name, val in self.state.get_model_state_iterator(model): | |
ema_val = self.state.state[name] | |
if self.device: | |
val = val.to(self.device) | |
ema_val.copy_(ema_val * d + val * (1.0 - d)) | |
def add_model_ema_configs(_C): | |
_C.MODEL_EMA = type(_C)() | |
_C.MODEL_EMA.ENABLED = False | |
_C.MODEL_EMA.DECAY = 0.999 | |
# use the same as MODEL.DEVICE when empty | |
_C.MODEL_EMA.DEVICE = "" | |
# When True, loading the ema weight to the model when eval_only=True in build_model() | |
_C.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = False | |
# when True, use YOLOX EMA: https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/ema.py#L22 | |
_C.MODEL_EMA.YOLOX = False | |
def _remove_ddp(model): | |
from torch.nn.parallel import DistributedDataParallel | |
if isinstance(model, DistributedDataParallel): | |
return model.module | |
return model | |
def may_build_model_ema(cfg, model): | |
if not cfg.MODEL_EMA.ENABLED: | |
return | |
model = _remove_ddp(model) | |
assert not hasattr( | |
model, "ema_state" | |
), "Name `ema_state` is reserved for model ema." | |
model.ema_state = EMAState() | |
logger.info("Using Model EMA.") | |
def may_get_ema_checkpointer(cfg, model): | |
if not cfg.MODEL_EMA.ENABLED: | |
return {} | |
model = _remove_ddp(model) | |
return {"ema_state": model.ema_state} | |
def get_model_ema_state(model): | |
"""Return the ema state stored in `model`""" | |
model = _remove_ddp(model) | |
assert hasattr(model, "ema_state") | |
ema = model.ema_state | |
return ema | |
def apply_model_ema(model, state=None, save_current=False): | |
"""Apply ema stored in `model` to model and returns a function to restore | |
the weights are applied | |
""" | |
model = _remove_ddp(model) | |
if state is None: | |
state = get_model_ema_state(model) | |
if save_current: | |
# save current model state | |
old_state = EMAState.FromModel(model, state.device) | |
state.apply_to(model) | |
if save_current: | |
return old_state | |
return None | |
def apply_model_ema_and_restore(model, state=None): | |
"""Apply ema stored in `model` to model and returns a function to restore | |
the weights are applied | |
""" | |
model = _remove_ddp(model) | |
if state is None: | |
state = get_model_ema_state(model) | |
old_state = EMAState.FromModel(model, state.device) | |
state.apply_to(model) | |
yield old_state | |
old_state.apply_to(model) | |
class EMAHook(HookBase): | |
def __init__(self, cfg, model): | |
model = _remove_ddp(model) | |
assert cfg.MODEL_EMA.ENABLED | |
assert hasattr( | |
model, "ema_state" | |
), "Call `may_build_model_ema` first to initilaize the model ema" | |
self.model = model | |
self.ema = self.model.ema_state | |
self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE | |
self.ema_updater = EMAUpdater( | |
self.model.ema_state, decay=cfg.MODEL_EMA.DECAY, device=self.device, yolox=cfg.MODEL_EMA.YOLOX | |
) | |
def before_train(self): | |
if self.ema.has_inited(): | |
self.ema.to(self.device) | |
else: | |
self.ema_updater.init_state(self.model) | |
def after_train(self): | |
pass | |
def before_step(self): | |
pass | |
def after_step(self): | |
if not self.model.train: | |
return | |
self.ema_updater.update(self.model) | |