File size: 4,018 Bytes
e8e478e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import time
from datetime import datetime
from tqdm import tqdm
from tensorboardX import SummaryWriter
import torch
import torchinfo
import numpy as np

import options
from validate import validate, calculate_acc
from datasets import *
from utils.logger import create_logger
from utils.earlystop import EarlyStopping
from networks.trainer import Trainer



if __name__ == '__main__':
    train_opt = options.TrainOptions().parse()
    # val_opt = options.TestOptions().parse()
    
    # logger
    logger = create_logger(output_dir=train_opt.checkpoints_dir, name="FeatureTransformer")
    logger.info(f"working dir: {train_opt.checkpoints_dir}")

 
    model = Trainer(train_opt)
    # logger.info(opt.gpu_ids[0])
    logger.info(model.device)
    
    # extract_feature_model = model.extract_feature_model
    train_loader, val_loader = create_train_val_dataloader(train_opt, clip_model = None, transform = model.clip_model.preprocess, k_split=0.8)
    logger.info(f"train    {len(train_loader)}")
    logger.info(f"validate {len(val_loader)}")


    train_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "train"))
    val_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "val"))
        
    early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.001, verbose=True)
    
    start_time = time.time()
    logger.info(torchinfo.summary(model.model, input_size=(train_opt.batch_size, 16, 768), col_width=20,
                  col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0))
    
    
    logger.info("Length of train loader: %d" %(len(train_loader)))
    for epoch in range(train_opt.niter):
        y_true, y_pred = [], []
        pbar = tqdm(train_loader)
        for i, data in enumerate(pbar):
            pbar.set_description(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

            model.total_steps += 1

            model.set_input(data)
            model.optimize_parameters()

            y_pred.extend(model.output.sigmoid().flatten().tolist())
            y_true.extend(data[1].flatten().tolist())

            if model.total_steps % train_opt.loss_freq == 0:
                logger.info("Train loss: {} at step: {}".format(model.loss, model.total_steps))
                train_writer.add_scalar('loss', model.loss, model.total_steps)
                logger.info("Iter time: {}".format((time.time()-start_time)/model.total_steps)  )

            if model.total_steps in [10,30,50,100,1000,5000,10000] and False: # save models at these iters 
                model.save_networks('model_iters_%s.pth' % model.total_steps)
            # logger.info("trained one batch")
            pbar.set_postfix_str(f"loss: {model.loss}, ")
        r_acc0, f_acc0, acc0 = calculate_acc(np.array(y_true), np.array(y_pred), 0.5)
        logger.info(f"TrainSet r_acc: {r_acc0}, f_acc: {f_acc0}, acc: {acc0}")

        if epoch % train_opt.save_epoch_freq == 0:
            logger.info('saving the model at the end of epoch %d' % (epoch))
            model.save_networks( 'model_epoch_%s.pth' % epoch )

        # Validation
        model.eval()
        ap, r_acc, f_acc, acc = validate(model, val_loader, logger=logger)
        val_writer.add_scalar('accuracy', acc, model.total_steps)
        val_writer.add_scalar('ap', ap, model.total_steps)
        logger.info("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap))

        early_stopping(acc, model.model)
        if early_stopping.early_stop:
            cont_train = model.adjust_learning_rate()
            if cont_train:
                logger.info("Learning rate dropped by 10, continue training...")
                early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.002, verbose=True)
            else:
                logger.info("Early stopping.")
                break
        
        model.train()