Spaces:
Sleeping
Sleeping
File size: 1,419 Bytes
6faf7e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import argparse
import random
import numpy as np
import torch
from trainer import CBHGTrainer, Seq2SeqTrainer, GPTTrainer
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model_kind", dest="model_kind", type=str, required=True)
parser.add_argument(
"--model_desc", dest="model_desc", type=str, required=False, default=""
)
parser.add_argument("--config", dest="config", type=str, required=True)
parser.add_argument(
"--reset_dir",
dest="clear_dir",
action="store_true",
help="deletes everything under this config's folder.",
)
return parser
parser = train_parser()
args = parser.parse_args()
if args.model_kind in ["seq2seq"]:
trainer = Seq2SeqTrainer(args.config, args.model_kind, args.model_desc)
elif args.model_kind in ["tacotron_based"]:
trainer = Seq2SeqTrainer(args.config, args.model_kind, args.model_desc)
elif args.model_kind in ["baseline", "cbhg"]:
trainer = CBHGTrainer(args.config, args.model_kind, args.model_desc)
elif args.model_kind in ["gpt"]:
trainer = GPTTrainer(args.config, args.model_kind, args.model_desc)
else:
raise ValueError("The model kind is not supported")
trainer.run()
|