File size: 9,011 Bytes
8235b4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Author: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf

import json
import logging
from pathlib import Path
import os
import time

import numpy as np
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR

from . import distrib
from .separate import separate
from .evaluate import evaluate
from .models.sisnr_loss import cal_loss
from .models.swave import SWave
from .utils import bold, copy_state, pull_metric, serialize_model, swap_state, LogProgress


logger = logging.getLogger(__name__)


class Solver(object):
    def __init__(self, data, model, optimizer, args):
        self.tr_loader = data['tr_loader']
        self.cv_loader = data['cv_loader']
        self.tt_loader = data['tt_loader']
        self.model = model
        self.dmodel = distrib.wrap(model)
        self.optimizer = optimizer
        if args.lr_sched == 'step':
            self.sched = StepLR(
                self.optimizer, step_size=args.step.step_size, gamma=args.step.gamma)
        elif args.lr_sched == 'plateau':
            self.sched = ReduceLROnPlateau(
                self.optimizer, factor=args.plateau.factor, patience=args.plateau.patience)
        else:
            self.sched = None

        # Training config
        self.device = args.device
        self.epochs = args.epochs
        self.max_norm = args.max_norm

        # Checkpoints
        self.continue_from = args.continue_from
        self.eval_every = args.eval_every
        self.checkpoint = Path(
            args.checkpoint_file) if args.checkpoint else None
        if self.checkpoint:
            logger.debug("Checkpoint will be saved to %s",
                         self.checkpoint.resolve())
        self.history_file = args.history_file

        self.best_state = None
        self.restart = args.restart
        # keep track of losses
        self.history = []

        # Where to save samples
        self.samples_dir = args.samples_dir

        # logging
        self.num_prints = args.num_prints

        # for seperation tests
        self.args = args
        self._reset()

    def _serialize(self, path):
        package = {}
        package['model'] = serialize_model(self.model)
        package['optimizer'] = self.optimizer.state_dict()
        package['history'] = self.history
        package['best_state'] = self.best_state
        package['args'] = self.args
        torch.save(package, path)

    def _reset(self):
        load_from = None
        # Reset
        if self.checkpoint and self.checkpoint.exists() and not self.restart:
            load_from = self.checkpoint
        elif self.continue_from:
            load_from = self.continue_from

        if load_from:
            logger.info(f'Loading checkpoint model: {load_from}')
            package = torch.load(load_from, 'cpu')
            if load_from == self.continue_from and self.args.continue_best:
                self.model.load_state_dict(package['best_state'])
            else:
                self.model.load_state_dict(package['model']['state'])

            if 'optimizer' in package and not self.args.continue_best:
                self.optimizer.load_state_dict(package['optimizer'])
            self.history = package['history']
            self.best_state = package['best_state']

    def train(self):
        # Optimizing the model
        if self.history:
            logger.info("Replaying metrics from previous run")
        for epoch, metrics in enumerate(self.history):
            info = " ".join(f"{k}={v:.5f}" for k, v in metrics.items())
            logger.info(f"Epoch {epoch}: {info}")

        for epoch in range(len(self.history), self.epochs):
            # Train one epoch
            self.model.train()  # Turn on BatchNorm & Dropout
            start = time.time()
            logger.info('-' * 70)
            logger.info("Training...")
            train_loss = self._run_one_epoch(epoch)
            logger.info(bold(f'Train Summary | End of Epoch {epoch + 1} | '
                             f'Time {time.time() - start:.2f}s | Train Loss {train_loss:.5f}'))

            # Cross validation
            logger.info('-' * 70)
            logger.info('Cross validation...')
            self.model.eval()  # Turn off Batchnorm & Dropout
            with torch.no_grad():
                valid_loss = self._run_one_epoch(epoch, cross_valid=True)
            logger.info(bold(f'Valid Summary | End of Epoch {epoch + 1} | '
                             f'Time {time.time() - start:.2f}s | Valid Loss {valid_loss:.5f}'))

            # learning rate scheduling
            if self.sched:
                if self.args.lr_sched == 'plateau':
                    self.sched.step(valid_loss)
                else:
                    self.sched.step()
                logger.info(
                    f'Learning rate adjusted: {self.optimizer.state_dict()["param_groups"][0]["lr"]:.5f}')

            best_loss = min(pull_metric(self.history, 'valid') + [valid_loss])
            metrics = {'train': train_loss,
                       'valid': valid_loss, 'best': best_loss}
            # Save the best model
            if valid_loss == best_loss or self.args.keep_last:
                logger.info(bold('New best valid loss %.4f'), valid_loss)
                self.best_state = copy_state(self.model.state_dict())

            # evaluate and separate samples every 'eval_every' argument number of epochs
            # also evaluate on last epoch
            if (epoch + 1) % self.eval_every == 0 or epoch == self.epochs - 1:
                # Evaluate on the testset
                logger.info('-' * 70)
                logger.info('Evaluating on the test set...')
                # We switch to the best known model for testing
                with swap_state(self.model, self.best_state):
                    sisnr, pesq, stoi = evaluate(
                        self.args, self.model, self.tt_loader, self.args.sample_rate)
                metrics.update({'sisnr': sisnr, 'pesq': pesq, 'stoi': stoi})

                # separate some samples
                logger.info('Separate and save samples...')
                separate(self.args, self.model, self.samples_dir)

            self.history.append(metrics)
            info = " | ".join(
                f"{k.capitalize()} {v:.5f}" for k, v in metrics.items())
            logger.info('-' * 70)
            logger.info(bold(f"Overall Summary | Epoch {epoch + 1} | {info}"))

            if distrib.rank == 0:
                json.dump(self.history, open(self.history_file, "w"), indent=2)
                # Save model each epoch
                if self.checkpoint:
                    self._serialize(self.checkpoint)
                    logger.debug("Checkpoint saved to %s",
                                 self.checkpoint.resolve())

    def _run_one_epoch(self, epoch, cross_valid=False):
        total_loss = 0
        data_loader = self.tr_loader if not cross_valid else self.cv_loader

        # get a different order for distributed training, otherwise this will get ignored
        data_loader.epoch = epoch

        label = ["Train", "Valid"][cross_valid]
        name = label + f" | Epoch {epoch + 1}"
        logprog = LogProgress(logger, data_loader,
                              updates=self.num_prints, name=name)
        for i, data in enumerate(logprog):
            mixture, lengths, sources = [x.to(self.device) for x in data]
            estimate_source = self.dmodel(mixture)

            # only eval last layer
            if cross_valid:
                estimate_source = estimate_source[-1:]

            loss = 0
            cnt = len(estimate_source)
            # apply a loss function after each layer
            with torch.autograd.set_detect_anomaly(True):
                for c_idx, est_src in enumerate(estimate_source):
                    coeff = ((c_idx+1)*(1/cnt))
                    loss_i = 0
                    # SI-SNR loss
                    sisnr_loss, snr, est_src, reorder_est_src = cal_loss(
                        sources, estimate_source[c_idx], lengths)
                    loss += (coeff * sisnr_loss)
                loss /= len(estimate_source)

                if not cross_valid:
                    # optimize model in training mode
                    self.optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.max_norm)
                    self.optimizer.step()

            total_loss += loss.item()
            logprog.update(loss=format(total_loss / (i + 1), ".5f"))

            # Just in case, clear some memory
            del loss, estimate_source
        return distrib.average([total_loss / (i + 1)], i + 1)[0]