Spaces:
Paused
Paused
File size: 4,646 Bytes
45ee559 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import unittest
import torch
from tests import get_tests_input_path
from TTS.vc.configs.freevc_config import FreeVCConfig
from TTS.vc.models.freevc import FreeVC
# pylint: disable=unused-variable
# pylint: disable=no-self-use
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
c = FreeVCConfig()
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
BATCH_SIZE = 3
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class TestFreeVC(unittest.TestCase):
def _create_inputs(self, config, batch_size=2):
input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device)
input_lengths = torch.randint(100, 30 * config.audio["hop_length"], (batch_size,)).long().to(device)
input_lengths[-1] = 30 * config.audio["hop_length"]
spec = torch.rand(batch_size, 30, config.audio["filter_length"] // 2 + 1).to(device)
mel = torch.rand(batch_size, 30, config.audio["n_mel_channels"]).to(device)
spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
spec_lengths[-1] = spec.size(2)
waveform = torch.rand(batch_size, spec.size(2) * config.audio["hop_length"]).to(device)
return input_dummy, input_lengths, mel, spec, spec_lengths, waveform
@staticmethod
def _create_inputs_inference():
source_wav = torch.rand(16000)
target_wav = torch.rand(16000)
return source_wav, target_wav
@staticmethod
def _check_parameter_changes(model, model_ref):
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
def test_methods(self):
config = FreeVCConfig()
model = FreeVC(config).to(device)
model.load_pretrained_speaker_encoder()
model.init_multispeaker(config)
wavlm_feats = model.extract_wavlm_features(torch.rand(1, 16000))
assert wavlm_feats.shape == (1, 1024, 49), wavlm_feats.shape
def test_load_audio(self):
config = FreeVCConfig()
model = FreeVC(config).to(device)
wav = model.load_audio(WAV_FILE)
wav2 = model.load_audio(wav)
assert all(torch.isclose(wav, wav2))
def _test_forward(self, batch_size):
# create model
config = FreeVCConfig()
model = FreeVC(config).to(device)
model.train()
print(" > Num parameters for FreeVC model:%s" % (count_parameters(model)))
_, _, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size)
wavlm_vec = model.extract_wavlm_features(waveform)
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
y = model.forward(wavlm_vec, spec, None, mel, spec_lengths, wavlm_vec_lengths)
# TODO: assert with training implementation
def test_forward(self):
self._test_forward(1)
self._test_forward(3)
def _test_inference(self, batch_size):
config = FreeVCConfig()
model = FreeVC(config).to(device)
model.eval()
_, _, mel, _, _, waveform = self._create_inputs(config, batch_size)
wavlm_vec = model.extract_wavlm_features(waveform)
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
output_wav = model.inference(wavlm_vec, None, mel, wavlm_vec_lengths)
assert (
output_wav.shape[-1] // config.audio.hop_length == wavlm_vec.shape[-1]
), f"{output_wav.shape[-1] // config.audio.hop_length} != {wavlm_vec.shape}"
def test_inference(self):
self._test_inference(1)
self._test_inference(3)
def test_voice_conversion(self):
config = FreeVCConfig()
model = FreeVC(config).to(device)
model.eval()
source_wav, target_wav = self._create_inputs_inference()
output_wav = model.voice_conversion(source_wav, target_wav)
assert (
output_wav.shape[0] + config.audio.hop_length == source_wav.shape[0]
), f"{output_wav.shape} != {source_wav.shape}"
def test_train_step(self):
...
def test_train_eval_log(self):
...
def test_test_run(self):
...
def test_load_checkpoint(self):
...
def test_get_criterion(self):
...
def test_init_from_config(self):
...
|