Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.nn import Linear | |
from modules.commons.conv import ConvBlocks, ConditionalConvBlocks | |
from modules.commons.layers import Embedding | |
from modules.commons.rel_transformer import RelTransformerEncoder | |
from modules.commons.transformer import MultiheadAttention, FFTBlocks | |
from modules.tts.commons.align_ops import clip_mel2token_to_multiple, build_word_mask, expand_states, mel2ph_to_mel2word | |
from modules.tts.fs import FS_DECODERS, FastSpeech | |
from modules.tts.portaspeech.fvae import FVAE | |
from utils.commons.meters import Timer | |
from utils.nn.seq_utils import group_hidden_by_segs | |
class SinusoidalPosEmb(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
""" | |
:param x: [B, T] | |
:return: [B, T, H] | |
""" | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, :, None] * emb[None, :] | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
class PortaSpeech(FastSpeech): | |
def __init__(self, ph_dict_size, word_dict_size, hparams, out_dims=None): | |
super().__init__(ph_dict_size, hparams, out_dims) | |
# build linguistic encoder | |
if hparams['use_word_encoder']: | |
self.word_encoder = RelTransformerEncoder( | |
word_dict_size, self.hidden_size, self.hidden_size, self.hidden_size, 2, | |
hparams['word_enc_layers'], hparams['enc_ffn_kernel_size']) | |
if hparams['dur_level'] == 'word': | |
if hparams['word_encoder_type'] == 'rel_fft': | |
self.ph2word_encoder = RelTransformerEncoder( | |
0, self.hidden_size, self.hidden_size, self.hidden_size, 2, | |
hparams['word_enc_layers'], hparams['enc_ffn_kernel_size']) | |
if hparams['word_encoder_type'] == 'fft': | |
self.ph2word_encoder = FFTBlocks( | |
self.hidden_size, hparams['word_enc_layers'], 1, num_heads=hparams['num_heads']) | |
self.sin_pos = SinusoidalPosEmb(self.hidden_size) | |
self.enc_pos_proj = nn.Linear(2 * self.hidden_size, self.hidden_size) | |
self.dec_query_proj = nn.Linear(2 * self.hidden_size, self.hidden_size) | |
self.dec_res_proj = nn.Linear(2 * self.hidden_size, self.hidden_size) | |
self.attn = MultiheadAttention(self.hidden_size, 1, encoder_decoder_attention=True, bias=False) | |
self.attn.enable_torch_version = False | |
if hparams['text_encoder_postnet']: | |
self.text_encoder_postnet = ConvBlocks( | |
self.hidden_size, self.hidden_size, [1] * 3, 5, layers_in_block=2) | |
else: | |
self.sin_pos = SinusoidalPosEmb(self.hidden_size) | |
# build VAE decoder | |
if hparams['use_fvae']: | |
del self.decoder | |
del self.mel_out | |
self.fvae = FVAE( | |
c_in_out=self.out_dims, | |
hidden_size=hparams['fvae_enc_dec_hidden'], c_latent=hparams['latent_size'], | |
kernel_size=hparams['fvae_kernel_size'], | |
enc_n_layers=hparams['fvae_enc_n_layers'], | |
dec_n_layers=hparams['fvae_dec_n_layers'], | |
c_cond=self.hidden_size, | |
use_prior_flow=hparams['use_prior_flow'], | |
flow_hidden=hparams['prior_flow_hidden'], | |
flow_kernel_size=hparams['prior_flow_kernel_size'], | |
flow_n_steps=hparams['prior_flow_n_blocks'], | |
strides=[hparams['fvae_strides']], | |
encoder_type=hparams['fvae_encoder_type'], | |
decoder_type=hparams['fvae_decoder_type'], | |
) | |
else: | |
self.decoder = FS_DECODERS[hparams['decoder_type']](hparams) | |
self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True) | |
if hparams['use_pitch_embed']: | |
self.pitch_embed = Embedding(300, self.hidden_size, 0) | |
if self.hparams['add_word_pos']: | |
self.word_pos_proj = Linear(self.hidden_size, self.hidden_size) | |
def build_embedding(self, dictionary, embed_dim): | |
num_embeddings = len(dictionary) | |
emb = Embedding(num_embeddings, embed_dim, self.padding_idx) | |
return emb | |
def forward(self, txt_tokens, word_tokens, ph2word, word_len, mel2word=None, mel2ph=None, | |
spk_embed=None, spk_id=None, pitch=None, infer=False, tgt_mels=None, | |
global_step=None, *args, **kwargs): | |
ret = {} | |
x, tgt_nonpadding = self.run_text_encoder( | |
txt_tokens, word_tokens, ph2word, word_len, mel2word, mel2ph, ret) | |
style_embed = self.forward_style_embed(spk_embed, spk_id) | |
x = x + style_embed | |
x = x * tgt_nonpadding | |
ret['nonpadding'] = tgt_nonpadding | |
if self.hparams['use_pitch_embed']: | |
x = x + self.pitch_embed(pitch) | |
ret['decoder_inp'] = x | |
ret['mel_out_fvae'] = ret['mel_out'] = self.run_decoder(x, tgt_nonpadding, ret, infer, tgt_mels, global_step) | |
return ret | |
def run_text_encoder(self, txt_tokens, word_tokens, ph2word, word_len, mel2word, mel2ph, ret): | |
word2word = torch.arange(word_len)[None, :].to(ph2word.device) + 1 # [B, T_mel, T_word] | |
src_nonpadding = (txt_tokens > 0).float()[:, :, None] | |
ph_encoder_out = self.encoder(txt_tokens) * src_nonpadding | |
if self.hparams['use_word_encoder']: | |
word_encoder_out = self.word_encoder(word_tokens) | |
ph_encoder_out = ph_encoder_out + expand_states(word_encoder_out, ph2word) | |
if self.hparams['dur_level'] == 'word': | |
word_encoder_out = 0 | |
h_ph_gb_word = group_hidden_by_segs(ph_encoder_out, ph2word, word_len)[0] | |
word_encoder_out = word_encoder_out + self.ph2word_encoder(h_ph_gb_word) | |
if self.hparams['use_word_encoder']: | |
word_encoder_out = word_encoder_out + self.word_encoder(word_tokens) | |
mel2word = self.forward_dur(ph_encoder_out, mel2word, ret, ph2word=ph2word, word_len=word_len) | |
mel2word = clip_mel2token_to_multiple(mel2word, self.hparams['frames_multiple']) | |
tgt_nonpadding = (mel2word > 0).float()[:, :, None] | |
enc_pos = self.get_pos_embed(word2word, ph2word) # [B, T_ph, H] | |
dec_pos = self.get_pos_embed(word2word, mel2word) # [B, T_mel, H] | |
dec_word_mask = build_word_mask(mel2word, ph2word) # [B, T_mel, T_ph] | |
x, weight = self.attention(ph_encoder_out, enc_pos, word_encoder_out, dec_pos, mel2word, dec_word_mask) | |
if self.hparams['add_word_pos']: | |
x = x + self.word_pos_proj(dec_pos) | |
ret['attn'] = weight | |
else: | |
mel2ph = self.forward_dur(ph_encoder_out, mel2ph, ret) | |
mel2ph = clip_mel2token_to_multiple(mel2ph, self.hparams['frames_multiple']) | |
mel2word = mel2ph_to_mel2word(mel2ph, ph2word) | |
x = expand_states(ph_encoder_out, mel2ph) | |
if self.hparams['add_word_pos']: | |
dec_pos = self.get_pos_embed(word2word, mel2word) # [B, T_mel, H] | |
x = x + self.word_pos_proj(dec_pos) | |
tgt_nonpadding = (mel2ph > 0).float()[:, :, None] | |
if self.hparams['use_word_encoder']: | |
x = x + expand_states(word_encoder_out, mel2word) | |
return x, tgt_nonpadding | |
def attention(self, ph_encoder_out, enc_pos, word_encoder_out, dec_pos, mel2word, dec_word_mask): | |
ph_kv = self.enc_pos_proj(torch.cat([ph_encoder_out, enc_pos], -1)) | |
word_enc_out_expend = expand_states(word_encoder_out, mel2word) | |
word_enc_out_expend = torch.cat([word_enc_out_expend, dec_pos], -1) | |
if self.hparams['text_encoder_postnet']: | |
word_enc_out_expend = self.dec_res_proj(word_enc_out_expend) | |
word_enc_out_expend = self.text_encoder_postnet(word_enc_out_expend) | |
dec_q = x_res = word_enc_out_expend | |
else: | |
dec_q = self.dec_query_proj(word_enc_out_expend) | |
x_res = self.dec_res_proj(word_enc_out_expend) | |
ph_kv, dec_q = ph_kv.transpose(0, 1), dec_q.transpose(0, 1) | |
x, (weight, _) = self.attn(dec_q, ph_kv, ph_kv, attn_mask=(1 - dec_word_mask) * -1e9) | |
x = x.transpose(0, 1) | |
x = x + x_res | |
return x, weight | |
def run_decoder(self, x, tgt_nonpadding, ret, infer, tgt_mels=None, global_step=0): | |
if not self.hparams['use_fvae']: | |
x = self.decoder(x) | |
x = self.mel_out(x) | |
ret['kl'] = 0 | |
return x * tgt_nonpadding | |
else: | |
decoder_inp = x | |
x = x.transpose(1, 2) # [B, H, T] | |
tgt_nonpadding_BHT = tgt_nonpadding.transpose(1, 2) # [B, H, T] | |
if infer: | |
z = self.fvae(cond=x, infer=True) | |
else: | |
tgt_mels = tgt_mels.transpose(1, 2) # [B, 80, T] | |
z, ret['kl'], ret['z_p'], ret['m_q'], ret['logs_q'] = self.fvae( | |
tgt_mels, tgt_nonpadding_BHT, cond=x) | |
if global_step < self.hparams['posterior_start_steps']: | |
z = torch.randn_like(z) | |
x_recon = self.fvae.decoder(z, nonpadding=tgt_nonpadding_BHT, cond=x).transpose(1, 2) | |
ret['pre_mel_out'] = x_recon | |
return x_recon | |
def forward_dur(self, dur_input, mel2word, ret, **kwargs): | |
""" | |
:param dur_input: [B, T_txt, H] | |
:param mel2ph: [B, T_mel] | |
:param txt_tokens: [B, T_txt] | |
:param ret: | |
:return: | |
""" | |
src_padding = dur_input.data.abs().sum(-1) == 0 | |
dur_input = dur_input.detach() + self.hparams['predictor_grad'] * (dur_input - dur_input.detach()) | |
dur = self.dur_predictor(dur_input, src_padding) | |
if self.hparams['dur_level'] == 'word': | |
word_len = kwargs['word_len'] | |
ph2word = kwargs['ph2word'] | |
B, T_ph = ph2word.shape | |
dur = torch.zeros([B, word_len.max() + 1]).to(ph2word.device).scatter_add(1, ph2word, dur) | |
dur = dur[:, 1:] | |
ret['dur'] = dur | |
if mel2word is None: | |
mel2word = self.length_regulator(dur).detach() | |
return mel2word | |
def get_pos_embed(self, word2word, x2word): | |
x_pos = build_word_mask(word2word, x2word).float() # [B, T_word, T_ph] | |
x_pos = (x_pos.cumsum(-1) / x_pos.sum(-1).clamp(min=1)[..., None] * x_pos).sum(1) | |
x_pos = self.sin_pos(x_pos.float()) # [B, T_ph, H] | |
return x_pos | |
def store_inverse_all(self): | |
def remove_weight_norm(m): | |
try: | |
if hasattr(m, 'store_inverse'): | |
m.store_inverse() | |
nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
self.apply(remove_weight_norm) | |