|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import Conv1d |
|
from torch.nn.utils import remove_weight_norm, weight_norm |
|
|
|
from . import LRELU_SLOPE |
|
from tools.commons import get_padding, init_weights |
|
|
|
|
|
class ResBlock1(torch.nn.Module): |
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): |
|
super(ResBlock1, self).__init__() |
|
self.convs1 = nn.ModuleList( |
|
[ |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=dilation[0], |
|
padding=get_padding(kernel_size, dilation[0]), |
|
) |
|
), |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=dilation[1], |
|
padding=get_padding(kernel_size, dilation[1]), |
|
) |
|
), |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=dilation[2], |
|
padding=get_padding(kernel_size, dilation[2]), |
|
) |
|
), |
|
] |
|
) |
|
self.convs1.apply(init_weights) |
|
|
|
self.convs2 = nn.ModuleList( |
|
[ |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=1, |
|
padding=get_padding(kernel_size, 1), |
|
) |
|
), |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=1, |
|
padding=get_padding(kernel_size, 1), |
|
) |
|
), |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=1, |
|
padding=get_padding(kernel_size, 1), |
|
) |
|
), |
|
] |
|
) |
|
self.convs2.apply(init_weights) |
|
|
|
def forward(self, x, x_mask=None): |
|
for c1, c2 in zip(self.convs1, self.convs2): |
|
xt = F.leaky_relu(x, LRELU_SLOPE) |
|
if x_mask is not None: |
|
xt = xt * x_mask |
|
xt = c1(xt) |
|
xt = F.leaky_relu(xt, LRELU_SLOPE) |
|
if x_mask is not None: |
|
xt = xt * x_mask |
|
xt = c2(xt) |
|
x = xt + x |
|
if x_mask is not None: |
|
x = x * x_mask |
|
return x |
|
|
|
def remove_weight_norm(self): |
|
for l in self.convs1: |
|
remove_weight_norm(l) |
|
for l in self.convs2: |
|
remove_weight_norm(l) |
|
|
|
|
|
class ResBlock2(torch.nn.Module): |
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3)): |
|
super(ResBlock2, self).__init__() |
|
self.convs = nn.ModuleList( |
|
[ |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=dilation[0], |
|
padding=get_padding(kernel_size, dilation[0]), |
|
) |
|
), |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=dilation[1], |
|
padding=get_padding(kernel_size, dilation[1]), |
|
) |
|
), |
|
] |
|
) |
|
self.convs.apply(init_weights) |
|
|
|
def forward(self, x, x_mask=None): |
|
for c in self.convs: |
|
xt = F.leaky_relu(x, LRELU_SLOPE) |
|
if x_mask is not None: |
|
xt = xt * x_mask |
|
xt = c(xt) |
|
x = xt + x |
|
if x_mask is not None: |
|
x = x * x_mask |
|
return x |
|
|
|
def remove_weight_norm(self): |
|
for l in self.convs: |
|
remove_weight_norm(l) |
|
|