# Copyright (c) 2024 Xinsheng Wang (w.xinshawn@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0 import torch.nn as nn from sparktts.modules.blocks.layers import ( Snake1d, WNConv1d, ResidualUnit, WNConvTranspose1d, init_weights, ) class DecoderBlock(nn.Module): def __init__( self, input_dim: int = 16, output_dim: int = 8, kernel_size: int = 2, stride: int = 1, ): super().__init__() self.block = nn.Sequential( Snake1d(input_dim), WNConvTranspose1d( input_dim, output_dim, kernel_size=kernel_size, stride=stride, padding=(kernel_size - stride) // 2, ), ResidualUnit(output_dim, dilation=1), ResidualUnit(output_dim, dilation=3), ResidualUnit(output_dim, dilation=9), ) def forward(self, x): return self.block(x) class WaveGenerator(nn.Module): def __init__( self, input_channel, channels, rates, kernel_sizes, d_out: int = 1, ): super().__init__() # Add first conv layer layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] # Add upsampling + MRF blocks for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)): input_dim = channels // 2**i output_dim = channels // 2 ** (i + 1) layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)] # Add final conv layer layers += [ Snake1d(output_dim), WNConv1d(output_dim, d_out, kernel_size=7, padding=3), nn.Tanh(), ] self.model = nn.Sequential(*layers) self.apply(init_weights) def forward(self, x): return self.model(x)