from collections import OrderedDict import torch @torch.no_grad() def update_ema( ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True ) -> None: """ Step the EMA model towards the current model. """ ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): if name == "pos_embed": continue if param.requires_grad == False: continue if not sharded: param_data = param.data ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) else: if param.data.dtype != torch.float32: param_id = id(param) master_param = optimizer._param_store.working_to_master_param[param_id] param_data = master_param.data else: param_data = param.data ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)