File size: 2,702 Bytes
117183e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import datetime
import os
import logging

import omegaconf

from utils.logger import prepare_logging
from torch.utils.data import DataLoader
from data.datasets import get_datasets
from models.model import NamingEnhancementModel
from utils.setup_optim_scheduler import get_optimizer_scheduler
from utils.evaluator import Evaluator
from utils.setup_criterion import get_criterion
from utils.trainer import Trainer


def main(config: omegaconf.DictConfig):

    os.environ["CUDA_VISIBLE_DEVICES"] = str(config.train.cuda_visible_device)
    save_path = prepare_logging()

    logging.info(f"Starting training at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    logging.info(f"Saving logs to {save_path}")
    logging.info(f"Config file: {OmegaConf.to_yaml(config)}")

    train_dataset, valid_dataset, test_dataset = get_datasets(config.data)

    msg = f"Training with {len(train_dataset)} image pairs"
    if valid_dataset is not None:
        msg += f", validating with {len(valid_dataset)} image pairs"
    msg += f" and testing with {len(test_dataset)} image pairs."
    logging.info(msg)

    train_loader = DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=True)
    if valid_dataset is not None:
        valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
    else:
        valid_loader = None
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    model = NamingEnhancementModel(config.model)
    if config.model.ckpt_path is not None:
        model.load_state_dict(torch.load(config.model.ckpt_path)["model_state_dict"])
    model.cuda()

    criterion = get_criterion(config.train.criterion)

    optimizer, scheduler = get_optimizer_scheduler(model,
                                                   config.train.optimizer,
                                                   config.train.scheduler if "scheduler" in config.train else None)

    if valid_loader is not None:
        valid_evaluator = Evaluator(valid_loader, config.eval.metrics, 'valid', save_path, config.eval.metric_to_save)
    else:
        valid_evaluator = None

    test_evaluator = Evaluator(test_loader, config.eval.metrics, 'test', save_path, config.eval.metric_to_save)


    trainer = Trainer(model, optimizer, criterion, scheduler, train_loader, valid_evaluator, test_evaluator, config.train, config.eval)
    trainer.train()

if __name__ == "__main__":
    from omegaconf import OmegaConf
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='configs/mit5k_dpe_config.yaml')
    args = parser.parse_args()

    config = OmegaConf.load(args.config)
    main(config)