Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from modules.commons.layers import Embedding | |
from modules.commons.nar_tts_modules import EnergyPredictor, PitchPredictor | |
from modules.tts.commons.align_ops import expand_states | |
from modules.tts.fs import FastSpeech | |
from utils.audio.cwt import cwt2f0, get_lf0_cwt | |
from utils.audio.pitch.utils import denorm_f0, f0_to_coarse, norm_f0 | |
import numpy as np | |
class FastSpeech2Orig(FastSpeech): | |
def __init__(self, dict_size, hparams, out_dims=None): | |
super().__init__(dict_size, hparams, out_dims) | |
predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size | |
if hparams['use_energy_embed']: | |
self.energy_embed = Embedding(300, self.hidden_size, 0) | |
self.energy_predictor = EnergyPredictor( | |
self.hidden_size, n_chans=predictor_hidden, | |
n_layers=5, dropout_rate=0.1, odim=2, | |
kernel_size=hparams['predictor_kernel']) | |
if hparams['pitch_type'] == 'cwt' and hparams['use_pitch_embed']: | |
self.pitch_predictor = PitchPredictor( | |
self.hidden_size, n_chans=predictor_hidden, | |
n_layers=5, dropout_rate=0.1, odim=11, | |
kernel_size=hparams['predictor_kernel']) | |
self.cwt_stats_layers = nn.Sequential( | |
nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(), | |
nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(), nn.Linear(self.hidden_size, 2)) | |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None, spk_id=None, | |
f0=None, uv=None, energy=None, infer=False, **kwargs): | |
ret = {} | |
encoder_out = self.encoder(txt_tokens) # [B, T, C] | |
src_nonpadding = (txt_tokens > 0).float()[:, :, None] | |
style_embed = self.forward_style_embed(spk_embed, spk_id) | |
# add dur | |
dur_inp = (encoder_out + style_embed) * src_nonpadding | |
mel2ph = self.forward_dur(dur_inp, mel2ph, txt_tokens, ret) | |
tgt_nonpadding = (mel2ph > 0).float()[:, :, None] | |
decoder_inp = decoder_inp_ = expand_states(encoder_out, mel2ph) | |
# add pitch and energy embed | |
if self.hparams['use_pitch_embed']: | |
pitch_inp = (decoder_inp_ + style_embed) * tgt_nonpadding | |
decoder_inp = decoder_inp + self.forward_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out) | |
# add pitch and energy embed | |
if self.hparams['use_energy_embed']: | |
energy_inp = (decoder_inp_ + style_embed) * tgt_nonpadding | |
decoder_inp = decoder_inp + self.forward_energy(energy_inp, energy, ret) | |
# decoder input | |
ret['decoder_inp'] = decoder_inp = (decoder_inp + style_embed) * tgt_nonpadding | |
if self.hparams['dec_inp_add_noise']: | |
B, T, _ = decoder_inp.shape | |
z = kwargs.get('adv_z', torch.randn([B, T, self.z_channels])).to(decoder_inp.device) | |
ret['adv_z'] = z | |
decoder_inp = torch.cat([decoder_inp, z], -1) | |
decoder_inp = self.dec_inp_noise_proj(decoder_inp) * tgt_nonpadding | |
ret['mel_out'] = self.forward_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs) | |
return ret | |
def forward_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None): | |
if self.hparams['pitch_type'] == 'cwt': | |
decoder_inp = decoder_inp.detach() + self.hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach()) | |
pitch_padding = mel2ph == 0 | |
ret['cwt'] = cwt_out = self.pitch_predictor(decoder_inp) | |
stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2] | |
mean = ret['f0_mean'] = stats_out[:, 0] | |
std = ret['f0_std'] = stats_out[:, 1] | |
cwt_spec = cwt_out[:, :, :10] | |
if f0 is None: | |
std = std * self.hparams['cwt_std_scale'] | |
f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph) | |
if self.hparams['use_uv']: | |
assert cwt_out.shape[-1] == 11 | |
uv = cwt_out[:, :, -1] > 0 | |
ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv if self.hparams['use_uv'] else None, | |
pitch_padding=pitch_padding) | |
pitch = f0_to_coarse(f0_denorm) # start from 0 | |
pitch_embed = self.pitch_embed(pitch) | |
return pitch_embed | |
else: | |
return super(FastSpeech2Orig, self).forward_pitch(decoder_inp, f0, uv, mel2ph, ret, encoder_out) | |
def forward_energy(self, decoder_inp, energy, ret): | |
decoder_inp = decoder_inp.detach() + self.hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach()) | |
ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0] | |
energy_embed_inp = energy_pred if energy is None else energy | |
energy_embed_inp = torch.clamp(energy_embed_inp * 256 // 4, min=0, max=255).long() | |
energy_embed = self.energy_embed(energy_embed_inp) | |
return energy_embed | |
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph): | |
_, cwt_scales = get_lf0_cwt(np.ones(10)) | |
f0 = cwt2f0(cwt_spec, mean, std, cwt_scales) | |
f0 = torch.cat( | |
[f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1) | |
f0_norm = norm_f0(f0, None) | |
return f0_norm | |