Spaces:
Runtime error
Runtime error
# Copyright (c) 2025 SparkAudio | |
# 2025 Xinsheng Wang (w.xinshawn@gmail.com) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0 | |
import torch | |
import torch.nn as nn | |
from torch.nn.utils import weight_norm | |
def WNConv1d(*args, **kwargs): | |
return weight_norm(nn.Conv1d(*args, **kwargs)) | |
def WNConvTranspose1d(*args, **kwargs): | |
return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) | |
# Scripting this brings model speed up 1.4x | |
def snake(x, alpha): | |
shape = x.shape | |
x = x.reshape(shape[0], shape[1], -1) | |
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) | |
x = x.reshape(shape) | |
return x | |
class Snake1d(nn.Module): | |
def __init__(self, channels): | |
super().__init__() | |
self.alpha = nn.Parameter(torch.ones(1, channels, 1)) | |
def forward(self, x): | |
return snake(x, self.alpha) | |
class ResidualUnit(nn.Module): | |
def __init__(self, dim: int = 16, dilation: int = 1): | |
super().__init__() | |
pad = ((7 - 1) * dilation) // 2 | |
self.block = nn.Sequential( | |
Snake1d(dim), | |
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), | |
Snake1d(dim), | |
WNConv1d(dim, dim, kernel_size=1), | |
) | |
def forward(self, x): | |
y = self.block(x) | |
pad = (x.shape[-1] - y.shape[-1]) // 2 | |
if pad > 0: | |
x = x[..., pad:-pad] | |
return x + y | |
def init_weights(m): | |
if isinstance(m, nn.Conv1d): | |
nn.init.trunc_normal_(m.weight, std=0.02) | |
nn.init.constant_(m.bias, 0) | |