JustinLin610
update
10b0761
raw
history blame
26 kB
from math import sqrt
import torch
import torch.distributions as distr
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from .layers import ConvNorm, LinearNorm, GlobalAvgPool
from .utils import to_gpu, get_mask_from_lengths
class LocationLayer(nn.Module):
def __init__(self, attention_n_filters, attention_kernel_size,
attention_dim):
super(LocationLayer, self).__init__()
padding = int((attention_kernel_size - 1) / 2)
self.location_conv = ConvNorm(2, attention_n_filters,
kernel_size=attention_kernel_size,
padding=padding, bias=False, stride=1,
dilation=1)
self.location_dense = LinearNorm(attention_n_filters, attention_dim,
bias=False, w_init_gain='tanh')
def forward(self, attention_weights_cat):
processed_attention = self.location_conv(attention_weights_cat)
processed_attention = processed_attention.transpose(1, 2)
processed_attention = self.location_dense(processed_attention)
return processed_attention
class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size):
super(Attention, self).__init__()
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
bias=False, w_init_gain='tanh')
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
w_init_gain='tanh')
self.v = LinearNorm(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(attention_location_n_filters,
attention_location_kernel_size,
attention_dim)
self.score_mask_value = -float("inf")
def get_alignment_energies(self, query, processed_memory,
attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(torch.tanh(
processed_query + processed_attention_weights + processed_memory))
energies = energies.squeeze(-1)
return energies
def forward(self, attention_hidden_state, memory, processed_memory,
attention_weights_cat, mask):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
alignment = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat)
if mask is not None:
alignment.data.masked_fill_(mask, self.score_mask_value)
attention_weights = F.softmax(alignment, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights
class Prenet(nn.Module):
def __init__(self, in_dim, sizes):
super(Prenet, self).__init__()
in_sizes = [in_dim] + sizes[:-1]
self.layers = nn.ModuleList(
[LinearNorm(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_sizes, sizes)])
def forward(self, x):
for linear in self.layers:
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
return x
class Postnet(nn.Module):
"""Postnet
- Five 1-d convolution with 512 channels and kernel size 5
"""
def __init__(self, hparams):
super(Postnet, self).__init__()
self.convolutions = nn.ModuleList()
self.convolutions.append(
nn.Sequential(
ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
kernel_size=hparams.postnet_kernel_size, stride=1,
padding=int((hparams.postnet_kernel_size - 1) / 2),
dilation=1, w_init_gain='tanh'),
nn.BatchNorm1d(hparams.postnet_embedding_dim))
)
for i in range(1, hparams.postnet_n_convolutions - 1):
self.convolutions.append(
nn.Sequential(
ConvNorm(hparams.postnet_embedding_dim,
hparams.postnet_embedding_dim,
kernel_size=hparams.postnet_kernel_size, stride=1,
padding=int((hparams.postnet_kernel_size - 1) / 2),
dilation=1, w_init_gain='tanh'),
nn.BatchNorm1d(hparams.postnet_embedding_dim))
)
self.convolutions.append(
nn.Sequential(
ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
kernel_size=hparams.postnet_kernel_size, stride=1,
padding=int((hparams.postnet_kernel_size - 1) / 2),
dilation=1, w_init_gain='linear'),
nn.BatchNorm1d(hparams.n_mel_channels))
)
def forward(self, x):
for i in range(len(self.convolutions) - 1):
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
return x
class Encoder(nn.Module):
"""Encoder module:
- Three 1-d convolution banks
- Bidirectional LSTM
"""
def __init__(self, hparams):
super(Encoder, self).__init__()
convolutions = []
for _ in range(hparams.encoder_n_convolutions):
conv_layer = nn.Sequential(
ConvNorm(hparams.encoder_embedding_dim,
hparams.encoder_embedding_dim,
kernel_size=hparams.encoder_kernel_size, stride=1,
padding=int((hparams.encoder_kernel_size - 1) / 2),
dilation=1, w_init_gain='relu'),
nn.BatchNorm1d(hparams.encoder_embedding_dim))
convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions)
self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
int(hparams.encoder_embedding_dim / 2), 1,
batch_first=True, bidirectional=True)
def forward(self, x, input_lengths):
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
x = x.transpose(1, 2)
# pytorch tensor are not reversible, hence the conversion
input_lengths = input_lengths.cpu().numpy()
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True)
return outputs
def inference(self, x):
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
return outputs
class AudioEncoder(nn.Module):
def __init__(self, hparams):
super(AudioEncoder, self).__init__()
assert hparams.lat_dim > 0
convolutions = []
inp_dim = hparams.n_mel_channels
for _ in range(hparams.lat_n_convolutions):
conv_layer = nn.Sequential(
ConvNorm(inp_dim, hparams.lat_n_filters,
kernel_size=hparams.lat_kernel_size, stride=1,
padding=int((hparams.lat_kernel_size - 1) / 2),
dilation=1, w_init_gain='tanh'),
nn.BatchNorm1d(hparams.lat_n_filters))
inp_dim = hparams.lat_n_filters
convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions)
self.lstm = nn.LSTM(hparams.lat_n_filters,
int(hparams.lat_n_filters / 2),
hparams.lat_n_blstms, batch_first=True,
bidirectional=True)
self.pool = GlobalAvgPool()
self.mu_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim)
self.logvar_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim)
self.lat_dim = hparams.lat_dim
def forward(self, x, lengths):
"""
Args:
x (torch.Tensor): (B, F, T)
"""
for conv in self.convolutions:
x = F.dropout(F.tanh(conv(x)), 0.5, self.training)
x = x.transpose(1, 2) # (B, T, D)
# x may not be sorted by length. Sort->process->unsort
max_len = x.size(1)
assert max_len == torch.max(lengths).item()
lengths, perm_idx = lengths.sort(0, descending=True)
x = x[perm_idx]
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
_, unperm_idx = perm_idx.sort(0)
outputs = outputs[unperm_idx] # (B, T, D)
lengths = lengths[unperm_idx] # (B, T, D)
outputs = self.pool(outputs, lengths) # (B, D)
mu = self.mu_proj(outputs)
logvar = self.logvar_proj(outputs)
z = distr.Normal(mu, logvar).rsample()
return z, mu, logvar
class Decoder(nn.Module):
def __init__(self, hparams):
super(Decoder, self).__init__()
self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step
self.encoder_embedding_dim = hparams.encoder_embedding_dim
self.obs_dim = hparams.obs_dim
self.lat_dim = hparams.lat_dim
self.attention_rnn_dim = hparams.attention_rnn_dim
self.decoder_rnn_dim = hparams.decoder_rnn_dim
self.prenet_dim = hparams.prenet_dim
self.max_decoder_steps = hparams.max_decoder_steps
self.gate_threshold = hparams.gate_threshold
self.p_attention_dropout = hparams.p_attention_dropout
self.p_decoder_dropout = hparams.p_decoder_dropout
self.prenet = Prenet(
hparams.n_mel_channels * hparams.n_frames_per_step,
[hparams.prenet_dim, hparams.prenet_dim])
self.attention_rnn = nn.LSTMCell(
hparams.prenet_dim + hparams.encoder_embedding_dim,
hparams.attention_rnn_dim)
self.attention_layer = Attention(
hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
hparams.attention_dim, hparams.attention_location_n_filters,
hparams.attention_location_kernel_size)
encoder_tot_dim = (hparams.encoder_embedding_dim + \
hparams.lat_dim + hparams.obs_dim)
self.decoder_rnn = nn.LSTMCell(
hparams.attention_rnn_dim + encoder_tot_dim,
hparams.decoder_rnn_dim, 1)
self.linear_projection = LinearNorm(
hparams.decoder_rnn_dim + encoder_tot_dim,
hparams.n_mel_channels * hparams.n_frames_per_step)
self.gate_layer = LinearNorm(
hparams.decoder_rnn_dim + encoder_tot_dim, 1,
bias=True, w_init_gain='sigmoid')
def get_go_frame(self, memory):
""" Gets all zeros frames to use as first decoder input
PARAMS
------
memory: decoder outputs
RETURNS
-------
decoder_input: all zeros frames
"""
B = memory.size(0)
decoder_input = Variable(memory.data.new(
B, self.n_mel_channels * self.n_frames_per_step).zero_())
return decoder_input
def initialize_decoder_states(self, memory, obs_and_lat, mask):
""" Initializes attention rnn states, decoder rnn states, attention
weights, attention cumulative weights, attention context, stores memory
and stores processed memory
PARAMS
------
memory: Encoder outputs
obs_and_lat: Observed and latent attribute embeddings
mask: Mask for padded data if training, expects None for inference
"""
B = memory.size(0)
MAX_TIME = memory.size(1)
self.attention_hidden = Variable(memory.data.new(
B, self.attention_rnn_dim).zero_())
self.attention_cell = Variable(memory.data.new(
B, self.attention_rnn_dim).zero_())
self.decoder_hidden = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
self.decoder_cell = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
self.attention_weights = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_weights_cum = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_context = Variable(memory.data.new(
B, self.encoder_embedding_dim).zero_())
self.memory = memory
self.processed_memory = self.attention_layer.memory_layer(memory)
self.obs_and_lat = obs_and_lat
self.mask = mask
def parse_decoder_inputs(self, decoder_inputs):
""" Prepares decoder inputs, i.e. mel outputs
PARAMS
------
decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
RETURNS
-------
inputs: processed decoder inputs
"""
# (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
decoder_inputs = decoder_inputs.transpose(1, 2)
decoder_inputs = decoder_inputs.view(
decoder_inputs.size(0),
int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
# (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
decoder_inputs = decoder_inputs.transpose(0, 1)
return decoder_inputs
def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
""" Prepares decoder outputs for output
PARAMS
------
mel_outputs:
gate_outputs: gate output energies
alignments:
RETURNS
-------
mel_outputs:
gate_outpust: gate output energies
alignments:
"""
# (T_out, B) -> (B, T_out)
alignments = torch.stack(alignments).transpose(0, 1)
# (T_out, B) -> (B, T_out)
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
gate_outputs = gate_outputs.contiguous()
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
# decouple frames per step
mel_outputs = mel_outputs.view(
mel_outputs.size(0), -1, self.n_mel_channels)
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
mel_outputs = mel_outputs.transpose(1, 2)
return mel_outputs, gate_outputs, alignments
def decode(self, decoder_input):
""" Decoder step using stored states, attention and memory
PARAMS
------
decoder_input: previous mel output
RETURNS
-------
mel_output:
gate_output: gate output energies
attention_weights:
"""
cell_input = torch.cat((decoder_input, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(
self.attention_hidden, self.p_attention_dropout, self.training)
attention_weights_cat = torch.cat(
(self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)), dim=1)
self.attention_context, self.attention_weights = self.attention_layer(
self.attention_hidden, self.memory, self.processed_memory,
attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
decoder_input = torch.cat(
(self.attention_hidden, self.attention_context), -1)
if self.obs_and_lat is not None:
decoder_input = torch.cat((decoder_input, self.obs_and_lat), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(
self.decoder_hidden, self.p_decoder_dropout, self.training)
decoder_hidden_attention_context = torch.cat(
(self.decoder_hidden, self.attention_context), dim=1)
if self.obs_and_lat is not None:
decoder_hidden_attention_context = torch.cat(
(decoder_hidden_attention_context, self.obs_and_lat), dim=1)
decoder_output = self.linear_projection(
decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights
def forward(self, memory, obs_and_lat, decoder_inputs, memory_lengths):
""" Decoder forward pass for training
PARAMS
------
memory: Encoder outputs
obs_and_lat: Observed and latent attribute embeddings
decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
memory_lengths: Encoder output lengths for attention masking.
RETURNS
-------
mel_outputs: mel outputs from the decoder
gate_outputs: gate outputs from the decoder
alignments: sequence of attention weights from the decoder
"""
decoder_input = self.get_go_frame(memory).unsqueeze(0)
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
decoder_inputs = self.prenet(decoder_inputs)
self.initialize_decoder_states(
memory, obs_and_lat, mask=~get_mask_from_lengths(memory_lengths))
mel_outputs, gate_outputs, alignments = [], [], []
while len(mel_outputs) < decoder_inputs.size(0) - 1:
decoder_input = decoder_inputs[len(mel_outputs)]
mel_output, gate_output, attention_weights = self.decode(
decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze()]
alignments += [attention_weights]
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
mel_outputs, gate_outputs, alignments)
return mel_outputs, gate_outputs, alignments
def inference(self, memory, obs_and_lat, ret_has_eos=False):
""" Decoder inference
PARAMS
------
memory: Encoder outputs
obs_and_lat: Observed and latent attribute embeddings
RETURNS
-------
mel_outputs: mel outputs from the decoder
gate_outputs: gate outputs from the decoder
alignments: sequence of attention weights from the decoder
"""
decoder_input = self.get_go_frame(memory)
self.initialize_decoder_states(memory, obs_and_lat, mask=None)
mel_outputs, gate_outputs, alignments = [], [], []
has_eos = False
while True:
decoder_input = self.prenet(decoder_input)
mel_output, gate_output, alignment = self.decode(decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
alignments += [alignment]
if torch.sigmoid(gate_output.data) > self.gate_threshold:
has_eos = True
break
elif len(mel_outputs) == self.max_decoder_steps:
# print("Warning! Reached max decoder steps")
break
decoder_input = mel_output
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
mel_outputs, gate_outputs, alignments)
if ret_has_eos:
return mel_outputs, gate_outputs, alignments, has_eos
else:
return mel_outputs, gate_outputs, alignments
class Tacotron2(nn.Module):
def __init__(self, hparams):
super(Tacotron2, self).__init__()
self.mask_padding = hparams.mask_padding
self.fp16_run = hparams.fp16_run
self.n_mel_channels = hparams.n_mel_channels
self.n_frames_per_step = hparams.n_frames_per_step
# initialize text encoder embedding
self.embedding = nn.Embedding(
hparams.n_symbols, hparams.symbols_embedding_dim)
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
# initialize observed attribute embedding
self.obs_embedding = None
if hparams.obs_dim > 0:
self.obs_embedding = nn.Embedding(
hparams.obs_n_class, hparams.obs_dim)
std = sqrt(2.0 / (hparams.obs_n_class + hparams.obs_dim))
val = sqrt(3.0) * std # uniform bounds for std
self.obs_embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(hparams)
self.decoder = Decoder(hparams)
self.postnet = Postnet(hparams)
self.lat_encoder = None
if hparams.lat_dim > 0:
self.lat_encoder = AudioEncoder(hparams)
def parse_batch(self, batch):
(text_padded, input_lengths, obs_labels,
mel_padded, gate_padded, output_lengths) = batch
text_padded = to_gpu(text_padded).long()
input_lengths = to_gpu(input_lengths).long()
obs_labels = to_gpu(obs_labels).long()
max_len = torch.max(input_lengths.data).item()
mel_padded = to_gpu(mel_padded).float()
gate_padded = to_gpu(gate_padded).float()
output_lengths = to_gpu(output_lengths).long()
return (
(text_padded, input_lengths, obs_labels,
mel_padded, max_len, output_lengths),
(mel_padded, gate_padded))
def parse_output(self, outputs, output_lengths=None):
if self.mask_padding and output_lengths is not None:
mask = ~get_mask_from_lengths(output_lengths)
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
mask = mask.permute(1, 0, 2)
outputs[0].data.masked_fill_(mask, 0.0)
outputs[1].data.masked_fill_(mask, 0.0)
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
return outputs
def forward(self, inputs):
(text_inputs, text_lengths, obs_labels,
mels, max_len, output_lengths) = inputs
text_lengths, output_lengths = text_lengths.data, output_lengths.data
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
obs = None
if self.obs_embedding is not None:
obs = self.obs_embedding(obs_labels)
lat, lat_mu, lat_logvar = None, None, None
if self.lat_encoder is not None:
(lat, lat_mu, lat_logvar) = self.lat_encoder(mels, output_lengths)
obs_and_lat = [x for x in [obs, lat] if x is not None]
if bool(obs_and_lat):
obs_and_lat = torch.cat(obs_and_lat, dim=-1)
else:
obs_and_lat = None
mel_outputs, gate_outputs, alignments = self.decoder(
encoder_outputs, obs_and_lat, mels, memory_lengths=text_lengths)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
return self.parse_output(
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments,
lat_mu, lat_logvar],
output_lengths)
def inference(self, inputs, obs_labels=None, lat=None, ret_has_eos=False):
embedded_inputs = self.embedding(inputs).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
if obs_labels is None:
obs_labels = torch.LongTensor(len(inputs))
obs_labels = obs_labels.to(inputs.device).zero_()
obs = None
if self.obs_embedding is not None:
obs = self.obs_embedding(obs_labels)
if self.lat_encoder is not None:
if lat is None:
lat = torch.FloatTensor(len(inputs), self.lat_encoder.lat_dim)
lat = lat.to(inputs.device).zero_().type(encoder_outputs.type())
obs_and_lat = [x for x in [obs, lat] if x is not None]
if bool(obs_and_lat):
obs_and_lat = torch.cat(obs_and_lat, dim=-1)
else:
obs_and_lat = None
mel_outputs, gate_outputs, alignments, has_eos = self.decoder.inference(
encoder_outputs, obs_and_lat, ret_has_eos=True)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
outputs = self.parse_output(
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
if ret_has_eos:
return outputs + [has_eos]
else:
return outputs