Spaces:
Running
Running
File size: 3,872 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from pytorch_lightning import Callback
import copy
import itertools
import torch
import contextlib
from torch.distributed.fsdp import FullyShardedDataParallel
class EMACallback(Callback):
def __init__(
self,
module_attr_name,
ema_module_attr_name,
decay=0.999,
start_ema_step=0,
init_ema_random=True,
):
super().__init__()
self.decay = decay
self.module_attr_name = module_attr_name
self.ema_module_attr_name = ema_module_attr_name
self.start_ema_step = start_ema_step
self.init_ema_random = init_ema_random
def on_train_start(self, trainer, pl_module):
if pl_module.global_step == 0:
if not hasattr(pl_module, self.module_attr_name):
raise ValueError(
f"Module {pl_module} does not have attribute {self.module_attr_name}"
)
if not hasattr(pl_module, self.ema_module_attr_name):
pl_module.add_module(
self.ema_module_attr_name,
copy.deepcopy(getattr(pl_module, self.module_attr_name))
.eval()
.requires_grad_(False),
)
self.reset_ema(pl_module)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if pl_module.global_step == self.start_ema_step:
self.reset_ema(pl_module)
elif (
pl_module.global_step < self.start_ema_step
and pl_module.global_step % 100 == 0
):
## slow ema updates for visualisation
self.update_ema(pl_module, decay=0.9)
elif pl_module.global_step > self.start_ema_step:
self.update_ema(pl_module, decay=self.decay)
def update_ema(self, pl_module, decay=0.999):
ema_module = getattr(pl_module, self.ema_module_attr_name)
module = getattr(pl_module, self.module_attr_name)
context_manager = self.get_model_context_manager(module)
with context_manager:
with torch.no_grad():
ema_params = ema_module.state_dict()
for name, param in itertools.chain(
module.named_parameters(), module.named_buffers()
):
if name in ema_params:
if param.requires_grad:
ema_params[name].copy_(
ema_params[name].detach().lerp(param.detach(), decay)
)
def get_model_context_manager(self, module):
fsdp_enabled = is_model_fsdp(module)
model_context_manager = contextlib.nullcontext()
if fsdp_enabled:
model_context_manager = module.summon_full_params(module)
return model_context_manager
def reset_ema(self, pl_module):
ema_module = getattr(pl_module, self.ema_module_attr_name)
if self.init_ema_random:
ema_module.init_weights()
else:
module = getattr(pl_module, self.module_attr_name)
context_manager = self.get_model_context_manager(module)
with context_manager:
ema_params = ema_module.state_dict()
for name, param in itertools.chain(
module.named_parameters(), module.named_buffers()
):
if name in ema_params:
ema_params[name].copy_(param.detach())
def is_model_fsdp(model: torch.nn.Module) -> bool:
try:
if isinstance(model, FullyShardedDataParallel):
return True
# Check if model is wrapped with FSDP
for _, obj in model.named_children():
if isinstance(obj, FullyShardedDataParallel):
return True
return False
except ImportError:
return False
|