Ababababababbababa's picture
Duplicate from arbml/Ashaar
6faf7e7
raw
history blame
No virus
1.42 kB
import argparse
import random
import numpy as np
import torch
from trainer import CBHGTrainer, Seq2SeqTrainer, GPTTrainer
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model_kind", dest="model_kind", type=str, required=True)
parser.add_argument(
"--model_desc", dest="model_desc", type=str, required=False, default=""
)
parser.add_argument("--config", dest="config", type=str, required=True)
parser.add_argument(
"--reset_dir",
dest="clear_dir",
action="store_true",
help="deletes everything under this config's folder.",
)
return parser
parser = train_parser()
args = parser.parse_args()
if args.model_kind in ["seq2seq"]:
trainer = Seq2SeqTrainer(args.config, args.model_kind, args.model_desc)
elif args.model_kind in ["tacotron_based"]:
trainer = Seq2SeqTrainer(args.config, args.model_kind, args.model_desc)
elif args.model_kind in ["baseline", "cbhg"]:
trainer = CBHGTrainer(args.config, args.model_kind, args.model_desc)
elif args.model_kind in ["gpt"]:
trainer = GPTTrainer(args.config, args.model_kind, args.model_desc)
else:
raise ValueError("The model kind is not supported")
trainer.run()