File size: 6,018 Bytes
71de706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import typing

import torch
import torch.distributed as dist
from torch.nn.parallel import DataParallel
from torch.nn.parallel import DistributedDataParallel

from ..data.datasets import ResumableDistributedSampler as DistributedSampler
from ..data.datasets import ResumableSequentialSampler as SequentialSampler


class Accelerator:  # pragma: no cover
    """This class is used to prepare models and dataloaders for
    usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
    prepare the respective objects. In the case of models, they are moved to
    the appropriate GPU and SyncBatchNorm is applied to them. In the case of
    dataloaders, a sampler is created and the dataloader is initialized with
    that sampler.

    If the world size is 1, prepare_model and prepare_dataloader are
    no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
    script was launched without ``torchrun``, and ``DataParallel``
    will be used instead of ``DistributedDataParallel`` (not recommended), if
    the world size (number of GPUs) is greater than 1.

    Parameters
    ----------
    amp : bool, optional
        Whether or not to enable automatic mixed precision, by default False
    """

    def __init__(self, amp: bool = False):
        local_rank = os.getenv("LOCAL_RANK", None)
        self.world_size = torch.cuda.device_count()

        self.use_ddp = self.world_size > 1 and local_rank is not None
        self.use_dp = self.world_size > 1 and local_rank is None
        self.device = "cpu" if self.world_size == 0 else "cuda"

        if self.use_ddp:
            local_rank = int(local_rank)
            dist.init_process_group(
                "nccl",
                init_method="env://",
                world_size=self.world_size,
                rank=local_rank,
            )

        self.local_rank = 0 if local_rank is None else local_rank
        self.amp = amp

        class DummyScaler:
            def __init__(self):
                pass

            def step(self, optimizer):
                optimizer.step()

            def scale(self, loss):
                return loss

            def unscale_(self, optimizer):
                return optimizer

            def update(self):
                pass

        self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
        self.device_ctx = (
            torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
        )

    def __enter__(self):
        if self.device_ctx is not None:
            self.device_ctx.__enter__()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.device_ctx is not None:
            self.device_ctx.__exit__(exc_type, exc_value, traceback)

    def prepare_model(self, model: torch.nn.Module, **kwargs):
        """Prepares model for DDP or DP. The model is moved to
        the device of the correct rank.

        Parameters
        ----------
        model : torch.nn.Module
            Model that is converted for DDP or DP.

        Returns
        -------
        torch.nn.Module
            Wrapped model, or original model if DDP and DP are turned off.
        """
        model = model.to(self.device)
        if self.use_ddp:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = DistributedDataParallel(
                model, device_ids=[self.local_rank], **kwargs
            )
        elif self.use_dp:
            model = DataParallel(model, **kwargs)
        return model

    # Automatic mixed-precision utilities
    def autocast(self, *args, **kwargs):
        """Context manager for autocasting. Arguments
        go to ``torch.cuda.amp.autocast``.
        """
        return torch.cuda.amp.autocast(self.amp, *args, **kwargs)

    def backward(self, loss: torch.Tensor):
        """Backwards pass, after scaling the loss if ``amp`` is
        enabled.

        Parameters
        ----------
        loss : torch.Tensor
            Loss value.
        """
        self.scaler.scale(loss).backward()

    def step(self, optimizer: torch.optim.Optimizer):
        """Steps the optimizer, using a ``scaler`` if ``amp`` is
        enabled.

        Parameters
        ----------
        optimizer : torch.optim.Optimizer
            Optimizer to step forward.
        """
        self.scaler.step(optimizer)

    def update(self):
        """Updates the scale factor."""
        self.scaler.update()

    def prepare_dataloader(
        self, dataset: typing.Iterable, start_idx: int = None, **kwargs
    ):
        """Wraps a dataset with a DataLoader, using the correct sampler if DDP is
        enabled.

        Parameters
        ----------
        dataset : typing.Iterable
            Dataset to build Dataloader around.
        start_idx : int, optional
            Start index of sampler, useful if resuming from some epoch,
            by default None

        Returns
        -------
        _type_
            _description_
        """

        if self.use_ddp:
            sampler = DistributedSampler(
                dataset,
                start_idx,
                num_replicas=self.world_size,
                rank=self.local_rank,
            )
            if "num_workers" in kwargs:
                kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
            kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
        else:
            sampler = SequentialSampler(dataset, start_idx)

        dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
        return dataloader

    @staticmethod
    def unwrap(model):
        """Unwraps the model if it was wrapped in DDP or DP, otherwise
        just returns the model. Use this to unwrap the model returned by
        :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
        """
        if hasattr(model, "module"):
            return model.module
        return model