Spaces:
Running
Running
from pytorch_lightning import Callback | |
import copy | |
import itertools | |
import torch | |
import contextlib | |
from torch.distributed.fsdp import FullyShardedDataParallel | |
class EMACallback(Callback): | |
def __init__( | |
self, | |
module_attr_name, | |
ema_module_attr_name, | |
decay=0.999, | |
start_ema_step=0, | |
init_ema_random=True, | |
): | |
super().__init__() | |
self.decay = decay | |
self.module_attr_name = module_attr_name | |
self.ema_module_attr_name = ema_module_attr_name | |
self.start_ema_step = start_ema_step | |
self.init_ema_random = init_ema_random | |
def on_train_start(self, trainer, pl_module): | |
if pl_module.global_step == 0: | |
if not hasattr(pl_module, self.module_attr_name): | |
raise ValueError( | |
f"Module {pl_module} does not have attribute {self.module_attr_name}" | |
) | |
if not hasattr(pl_module, self.ema_module_attr_name): | |
pl_module.add_module( | |
self.ema_module_attr_name, | |
copy.deepcopy(getattr(pl_module, self.module_attr_name)) | |
.eval() | |
.requires_grad_(False), | |
) | |
self.reset_ema(pl_module) | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
if pl_module.global_step == self.start_ema_step: | |
self.reset_ema(pl_module) | |
elif ( | |
pl_module.global_step < self.start_ema_step | |
and pl_module.global_step % 100 == 0 | |
): | |
## slow ema updates for visualisation | |
self.update_ema(pl_module, decay=0.9) | |
elif pl_module.global_step > self.start_ema_step: | |
self.update_ema(pl_module, decay=self.decay) | |
def update_ema(self, pl_module, decay=0.999): | |
ema_module = getattr(pl_module, self.ema_module_attr_name) | |
module = getattr(pl_module, self.module_attr_name) | |
context_manager = self.get_model_context_manager(module) | |
with context_manager: | |
with torch.no_grad(): | |
ema_params = ema_module.state_dict() | |
for name, param in itertools.chain( | |
module.named_parameters(), module.named_buffers() | |
): | |
if name in ema_params: | |
if param.requires_grad: | |
ema_params[name].copy_( | |
ema_params[name].detach().lerp(param.detach(), decay) | |
) | |
def get_model_context_manager(self, module): | |
fsdp_enabled = is_model_fsdp(module) | |
model_context_manager = contextlib.nullcontext() | |
if fsdp_enabled: | |
model_context_manager = module.summon_full_params(module) | |
return model_context_manager | |
def reset_ema(self, pl_module): | |
ema_module = getattr(pl_module, self.ema_module_attr_name) | |
if self.init_ema_random: | |
ema_module.init_weights() | |
else: | |
module = getattr(pl_module, self.module_attr_name) | |
context_manager = self.get_model_context_manager(module) | |
with context_manager: | |
ema_params = ema_module.state_dict() | |
for name, param in itertools.chain( | |
module.named_parameters(), module.named_buffers() | |
): | |
if name in ema_params: | |
ema_params[name].copy_(param.detach()) | |
def is_model_fsdp(model: torch.nn.Module) -> bool: | |
try: | |
if isinstance(model, FullyShardedDataParallel): | |
return True | |
# Check if model is wrapped with FSDP | |
for _, obj in model.named_children(): | |
if isinstance(obj, FullyShardedDataParallel): | |
return True | |
return False | |
except ImportError: | |
return False | |