File size: 3,872 Bytes
c4c7cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from pytorch_lightning import Callback
import copy
import itertools
import torch
import contextlib
from torch.distributed.fsdp import FullyShardedDataParallel


class EMACallback(Callback):
    def __init__(
        self,
        module_attr_name,
        ema_module_attr_name,
        decay=0.999,
        start_ema_step=0,
        init_ema_random=True,
    ):
        super().__init__()
        self.decay = decay
        self.module_attr_name = module_attr_name
        self.ema_module_attr_name = ema_module_attr_name
        self.start_ema_step = start_ema_step
        self.init_ema_random = init_ema_random

    def on_train_start(self, trainer, pl_module):
        if pl_module.global_step == 0:
            if not hasattr(pl_module, self.module_attr_name):
                raise ValueError(
                    f"Module {pl_module} does not have attribute {self.module_attr_name}"
                )
            if not hasattr(pl_module, self.ema_module_attr_name):
                pl_module.add_module(
                    self.ema_module_attr_name,
                    copy.deepcopy(getattr(pl_module, self.module_attr_name))
                    .eval()
                    .requires_grad_(False),
                )
            self.reset_ema(pl_module)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if pl_module.global_step == self.start_ema_step:
            self.reset_ema(pl_module)
        elif (
            pl_module.global_step < self.start_ema_step
            and pl_module.global_step % 100 == 0
        ):
            ## slow ema updates for visualisation
            self.update_ema(pl_module, decay=0.9)
        elif pl_module.global_step > self.start_ema_step:
            self.update_ema(pl_module, decay=self.decay)

    def update_ema(self, pl_module, decay=0.999):
        ema_module = getattr(pl_module, self.ema_module_attr_name)
        module = getattr(pl_module, self.module_attr_name)
        context_manager = self.get_model_context_manager(module)
        with context_manager:
            with torch.no_grad():
                ema_params = ema_module.state_dict()
                for name, param in itertools.chain(
                    module.named_parameters(), module.named_buffers()
                ):
                    if name in ema_params:
                        if param.requires_grad:
                            ema_params[name].copy_(
                                ema_params[name].detach().lerp(param.detach(), decay)
                            )

    def get_model_context_manager(self, module):
        fsdp_enabled = is_model_fsdp(module)
        model_context_manager = contextlib.nullcontext()
        if fsdp_enabled:
            model_context_manager = module.summon_full_params(module)
        return model_context_manager

    def reset_ema(self, pl_module):
        ema_module = getattr(pl_module, self.ema_module_attr_name)
        if self.init_ema_random:
            ema_module.init_weights()
        else:
            module = getattr(pl_module, self.module_attr_name)
            context_manager = self.get_model_context_manager(module)
            with context_manager:
                ema_params = ema_module.state_dict()
                for name, param in itertools.chain(
                    module.named_parameters(), module.named_buffers()
                ):
                    if name in ema_params:
                        ema_params[name].copy_(param.detach())


def is_model_fsdp(model: torch.nn.Module) -> bool:
    try:
        if isinstance(model, FullyShardedDataParallel):
            return True

        # Check if model is wrapped with FSDP
        for _, obj in model.named_children():
            if isinstance(obj, FullyShardedDataParallel):
                return True
        return False
    except ImportError:
        return False