vta-ldm / audioldm /hifigan /utilities.py
fffiloni's picture
Upload 130 files
c673f60 verified
raw
history blame
2.13 kB
import os
import json
import torch
import numpy as np
import audioldm.hifigan as hifigan
HIFIGAN_16K_64 = {
"resblock": "1",
"num_gpus": 6,
"batch_size": 16,
"learning_rate": 0.0002,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [5, 4, 2, 2, 2],
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
"upsample_initial_channel": 1024,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"segment_size": 8192,
"num_mels": 64,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 160,
"win_size": 1024,
"sampling_rate": 16000,
"fmin": 0,
"fmax": 8000,
"fmax_for_loss": None,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1,
},
}
def get_available_checkpoint_keys(model, ckpt):
print("==> Attemp to reload from %s" % ckpt)
state_dict = torch.load(ckpt)["state_dict"]
current_state_dict = model.state_dict()
new_state_dict = {}
for k in state_dict.keys():
if (
k in current_state_dict.keys()
and current_state_dict[k].size() == state_dict[k].size()
):
new_state_dict[k] = state_dict[k]
else:
print("==> WARNING: Skipping %s" % k)
print(
"%s out of %s keys are matched"
% (len(new_state_dict.keys()), len(state_dict.keys()))
)
return new_state_dict
def get_param_num(model):
num_param = sum(param.numel() for param in model.parameters())
return num_param
def get_vocoder(config, device):
config = hifigan.AttrDict(HIFIGAN_16K_64)
vocoder = hifigan.Generator(config)
vocoder.eval()
vocoder.remove_weight_norm()
vocoder.to(device)
return vocoder
def vocoder_infer(mels, vocoder, lengths=None):
vocoder.eval()
with torch.no_grad():
wavs = vocoder(mels).squeeze(1)
wavs = (wavs.cpu().numpy() * 32768).astype("int16")
if lengths is not None:
wavs = wavs[:, :lengths]
return wavs