mrfakename's picture
Upload 43 files
d93aca0 verified
raw
history blame
2.5 kB
# Copyright (c) 2024 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
import torch.nn as nn
from sparktts.modules.blocks.layers import (
Snake1d,
WNConv1d,
ResidualUnit,
WNConvTranspose1d,
init_weights,
)
class DecoderBlock(nn.Module):
def __init__(
self,
input_dim: int = 16,
output_dim: int = 8,
kernel_size: int = 2,
stride: int = 1,
):
super().__init__()
self.block = nn.Sequential(
Snake1d(input_dim),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2,
),
ResidualUnit(output_dim, dilation=1),
ResidualUnit(output_dim, dilation=3),
ResidualUnit(output_dim, dilation=9),
)
def forward(self, x):
return self.block(x)
class WaveGenerator(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
kernel_sizes,
d_out: int = 1,
):
super().__init__()
# Add first conv layer
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
self.apply(init_weights)
def forward(self, x):
return self.model(x)