|
import math |
|
import torch |
|
from torch import nn |
|
from torch.nn import Linear |
|
from utils.hparams import hparams |
|
from modules.commons.conv import ConvBlocks, ConditionalConvBlocks |
|
from modules.commons.common_layers import Embedding |
|
from modules.commons.rel_transformer import RelTransformerEncoder |
|
from modules.commons.transformer import MultiheadAttention, FFTBlocks |
|
from modules.commons.align_ops import clip_mel2token_to_multiple, build_word_mask, expand_states, mel2ph_to_mel2word |
|
from modules.tts.fastspeech import FS_DECODERS, FastSpeech |
|
from modules.portaspeech.fvae import SyntaFVAE, FVAE |
|
from utils.nn.seq_utils import group_hidden_by_segs |
|
from modules.fastspeech.tts_modules import SyntaDurationPredictor |
|
|
|
|
|
class SinusoidalPosEmb(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.dim = dim |
|
def forward(self, x): |
|
""" |
|
|
|
:param x: [B, T] |
|
:return: [B, T, H] |
|
""" |
|
device = x.device |
|
half_dim = self.dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
|
emb = x[:, :, None] * emb[None, :] |
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
return emb |
|
|
|
|
|
class SyntaSpeech(FastSpeech): |
|
def __init__(self, ph_dict_size, word_dict_size, out_dims=None): |
|
super().__init__(ph_dict_size, out_dims) |
|
|
|
if hparams['num_spk'] > 1: |
|
self.spk_embed_proj = Embedding(hparams['num_spk'], self.hidden_size) |
|
if hparams['use_word_encoder']: |
|
self.word_encoder = RelTransformerEncoder( |
|
word_dict_size, self.hidden_size, self.hidden_size, self.hidden_size, 2, |
|
hparams['word_enc_layers'], hparams['enc_ffn_kernel_size']) |
|
if hparams['dur_level'] == 'word': |
|
if hparams['word_encoder_type'] == 'rel_fft': |
|
self.ph2word_encoder = RelTransformerEncoder( |
|
0, self.hidden_size, self.hidden_size, self.hidden_size, 2, |
|
hparams['word_enc_layers'], hparams['enc_ffn_kernel_size']) |
|
if hparams['word_encoder_type'] == 'fft': |
|
self.ph2word_encoder = FFTBlocks( |
|
self.hidden_size, hparams['word_enc_layers'], 1, num_heads=hparams['num_heads']) |
|
self.sin_pos = SinusoidalPosEmb(self.hidden_size) |
|
self.enc_pos_proj = nn.Linear(2 * self.hidden_size, self.hidden_size) |
|
self.dec_query_proj = nn.Linear(2 * self.hidden_size, self.hidden_size) |
|
self.dec_res_proj = nn.Linear(2 * self.hidden_size, self.hidden_size) |
|
self.attn = MultiheadAttention(self.hidden_size, 1, encoder_decoder_attention=True, bias=False) |
|
self.attn.enable_torch_version = False |
|
if hparams['text_encoder_postnet']: |
|
self.text_encoder_postnet = ConvBlocks( |
|
self.hidden_size, self.hidden_size, [1] * 3, 5, layers_in_block=2) |
|
else: |
|
self.sin_pos = SinusoidalPosEmb(self.hidden_size) |
|
|
|
predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size |
|
self.dur_predictor = SyntaDurationPredictor( |
|
self.hidden_size, |
|
n_chans=predictor_hidden, |
|
n_layers=hparams['dur_predictor_layers'], |
|
dropout_rate=hparams['predictor_dropout'], |
|
kernel_size=hparams['dur_predictor_kernel']) |
|
|
|
if hparams['use_fvae']: |
|
del self.decoder |
|
del self.mel_out |
|
if hparams.get("use_gae_in_prior", True): |
|
self.fvae = SyntaFVAE( |
|
c_in_out=self.out_dims, |
|
hidden_size=hparams['fvae_enc_dec_hidden'], c_latent=hparams['latent_size'], |
|
kernel_size=hparams['fvae_kernel_size'], |
|
enc_n_layers=hparams['fvae_enc_n_layers'], |
|
dec_n_layers=hparams['fvae_dec_n_layers'], |
|
c_cond=self.hidden_size, |
|
use_prior_flow=hparams['use_prior_flow'], |
|
flow_hidden=hparams['prior_flow_hidden'], |
|
flow_kernel_size=hparams['prior_flow_kernel_size'], |
|
flow_n_steps=hparams['prior_flow_n_blocks'], |
|
strides=[hparams['fvae_strides']], |
|
encoder_type=hparams['fvae_encoder_type'], |
|
decoder_type=hparams['fvae_decoder_type'], |
|
) |
|
else: |
|
self.fvae = FVAE( |
|
c_in_out=self.out_dims, |
|
hidden_size=hparams['fvae_enc_dec_hidden'], c_latent=hparams['latent_size'], |
|
kernel_size=hparams['fvae_kernel_size'], |
|
enc_n_layers=hparams['fvae_enc_n_layers'], |
|
dec_n_layers=hparams['fvae_dec_n_layers'], |
|
c_cond=self.hidden_size, |
|
use_prior_flow=hparams['use_prior_flow'], |
|
flow_hidden=hparams['prior_flow_hidden'], |
|
flow_kernel_size=hparams['prior_flow_kernel_size'], |
|
flow_n_steps=hparams['prior_flow_n_blocks'], |
|
strides=[hparams['fvae_strides']], |
|
encoder_type=hparams['fvae_encoder_type'], |
|
decoder_type=hparams['fvae_decoder_type'], |
|
) |
|
else: |
|
self.decoder = FS_DECODERS[hparams['decoder_type']](hparams) |
|
self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True) |
|
if hparams['use_pitch_embed']: |
|
self.pitch_embed = Embedding(300, self.hidden_size, 0) |
|
if hparams['add_word_pos']: |
|
self.word_pos_proj = Linear(self.hidden_size, self.hidden_size) |
|
|
|
def build_embedding(self, dictionary, embed_dim): |
|
num_embeddings = len(dictionary) |
|
emb = Embedding(num_embeddings, embed_dim, self.padding_idx) |
|
return emb |
|
|
|
def forward(self, txt_tokens, word_tokens, ph2word, word_len, mel2word=None, mel2ph=None, |
|
spk_embed=None, spk_id=None, pitch=None, infer=False, tgt_mels=None, |
|
global_step=None, graph_lst=None, etypes_lst=None, *args, **kwargs): |
|
|
|
if hparams['use_spk_embed']: |
|
spk_embed = spk_embed |
|
elif hparams['use_spk_id']: |
|
spk_embed = self.spk_embed_proj(spk_id)[:, None, :] |
|
else: |
|
spk_embed = 0 |
|
|
|
ret = {} |
|
style_embed = self.forward_style_embed(spk_embed, spk_id) |
|
x, tgt_nonpadding = self.run_text_encoder( |
|
txt_tokens, word_tokens, ph2word, word_len, mel2word, mel2ph, style_embed, ret, graph_lst=graph_lst, etypes_lst=etypes_lst, **kwargs) |
|
x = x + style_embed |
|
x = x * tgt_nonpadding |
|
ret['nonpadding'] = tgt_nonpadding |
|
if hparams['use_pitch_embed']: |
|
x = x + self.pitch_embed(pitch) |
|
ret['decoder_inp'] = x |
|
if infer and (mel2ph is None or mel2word is None): |
|
mel2word = ret['mel2word'] |
|
ret['mel_out_fvae'] = ret['mel_out'] = self.run_decoder(x, tgt_nonpadding, ret, infer, tgt_mels, global_step, |
|
mel2word=mel2word, ph2word=ph2word, graph_lst=graph_lst, etypes_lst=etypes_lst) |
|
return ret |
|
|
|
def run_text_encoder(self, txt_tokens, word_tokens, ph2word, word_len, mel2word, mel2ph, style_embed, ret, graph_lst, etypes_lst, **kwargs): |
|
word2word = torch.arange(word_len)[None, :].to(ph2word.device) + 1 |
|
src_nonpadding = (txt_tokens > 0).float()[:, :, None] |
|
use_bert = hparams.get("use_bert") is True |
|
if use_bert: |
|
ph_encoder_out = self.encoder(txt_tokens, bert_feats=kwargs['bert_feats'], ph2word=ph2word, |
|
graph_lst=graph_lst, etypes_lst=etypes_lst, |
|
cl_feats=kwargs['cl_feats'], ret=ret) * src_nonpadding + style_embed |
|
else: |
|
ph_encoder_out = self.encoder(txt_tokens) * src_nonpadding + style_embed |
|
if hparams['use_word_encoder']: |
|
word_encoder_out = self.word_encoder(word_tokens) + style_embed |
|
ph_encoder_out = ph_encoder_out + expand_states(word_encoder_out, ph2word) |
|
|
|
dur_input = ph_encoder_out * src_nonpadding |
|
if hparams['dur_level'] == 'word': |
|
word_encoder_out = 0 |
|
h_ph_gb_word = group_hidden_by_segs(ph_encoder_out, ph2word, word_len)[0] |
|
word_encoder_out = word_encoder_out + self.ph2word_encoder(h_ph_gb_word) |
|
if hparams['use_word_encoder']: |
|
word_encoder_out = word_encoder_out + self.word_encoder(word_tokens) |
|
mel2word = self.forward_dur(dur_input, mel2word, ret, ph2word=ph2word, word_len=word_len, graph_lst=graph_lst, etypes_lst=etypes_lst) |
|
mel2word = clip_mel2token_to_multiple(mel2word, hparams['frames_multiple']) |
|
ret['mel2word'] = mel2word |
|
tgt_nonpadding = (mel2word > 0).float()[:, :, None] |
|
enc_pos = self.get_pos_embed(word2word, ph2word) |
|
dec_pos = self.get_pos_embed(word2word, mel2word) |
|
dec_word_mask = build_word_mask(mel2word, ph2word) |
|
x, weight = self.attention(ph_encoder_out, enc_pos, word_encoder_out, dec_pos, mel2word, dec_word_mask) |
|
if hparams['add_word_pos']: |
|
x = x + self.word_pos_proj(dec_pos) |
|
ret['attn'] = weight |
|
else: |
|
mel2ph = self.forward_dur(dur_input, mel2ph, ret) |
|
mel2ph = clip_mel2token_to_multiple(mel2ph, hparams['frames_multiple']) |
|
mel2word = mel2ph_to_mel2word(mel2ph, ph2word) |
|
x = expand_states(ph_encoder_out, mel2ph) |
|
if hparams['add_word_pos']: |
|
dec_pos = self.get_pos_embed(word2word, mel2word) |
|
x = x + self.word_pos_proj(dec_pos) |
|
tgt_nonpadding = (mel2ph > 0).float()[:, :, None] |
|
if hparams['use_word_encoder']: |
|
x = x + expand_states(word_encoder_out, mel2word) |
|
return x, tgt_nonpadding |
|
|
|
def attention(self, ph_encoder_out, enc_pos, word_encoder_out, dec_pos, mel2word, dec_word_mask): |
|
ph_kv = self.enc_pos_proj(torch.cat([ph_encoder_out, enc_pos], -1)) |
|
word_enc_out_expend = expand_states(word_encoder_out, mel2word) |
|
word_enc_out_expend = torch.cat([word_enc_out_expend, dec_pos], -1) |
|
if hparams['text_encoder_postnet']: |
|
word_enc_out_expend = self.dec_res_proj(word_enc_out_expend) |
|
word_enc_out_expend = self.text_encoder_postnet(word_enc_out_expend) |
|
dec_q = x_res = word_enc_out_expend |
|
else: |
|
dec_q = self.dec_query_proj(word_enc_out_expend) |
|
x_res = self.dec_res_proj(word_enc_out_expend) |
|
ph_kv, dec_q = ph_kv.transpose(0, 1), dec_q.transpose(0, 1) |
|
x, (weight, _) = self.attn(dec_q, ph_kv, ph_kv, attn_mask=(1 - dec_word_mask) * -1e9) |
|
x = x.transpose(0, 1) |
|
x = x + x_res |
|
return x, weight |
|
|
|
def run_decoder(self, x, tgt_nonpadding, ret, infer, tgt_mels=None, global_step=0, |
|
mel2word=None, ph2word=None, graph_lst=None, etypes_lst=None): |
|
if not hparams['use_fvae']: |
|
x = self.decoder(x) |
|
x = self.mel_out(x) |
|
ret['kl'] = 0 |
|
return x * tgt_nonpadding |
|
else: |
|
|
|
x = x.transpose(1, 2) |
|
tgt_nonpadding_BHT = tgt_nonpadding.transpose(1, 2) |
|
if infer: |
|
z = self.fvae(cond=x, infer=True, mel2word=mel2word, ph2word=ph2word, graph_lst=graph_lst, etypes_lst=etypes_lst) |
|
else: |
|
tgt_mels = tgt_mels.transpose(1, 2) |
|
z, ret['kl'], ret['z_p'], ret['m_q'], ret['logs_q'] = self.fvae( |
|
tgt_mels, tgt_nonpadding_BHT, cond=x, mel2word=mel2word, ph2word=ph2word, graph_lst=graph_lst, etypes_lst=etypes_lst) |
|
if global_step < hparams['posterior_start_steps']: |
|
z = torch.randn_like(z) |
|
x_recon = self.fvae.decoder(z, nonpadding=tgt_nonpadding_BHT, cond=x).transpose(1, 2) |
|
ret['pre_mel_out'] = x_recon |
|
return x_recon |
|
|
|
def forward_dur(self, dur_input, mel2word, ret, **kwargs): |
|
""" |
|
|
|
:param dur_input: [B, T_txt, H] |
|
:param mel2ph: [B, T_mel] |
|
:param txt_tokens: [B, T_txt] |
|
:param ret: |
|
:return: |
|
""" |
|
word_len = kwargs['word_len'] |
|
ph2word = kwargs['ph2word'] |
|
graph_lst = kwargs['graph_lst'] |
|
etypes_lst = kwargs['etypes_lst'] |
|
src_padding = dur_input.data.abs().sum(-1) == 0 |
|
dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach()) |
|
dur = self.dur_predictor(dur_input, src_padding, ph2word, graph_lst, etypes_lst) |
|
|
|
B, T_ph = ph2word.shape |
|
dur = torch.zeros([B, word_len.max() + 1]).to(ph2word.device).scatter_add(1, ph2word, dur) |
|
dur = dur[:, 1:] |
|
ret['dur'] = dur |
|
if mel2word is None: |
|
mel2word = self.length_regulator(dur).detach() |
|
return mel2word |
|
|
|
def get_pos_embed(self, word2word, x2word): |
|
x_pos = build_word_mask(word2word, x2word).float() |
|
x_pos = (x_pos.cumsum(-1) / x_pos.sum(-1).clamp(min=1)[..., None] * x_pos).sum(1) |
|
x_pos = self.sin_pos(x_pos.float()) |
|
return x_pos |
|
|
|
def store_inverse_all(self): |
|
def remove_weight_norm(m): |
|
try: |
|
if hasattr(m, 'store_inverse'): |
|
m.store_inverse() |
|
nn.utils.remove_weight_norm(m) |
|
except ValueError: |
|
return |
|
|
|
self.apply(remove_weight_norm) |
|
|