File size: 1,419 Bytes
6faf7e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
import random

import numpy as np
import torch

from trainer import CBHGTrainer, Seq2SeqTrainer, GPTTrainer

SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def train_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_kind", dest="model_kind", type=str, required=True)
    parser.add_argument(
        "--model_desc", dest="model_desc", type=str, required=False, default=""
    )
    parser.add_argument("--config", dest="config", type=str, required=True)
    parser.add_argument(
        "--reset_dir",
        dest="clear_dir",
        action="store_true",
        help="deletes everything under this config's folder.",
    )
    return parser


parser = train_parser()
args = parser.parse_args()


if args.model_kind in ["seq2seq"]:
    trainer = Seq2SeqTrainer(args.config, args.model_kind, args.model_desc)
elif args.model_kind in ["tacotron_based"]:
    trainer = Seq2SeqTrainer(args.config, args.model_kind, args.model_desc)
elif args.model_kind in ["baseline", "cbhg"]:
    trainer = CBHGTrainer(args.config, args.model_kind, args.model_desc)
elif args.model_kind in ["gpt"]:
    trainer = GPTTrainer(args.config, args.model_kind, args.model_desc)
else:
    raise ValueError("The model kind is not supported")

trainer.run()