yerfor's picture
init
22871e7
raw
history blame
6.21 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.commons.layers import LayerNorm, Embedding
class LambdaLayer(nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
def init_weights_func(m):
classname = m.__class__.__name__
if classname.find("Conv1d") != -1:
torch.nn.init.xavier_uniform_(m.weight)
class ResidualBlock(nn.Module):
"""Implements conv->PReLU->norm n-times"""
def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
c_multiple=2, ln_eps=1e-12):
super(ResidualBlock, self).__init__()
if norm_type == 'bn':
norm_builder = lambda: nn.BatchNorm1d(channels)
elif norm_type == 'in':
norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
elif norm_type == 'gn':
norm_builder = lambda: nn.GroupNorm(8, channels)
elif norm_type == 'ln':
norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
else:
norm_builder = lambda: nn.Identity()
self.blocks = [
nn.Sequential(
norm_builder(),
nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
padding=(dilation * (kernel_size - 1)) // 2),
LambdaLayer(lambda x: x * kernel_size ** -0.5),
nn.GELU(),
nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation),
)
for i in range(n)
]
self.blocks = nn.ModuleList(self.blocks)
self.dropout = dropout
def forward(self, x):
nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
for b in self.blocks:
x_ = b(x)
if self.dropout > 0 and self.training:
x_ = F.dropout(x_, self.dropout, training=self.training)
x = x + x_
x = x * nonpadding
return x
class ConvBlocks(nn.Module):
"""Decodes the expanded phoneme encoding into spectrograms"""
def __init__(self, hidden_size, out_dims, dilations, kernel_size,
norm_type='ln', layers_in_block=2, c_multiple=2,
dropout=0.0, ln_eps=1e-5,
init_weights=True, is_BTC=True, num_layers=None, post_net_kernel=3):
super(ConvBlocks, self).__init__()
self.is_BTC = is_BTC
if num_layers is not None:
dilations = [1] * num_layers
self.res_blocks = nn.Sequential(
*[ResidualBlock(hidden_size, kernel_size, d,
n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
dropout=dropout, ln_eps=ln_eps)
for d in dilations],
)
if norm_type == 'bn':
norm = nn.BatchNorm1d(hidden_size)
elif norm_type == 'in':
norm = nn.InstanceNorm1d(hidden_size, affine=True)
elif norm_type == 'gn':
norm = nn.GroupNorm(8, hidden_size)
elif norm_type == 'ln':
norm = LayerNorm(hidden_size, dim=1, eps=ln_eps)
self.last_norm = norm
self.post_net1 = nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel,
padding=post_net_kernel // 2)
if init_weights:
self.apply(init_weights_func)
def forward(self, x, nonpadding=None):
"""
:param x: [B, T, H]
:return: [B, T, H]
"""
if self.is_BTC:
x = x.transpose(1, 2)
if nonpadding is None:
nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
elif self.is_BTC:
nonpadding = nonpadding.transpose(1, 2)
x = self.res_blocks(x) * nonpadding
x = self.last_norm(x) * nonpadding
x = self.post_net1(x) * nonpadding
if self.is_BTC:
x = x.transpose(1, 2)
return x
class TextConvEncoder(ConvBlocks):
def __init__(self, dict_size, hidden_size, out_dims, dilations, kernel_size,
norm_type='ln', layers_in_block=2, c_multiple=2,
dropout=0.0, ln_eps=1e-5, init_weights=True, num_layers=None, post_net_kernel=3):
super().__init__(hidden_size, out_dims, dilations, kernel_size,
norm_type, layers_in_block, c_multiple,
dropout, ln_eps, init_weights, num_layers=num_layers,
post_net_kernel=post_net_kernel)
self.embed_tokens = Embedding(dict_size, hidden_size, 0)
self.embed_scale = math.sqrt(hidden_size)
def forward(self, txt_tokens):
"""
:param txt_tokens: [B, T]
:return: {
'encoder_out': [B x T x C]
}
"""
x = self.embed_scale * self.embed_tokens(txt_tokens)
return super().forward(x)
class ConditionalConvBlocks(ConvBlocks):
def __init__(self, hidden_size, c_cond, c_out, dilations, kernel_size,
norm_type='ln', layers_in_block=2, c_multiple=2,
dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, num_layers=None):
super().__init__(hidden_size, c_out, dilations, kernel_size,
norm_type, layers_in_block, c_multiple,
dropout, ln_eps, init_weights, is_BTC=False, num_layers=num_layers)
self.g_prenet = nn.Conv1d(c_cond, hidden_size, 3, padding=1)
self.is_BTC_ = is_BTC
if init_weights:
self.g_prenet.apply(init_weights_func)
def forward(self, x, cond, nonpadding=None):
if self.is_BTC_:
x = x.transpose(1, 2)
cond = cond.transpose(1, 2)
if nonpadding is not None:
nonpadding = nonpadding.transpose(1, 2)
if nonpadding is None:
nonpadding = x.abs().sum(1)[:, None]
x = x + self.g_prenet(cond)
x = x * nonpadding
x = super(ConditionalConvBlocks, self).forward(x) # input needs to be BTC
if self.is_BTC_:
x = x.transpose(1, 2)
return x