Spaces:
Paused
Paused
import os | |
import random | |
import unittest | |
from copy import deepcopy | |
import torch | |
from tests import get_tests_output_path | |
from TTS.tts.configs.overflow_config import OverflowConfig | |
from TTS.tts.layers.overflow.common_layers import Encoder, Outputnet, OverflowUtils | |
from TTS.tts.layers.overflow.decoder import Decoder | |
from TTS.tts.layers.overflow.neural_hmm import EmissionModel, NeuralHMM, TransitionModel | |
from TTS.tts.models.overflow import Overflow | |
from TTS.tts.utils.helpers import sequence_mask | |
from TTS.utils.audio import AudioProcessor | |
# pylint: disable=unused-variable | |
torch.manual_seed(1) | |
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
config_global = OverflowConfig(num_chars=24) | |
ap = AudioProcessor.init_from_config(config_global) | |
config_path = os.path.join(get_tests_output_path(), "test_model_config.json") | |
output_path = os.path.join(get_tests_output_path(), "train_outputs") | |
parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt") | |
torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path) | |
def _create_inputs(batch_size=8): | |
max_len_t, max_len_m = random.randint(25, 50), random.randint(50, 80) | |
input_dummy = torch.randint(0, 24, (batch_size, max_len_t)).long().to(device) | |
input_lengths = torch.randint(20, max_len_t, (batch_size,)).long().to(device).sort(descending=True)[0] | |
input_lengths[0] = max_len_t | |
input_dummy = input_dummy * sequence_mask(input_lengths) | |
mel_spec = torch.randn(batch_size, max_len_m, config_global.audio["num_mels"]).to(device) | |
mel_lengths = torch.randint(40, max_len_m, (batch_size,)).long().to(device).sort(descending=True)[0] | |
mel_lengths[0] = max_len_m | |
mel_spec = mel_spec * sequence_mask(mel_lengths).unsqueeze(2) | |
return input_dummy, input_lengths, mel_spec, mel_lengths | |
def get_model(config=None): | |
if config is None: | |
config = config_global | |
config.mel_statistics_parameter_path = parameter_path | |
model = Overflow(config) | |
model = model.to(device) | |
return model | |
def reset_all_weights(model): | |
""" | |
refs: | |
- https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6 | |
- https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch | |
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html | |
""" | |
def weight_reset(m): | |
# - check if the current module has reset_parameters & if it's callabed called it on m | |
reset_parameters = getattr(m, "reset_parameters", None) | |
if callable(reset_parameters): | |
m.reset_parameters() | |
# Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html | |
model.apply(fn=weight_reset) | |
class TestOverflow(unittest.TestCase): | |
def test_forward(self): | |
model = get_model() | |
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() | |
outputs = model(input_dummy, input_lengths, mel_spec, mel_lengths) | |
self.assertEqual(outputs["log_probs"].shape, (input_dummy.shape[0],)) | |
self.assertEqual(model.state_per_phone * max(input_lengths), outputs["alignments"].shape[2]) | |
def test_inference(self): | |
model = get_model() | |
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() | |
output_dict = model.inference(input_dummy) | |
self.assertEqual(output_dict["model_outputs"].shape[2], config_global.out_channels) | |
def test_init_from_config(self): | |
config = deepcopy(config_global) | |
config.mel_statistics_parameter_path = parameter_path | |
config.prenet_dim = 256 | |
model = Overflow.init_from_config(config_global) | |
self.assertEqual(model.prenet_dim, config.prenet_dim) | |
class TestOverflowEncoder(unittest.TestCase): | |
def get_encoder(state_per_phone): | |
config = deepcopy(config_global) | |
config.state_per_phone = state_per_phone | |
config.num_chars = 24 | |
return Encoder(config.num_chars, config.state_per_phone, config.prenet_dim, config.encoder_n_convolutions).to( | |
device | |
) | |
def test_forward_with_state_per_phone_multiplication(self): | |
for s_p_p in [1, 2, 3]: | |
input_dummy, input_lengths, _, _ = _create_inputs() | |
model = self.get_encoder(s_p_p) | |
x, x_len = model(input_dummy, input_lengths) | |
self.assertEqual(x.shape[1], input_dummy.shape[1] * s_p_p) | |
def test_inference_with_state_per_phone_multiplication(self): | |
for s_p_p in [1, 2, 3]: | |
input_dummy, input_lengths, _, _ = _create_inputs() | |
model = self.get_encoder(s_p_p) | |
x, x_len = model.inference(input_dummy, input_lengths) | |
self.assertEqual(x.shape[1], input_dummy.shape[1] * s_p_p) | |
class TestOverflowUtils(unittest.TestCase): | |
def test_logsumexp(self): | |
a = torch.randn(10) # random numbers | |
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all()) | |
a = torch.zeros(10) # all zeros | |
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all()) | |
a = torch.ones(10) # all ones | |
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all()) | |
class TestOverflowDecoder(unittest.TestCase): | |
def _get_decoder(num_flow_blocks_dec=None, hidden_channels_dec=None, reset_weights=True): | |
config = deepcopy(config_global) | |
config.num_flow_blocks_dec = ( | |
num_flow_blocks_dec if num_flow_blocks_dec is not None else config.num_flow_blocks_dec | |
) | |
config.hidden_channels_dec = ( | |
hidden_channels_dec if hidden_channels_dec is not None else config.hidden_channels_dec | |
) | |
config.dropout_p_dec = 0.0 # turn off dropout to check invertibility | |
decoder = Decoder( | |
config.out_channels, | |
config.hidden_channels_dec, | |
config.kernel_size_dec, | |
config.dilation_rate, | |
config.num_flow_blocks_dec, | |
config.num_block_layers, | |
config.dropout_p_dec, | |
config.num_splits, | |
config.num_squeeze, | |
config.sigmoid_scale, | |
config.c_in_channels, | |
).to(device) | |
if reset_weights: | |
reset_all_weights(decoder) | |
return decoder | |
def test_decoder_forward_backward(self): | |
for num_flow_blocks_dec in [8, None]: | |
for hidden_channels_dec in [100, None]: | |
decoder = self._get_decoder(num_flow_blocks_dec, hidden_channels_dec) | |
_, _, mel_spec, mel_lengths = _create_inputs() | |
z, z_len, _ = decoder(mel_spec.transpose(1, 2), mel_lengths) | |
mel_spec_, mel_lengths_, _ = decoder(z, z_len, reverse=True) | |
mask = sequence_mask(z_len).unsqueeze(1) | |
mel_spec = mel_spec[:, : z.shape[2], :].transpose(1, 2) * mask | |
z = z * mask | |
self.assertTrue( | |
torch.isclose(mel_spec, mel_spec_, atol=1e-2).all(), | |
f"num_flow_blocks_dec={num_flow_blocks_dec}, hidden_channels_dec={hidden_channels_dec}", | |
) | |
class TestNeuralHMM(unittest.TestCase): | |
def _get_neural_hmm(deterministic_transition=None): | |
config = deepcopy(config_global) | |
neural_hmm = NeuralHMM( | |
config.out_channels, | |
config.ar_order, | |
config.deterministic_transition if deterministic_transition is None else deterministic_transition, | |
config.encoder_in_out_features, | |
config.prenet_type, | |
config.prenet_dim, | |
config.prenet_n_layers, | |
config.prenet_dropout, | |
config.prenet_dropout_at_inference, | |
config.memory_rnn_dim, | |
config.outputnet_size, | |
config.flat_start_params, | |
config.std_floor, | |
).to(device) | |
return neural_hmm | |
def _get_emission_model(): | |
return EmissionModel().to(device) | |
def _get_transition_model(): | |
return TransitionModel().to(device) | |
def _get_embedded_input(): | |
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() | |
input_dummy = torch.nn.Embedding(config_global.num_chars, config_global.encoder_in_out_features).to(device)( | |
input_dummy | |
) | |
return input_dummy, input_lengths, mel_spec, mel_lengths | |
def test_neural_hmm_forward(self): | |
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() | |
neural_hmm = self._get_neural_hmm() | |
log_prob, log_alpha_scaled, transition_matrix, means = neural_hmm( | |
input_dummy, input_lengths, mel_spec.transpose(1, 2), mel_lengths | |
) | |
self.assertEqual(log_prob.shape, (input_dummy.shape[0],)) | |
self.assertEqual(log_alpha_scaled.shape, transition_matrix.shape) | |
def test_mask_lengths(self): | |
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() | |
neural_hmm = self._get_neural_hmm() | |
log_prob, log_alpha_scaled, transition_matrix, means = neural_hmm( | |
input_dummy, input_lengths, mel_spec.transpose(1, 2), mel_lengths | |
) | |
log_c = torch.randn(mel_spec.shape[0], mel_spec.shape[1], device=device) | |
log_c, log_alpha_scaled = neural_hmm._mask_lengths( # pylint: disable=protected-access | |
mel_lengths, log_c, log_alpha_scaled | |
) | |
assertions = [] | |
for i in range(mel_spec.shape[0]): | |
assertions.append(log_c[i, mel_lengths[i] :].sum() == 0.0) | |
self.assertTrue(all(assertions), "Incorrect masking") | |
assertions = [] | |
for i in range(mel_spec.shape[0]): | |
assertions.append(log_alpha_scaled[i, mel_lengths[i] :, : input_lengths[i]].sum() == 0.0) | |
self.assertTrue(all(assertions), "Incorrect masking") | |
def test_process_ar_timestep(self): | |
model = self._get_neural_hmm() | |
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() | |
h_post_prenet, c_post_prenet = model._init_lstm_states( # pylint: disable=protected-access | |
input_dummy.shape[0], config_global.memory_rnn_dim, mel_spec | |
) | |
h_post_prenet, c_post_prenet = model._process_ar_timestep( # pylint: disable=protected-access | |
1, | |
mel_spec, | |
h_post_prenet, | |
c_post_prenet, | |
) | |
self.assertEqual(h_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim)) | |
self.assertEqual(c_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim)) | |
def test_add_go_token(self): | |
model = self._get_neural_hmm() | |
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() | |
out = model._add_go_token(mel_spec) # pylint: disable=protected-access | |
self.assertEqual(out.shape, mel_spec.shape) | |
self.assertTrue((out[:, 1:] == mel_spec[:, :-1]).all(), "Go token not appended properly") | |
def test_forward_algorithm_variables(self): | |
model = self._get_neural_hmm() | |
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() | |
( | |
log_c, | |
log_alpha_scaled, | |
transition_matrix, | |
_, | |
) = model._initialize_forward_algorithm_variables( # pylint: disable=protected-access | |
mel_spec, input_dummy.shape[1] * config_global.state_per_phone | |
) | |
self.assertEqual(log_c.shape, (mel_spec.shape[0], mel_spec.shape[1])) | |
self.assertEqual( | |
log_alpha_scaled.shape, | |
( | |
mel_spec.shape[0], | |
mel_spec.shape[1], | |
input_dummy.shape[1] * config_global.state_per_phone, | |
), | |
) | |
self.assertEqual( | |
transition_matrix.shape, | |
(mel_spec.shape[0], mel_spec.shape[1], input_dummy.shape[1] * config_global.state_per_phone), | |
) | |
def test_get_absorption_state_scaling_factor(self): | |
model = self._get_neural_hmm() | |
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() | |
input_lengths = input_lengths * config_global.state_per_phone | |
( | |
log_c, | |
log_alpha_scaled, | |
transition_matrix, | |
_, | |
) = model._initialize_forward_algorithm_variables( # pylint: disable=protected-access | |
mel_spec, input_dummy.shape[1] * config_global.state_per_phone | |
) | |
log_alpha_scaled = torch.rand_like(log_alpha_scaled).clamp(1e-3) | |
transition_matrix = torch.randn_like(transition_matrix).sigmoid().log() | |
sum_final_log_c = model.get_absorption_state_scaling_factor( | |
mel_lengths, log_alpha_scaled, input_lengths, transition_matrix | |
) | |
text_mask = ~sequence_mask(input_lengths) | |
transition_prob_mask = ~model.get_mask_for_last_item(input_lengths, device=input_lengths.device) | |
outputs = [] | |
for i in range(input_dummy.shape[0]): | |
last_log_alpha_scaled = log_alpha_scaled[i, mel_lengths[i] - 1].masked_fill(text_mask[i], -float("inf")) | |
log_last_transition_probability = OverflowUtils.log_clamped( | |
torch.sigmoid(transition_matrix[i, mel_lengths[i] - 1]) | |
).masked_fill(transition_prob_mask[i], -float("inf")) | |
outputs.append(last_log_alpha_scaled + log_last_transition_probability) | |
sum_final_log_c_computed = torch.logsumexp(torch.stack(outputs), dim=1) | |
self.assertTrue(torch.isclose(sum_final_log_c_computed, sum_final_log_c).all()) | |
def test_inference(self): | |
model = self._get_neural_hmm() | |
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() | |
for temp in [0.334, 0.667, 1.0]: | |
outputs = model.inference( | |
input_dummy, input_lengths, temp, config_global.max_sampling_time, config_global.duration_threshold | |
) | |
self.assertEqual(outputs["hmm_outputs"].shape[-1], outputs["input_parameters"][0][0][0].shape[-1]) | |
self.assertEqual( | |
outputs["output_parameters"][0][0][0].shape[-1], outputs["input_parameters"][0][0][0].shape[-1] | |
) | |
self.assertEqual(len(outputs["alignments"]), input_dummy.shape[0]) | |
def test_emission_model(self): | |
model = self._get_emission_model() | |
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() | |
x_t = torch.randn(input_dummy.shape[0], config_global.out_channels).to(device) | |
means = torch.randn(input_dummy.shape[0], input_dummy.shape[1], config_global.out_channels).to(device) | |
std = torch.rand_like(means).to(device).clamp_(1e-3) # std should be positive | |
out = model(x_t, means, std, input_lengths) | |
self.assertEqual(out.shape, (input_dummy.shape[0], input_dummy.shape[1])) | |
# testing sampling | |
for temp in [0, 0.334, 0.667]: | |
out = model.sample(means, std, 0) | |
self.assertEqual(out.shape, means.shape) | |
if temp == 0: | |
self.assertTrue(torch.isclose(out, means).all()) | |
def test_transition_model(self): | |
model = self._get_transition_model() | |
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input() | |
prev_t_log_scaled_alph = torch.randn(input_dummy.shape[0], input_lengths.max()).to(device) | |
transition_vector = torch.randn(input_lengths.max()).to(device) | |
out = model(prev_t_log_scaled_alph, transition_vector, input_lengths) | |
self.assertEqual(out.shape, (input_dummy.shape[0], input_lengths.max())) | |
class TestOverflowOutputNet(unittest.TestCase): | |
def _get_outputnet(): | |
config = deepcopy(config_global) | |
outputnet = Outputnet( | |
config.encoder_in_out_features, | |
config.memory_rnn_dim, | |
config.out_channels, | |
config.outputnet_size, | |
config.flat_start_params, | |
config.std_floor, | |
).to(device) | |
return outputnet | |
def _get_embedded_input(): | |
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs() | |
input_dummy = torch.nn.Embedding(config_global.num_chars, config_global.encoder_in_out_features).to(device)( | |
input_dummy | |
) | |
one_timestep_frame = torch.randn(input_dummy.shape[0], config_global.memory_rnn_dim).to(device) | |
return input_dummy, one_timestep_frame | |
def test_outputnet_forward_with_flat_start(self): | |
model = self._get_outputnet() | |
input_dummy, one_timestep_frame = self._get_embedded_input() | |
mean, std, transition_vector = model(one_timestep_frame, input_dummy) | |
self.assertTrue(torch.isclose(mean, torch.tensor(model.flat_start_params["mean"] * 1.0)).all()) | |
self.assertTrue(torch.isclose(std, torch.tensor(model.flat_start_params["std"] * 1.0)).all()) | |
self.assertTrue( | |
torch.isclose( | |
transition_vector.sigmoid(), torch.tensor(model.flat_start_params["transition_p"] * 1.0) | |
).all() | |
) | |