File size: 3,732 Bytes
2252f3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/utils/base_trainer.py
from __future__ import division
import logging
from utils import CheckpointSaver
from tensorboardX import SummaryWriter

import torch
from tqdm import tqdm

tqdm.monitor_interval = 0

logger = logging.getLogger(__name__)


class BaseTrainer(object):
    """Base class for Trainer objects.
    Takes care of checkpointing/logging/resuming training.
    """

    def __init__(self, options):
        self.options = options
        if options.multiprocessing_distributed:
            self.device = torch.device('cuda', options.gpu)
        else:
            self.device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')
        # override this function to define your model, optimizers etc.
        self.saver = CheckpointSaver(save_dir=options.checkpoint_dir,
                                     overwrite=options.overwrite)
        if options.rank == 0:
            self.summary_writer = SummaryWriter(self.options.summary_dir)
        self.init_fn()

        self.checkpoint = None
        if options.resume and self.saver.exists_checkpoint():
            self.checkpoint = self.saver.load_checkpoint(
                self.models_dict, self.optimizers_dict)

        if self.checkpoint is None:
            self.epoch_count = 0
            self.step_count = 0
        else:
            self.epoch_count = self.checkpoint['epoch']
            self.step_count = self.checkpoint['total_step_count']

        if self.checkpoint is not None:
            self.checkpoint_batch_idx = self.checkpoint['batch_idx']
        else:
            self.checkpoint_batch_idx = 0

        self.best_performance = float('inf')

    def load_pretrained(self, checkpoint_file=None):
        """Load a pretrained checkpoint.
        This is different from resuming training using --resume.
        """
        if checkpoint_file is not None:
            checkpoint = torch.load(checkpoint_file)
            for model in self.models_dict:
                if model in checkpoint:
                    self.models_dict[model].load_state_dict(checkpoint[model],
                                                            strict=True)
                    print(f'Checkpoint {model} loaded')

    def move_dict_to_device(self, dict, device, tensor2float=False):
        for k, v in dict.items():
            if isinstance(v, torch.Tensor):
                if tensor2float:
                    dict[k] = v.float().to(device)
                else:
                    dict[k] = v.to(device)

    # The following methods (with the possible exception of test) have to be implemented in the derived classes
    def train(self, epoch):
        raise NotImplementedError('You need to provide an train method')

    def init_fn(self):
        raise NotImplementedError('You need to provide an _init_fn method')

    def train_step(self, input_batch):
        raise NotImplementedError('You need to provide a _train_step method')

    def train_summaries(self, input_batch):
        raise NotImplementedError(
            'You need to provide a _train_summaries method')

    def visualize(self, input_batch):
        raise NotImplementedError('You need to provide a visualize method')

    def validate(self):
        pass

    def test(self):
        pass

    def evaluate(self):
        pass

    def fit(self):
        # Run training for num_epochs epochs
        for epoch in tqdm(range(self.epoch_count, self.options.num_epochs),
                          total=self.options.num_epochs,
                          initial=self.epoch_count):
            self.epoch_count = epoch
            self.train(epoch)
        return