|
import argparse |
|
import random |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from trainer import CBHGTrainer, Seq2SeqTrainer, GPTTrainer |
|
|
|
SEED = 1234 |
|
random.seed(SEED) |
|
np.random.seed(SEED) |
|
torch.manual_seed(SEED) |
|
torch.cuda.manual_seed(SEED) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def train_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_kind", dest="model_kind", type=str, required=True) |
|
parser.add_argument( |
|
"--model_desc", dest="model_desc", type=str, required=False, default="" |
|
) |
|
parser.add_argument("--config", dest="config", type=str, required=True) |
|
parser.add_argument( |
|
"--reset_dir", |
|
dest="clear_dir", |
|
action="store_true", |
|
help="deletes everything under this config's folder.", |
|
) |
|
return parser |
|
|
|
|
|
parser = train_parser() |
|
args = parser.parse_args() |
|
|
|
|
|
if args.model_kind in ["seq2seq"]: |
|
trainer = Seq2SeqTrainer(args.config, args.model_kind, args.model_desc) |
|
elif args.model_kind in ["tacotron_based"]: |
|
trainer = Seq2SeqTrainer(args.config, args.model_kind, args.model_desc) |
|
elif args.model_kind in ["baseline", "cbhg"]: |
|
trainer = CBHGTrainer(args.config, args.model_kind, args.model_desc) |
|
elif args.model_kind in ["gpt"]: |
|
trainer = GPTTrainer(args.config, args.model_kind, args.model_desc) |
|
else: |
|
raise ValueError("The model kind is not supported") |
|
|
|
trainer.run() |
|
|