Spaces:
Runtime error
Runtime error
""" | |
Utility functions | |
""" | |
import pickle | |
from pathlib import Path | |
import pax | |
import toml | |
import yaml | |
from tacotron import Tacotron | |
def load_tacotron_config(config_file=Path("tacotron.toml")): | |
""" | |
Load the project configurations | |
""" | |
return toml.load(config_file)["tacotron"] | |
def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path): | |
""" | |
load checkpoint from disk | |
""" | |
with open(path, "rb") as f: | |
dic = pickle.load(f) | |
if net is not None: | |
net = net.load_state_dict(dic["model_state_dict"]) | |
if optim is not None: | |
optim = optim.load_state_dict(dic["optim_state_dict"]) | |
return dic["step"], net, optim | |
def create_tacotron_model(config): | |
""" | |
return a random initialized Tacotron model | |
""" | |
return Tacotron( | |
mel_dim=config["MEL_DIM"], | |
attn_bias=config["ATTN_BIAS"], | |
rr=config["RR"], | |
max_rr=config["MAX_RR"], | |
mel_min=config["MEL_MIN"], | |
sigmoid_noise=config["SIGMOID_NOISE"], | |
pad_token=config["PAD_TOKEN"], | |
prenet_dim=config["PRENET_DIM"], | |
attn_hidden_dim=config["ATTN_HIDDEN_DIM"], | |
attn_rnn_dim=config["ATTN_RNN_DIM"], | |
rnn_dim=config["RNN_DIM"], | |
postnet_dim=config["POSTNET_DIM"], | |
text_dim=config["TEXT_DIM"], | |
) | |
def load_wavegru_config(config_file): | |
""" | |
Load project configurations | |
""" | |
with open(config_file, "r", encoding="utf-8") as f: | |
return yaml.safe_load(f) | |
def load_wavegru_ckpt(net, optim, ckpt_file): | |
""" | |
load training checkpoint from file | |
""" | |
with open(ckpt_file, "rb") as f: | |
dic = pickle.load(f) | |
if net is not None: | |
net = net.load_state_dict(dic["net_state_dict"]) | |
if optim is not None: | |
optim = optim.load_state_dict(dic["optim_state_dict"]) | |
return dic["step"], net, optim | |