Plonk / callbacks /ema.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
raw
history blame
3.87 kB
from pytorch_lightning import Callback
import copy
import itertools
import torch
import contextlib
from torch.distributed.fsdp import FullyShardedDataParallel
class EMACallback(Callback):
def __init__(
self,
module_attr_name,
ema_module_attr_name,
decay=0.999,
start_ema_step=0,
init_ema_random=True,
):
super().__init__()
self.decay = decay
self.module_attr_name = module_attr_name
self.ema_module_attr_name = ema_module_attr_name
self.start_ema_step = start_ema_step
self.init_ema_random = init_ema_random
def on_train_start(self, trainer, pl_module):
if pl_module.global_step == 0:
if not hasattr(pl_module, self.module_attr_name):
raise ValueError(
f"Module {pl_module} does not have attribute {self.module_attr_name}"
)
if not hasattr(pl_module, self.ema_module_attr_name):
pl_module.add_module(
self.ema_module_attr_name,
copy.deepcopy(getattr(pl_module, self.module_attr_name))
.eval()
.requires_grad_(False),
)
self.reset_ema(pl_module)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if pl_module.global_step == self.start_ema_step:
self.reset_ema(pl_module)
elif (
pl_module.global_step < self.start_ema_step
and pl_module.global_step % 100 == 0
):
## slow ema updates for visualisation
self.update_ema(pl_module, decay=0.9)
elif pl_module.global_step > self.start_ema_step:
self.update_ema(pl_module, decay=self.decay)
def update_ema(self, pl_module, decay=0.999):
ema_module = getattr(pl_module, self.ema_module_attr_name)
module = getattr(pl_module, self.module_attr_name)
context_manager = self.get_model_context_manager(module)
with context_manager:
with torch.no_grad():
ema_params = ema_module.state_dict()
for name, param in itertools.chain(
module.named_parameters(), module.named_buffers()
):
if name in ema_params:
if param.requires_grad:
ema_params[name].copy_(
ema_params[name].detach().lerp(param.detach(), decay)
)
def get_model_context_manager(self, module):
fsdp_enabled = is_model_fsdp(module)
model_context_manager = contextlib.nullcontext()
if fsdp_enabled:
model_context_manager = module.summon_full_params(module)
return model_context_manager
def reset_ema(self, pl_module):
ema_module = getattr(pl_module, self.ema_module_attr_name)
if self.init_ema_random:
ema_module.init_weights()
else:
module = getattr(pl_module, self.module_attr_name)
context_manager = self.get_model_context_manager(module)
with context_manager:
ema_params = ema_module.state_dict()
for name, param in itertools.chain(
module.named_parameters(), module.named_buffers()
):
if name in ema_params:
ema_params[name].copy_(param.detach())
def is_model_fsdp(model: torch.nn.Module) -> bool:
try:
if isinstance(model, FullyShardedDataParallel):
return True
# Check if model is wrapped with FSDP
for _, obj in model.named_children():
if isinstance(obj, FullyShardedDataParallel):
return True
return False
except ImportError:
return False