|
import matplotlib.pyplot as plt |
|
import json |
|
import torch |
|
import torchaudio |
|
|
|
def configure_args(config, args): |
|
for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]: |
|
if getattr(args, key) != None: |
|
config["general"][key] = str(getattr(args, key)) |
|
|
|
for key in ["n_train", "n_val", "n_test"]: |
|
if getattr(args, key) != None: |
|
config["preprocess"][key] = getattr(args, key) |
|
|
|
for key in ["alpha", "beta", "learning_rate", "epoch"]: |
|
if getattr(args, key) != None: |
|
config["train"][key] = getattr(args, key) |
|
|
|
for key in ["load_pretrained", "early_stopping"]: |
|
config["train"][key] = getattr(args, key) |
|
|
|
if args.feature_loss_type != None: |
|
config["train"]["feature_loss"]["type"] = args.feature_loss_type |
|
|
|
for key in ["pretrained_path"]: |
|
if getattr(args, key) != None: |
|
config["train"][key] = str(getattr(args, key)) |
|
|
|
return config, args |
|
|