|
|
|
|
|
import torch |
|
from torch import nn |
|
from typing import Literal, Dict, Any |
|
import math |
|
import comfy.ops |
|
ops = comfy.ops.disable_weight_init |
|
|
|
def vae_sample(mean, scale): |
|
stdev = nn.functional.softplus(scale) + 1e-4 |
|
var = stdev * stdev |
|
logvar = torch.log(var) |
|
latents = torch.randn_like(mean) * stdev + mean |
|
|
|
kl = (mean * mean + var - logvar - 1).sum(1).mean() |
|
|
|
return latents, kl |
|
|
|
class VAEBottleneck(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.is_discrete = False |
|
|
|
def encode(self, x, return_info=False, **kwargs): |
|
info = {} |
|
|
|
mean, scale = x.chunk(2, dim=1) |
|
|
|
x, kl = vae_sample(mean, scale) |
|
|
|
info["kl"] = kl |
|
|
|
if return_info: |
|
return x, info |
|
else: |
|
return x |
|
|
|
def decode(self, x): |
|
return x |
|
|
|
|
|
def snake_beta(x, alpha, beta): |
|
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) |
|
|
|
|
|
class SnakeBeta(nn.Module): |
|
|
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): |
|
super(SnakeBeta, self).__init__() |
|
self.in_features = in_features |
|
|
|
|
|
self.alpha_logscale = alpha_logscale |
|
if self.alpha_logscale: |
|
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) |
|
self.beta = nn.Parameter(torch.zeros(in_features) * alpha) |
|
else: |
|
self.alpha = nn.Parameter(torch.ones(in_features) * alpha) |
|
self.beta = nn.Parameter(torch.ones(in_features) * alpha) |
|
|
|
|
|
|
|
|
|
self.no_div_by_zero = 0.000000001 |
|
|
|
def forward(self, x): |
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) |
|
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device) |
|
if self.alpha_logscale: |
|
alpha = torch.exp(alpha) |
|
beta = torch.exp(beta) |
|
x = snake_beta(x, alpha, beta) |
|
|
|
return x |
|
|
|
def WNConv1d(*args, **kwargs): |
|
try: |
|
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs)) |
|
except: |
|
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) |
|
|
|
def WNConvTranspose1d(*args, **kwargs): |
|
try: |
|
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) |
|
except: |
|
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) |
|
|
|
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: |
|
if activation == "elu": |
|
act = torch.nn.ELU() |
|
elif activation == "snake": |
|
act = SnakeBeta(channels) |
|
elif activation == "none": |
|
act = torch.nn.Identity() |
|
else: |
|
raise ValueError(f"Unknown activation {activation}") |
|
|
|
if antialias: |
|
act = Activation1d(act) |
|
|
|
return act |
|
|
|
|
|
class ResidualUnit(nn.Module): |
|
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): |
|
super().__init__() |
|
|
|
self.dilation = dilation |
|
|
|
padding = (dilation * (7-1)) // 2 |
|
|
|
self.layers = nn.Sequential( |
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), |
|
WNConv1d(in_channels=in_channels, out_channels=out_channels, |
|
kernel_size=7, dilation=dilation, padding=padding), |
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), |
|
WNConv1d(in_channels=out_channels, out_channels=out_channels, |
|
kernel_size=1) |
|
) |
|
|
|
def forward(self, x): |
|
res = x |
|
|
|
|
|
x = self.layers(x) |
|
|
|
return x + res |
|
|
|
class EncoderBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): |
|
super().__init__() |
|
|
|
self.layers = nn.Sequential( |
|
ResidualUnit(in_channels=in_channels, |
|
out_channels=in_channels, dilation=1, use_snake=use_snake), |
|
ResidualUnit(in_channels=in_channels, |
|
out_channels=in_channels, dilation=3, use_snake=use_snake), |
|
ResidualUnit(in_channels=in_channels, |
|
out_channels=in_channels, dilation=9, use_snake=use_snake), |
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), |
|
WNConv1d(in_channels=in_channels, out_channels=out_channels, |
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), |
|
) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): |
|
super().__init__() |
|
|
|
if use_nearest_upsample: |
|
upsample_layer = nn.Sequential( |
|
nn.Upsample(scale_factor=stride, mode="nearest"), |
|
WNConv1d(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=2*stride, |
|
stride=1, |
|
bias=False, |
|
padding='same') |
|
) |
|
else: |
|
upsample_layer = WNConvTranspose1d(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) |
|
|
|
self.layers = nn.Sequential( |
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), |
|
upsample_layer, |
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels, |
|
dilation=1, use_snake=use_snake), |
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels, |
|
dilation=3, use_snake=use_snake), |
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels, |
|
dilation=9, use_snake=use_snake), |
|
) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
class OobleckEncoder(nn.Module): |
|
def __init__(self, |
|
in_channels=2, |
|
channels=128, |
|
latent_dim=32, |
|
c_mults = [1, 2, 4, 8], |
|
strides = [2, 4, 8, 8], |
|
use_snake=False, |
|
antialias_activation=False |
|
): |
|
super().__init__() |
|
|
|
c_mults = [1] + c_mults |
|
|
|
self.depth = len(c_mults) |
|
|
|
layers = [ |
|
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) |
|
] |
|
|
|
for i in range(self.depth-1): |
|
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] |
|
|
|
layers += [ |
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), |
|
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) |
|
] |
|
|
|
self.layers = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
|
|
class OobleckDecoder(nn.Module): |
|
def __init__(self, |
|
out_channels=2, |
|
channels=128, |
|
latent_dim=32, |
|
c_mults = [1, 2, 4, 8], |
|
strides = [2, 4, 8, 8], |
|
use_snake=False, |
|
antialias_activation=False, |
|
use_nearest_upsample=False, |
|
final_tanh=True): |
|
super().__init__() |
|
|
|
c_mults = [1] + c_mults |
|
|
|
self.depth = len(c_mults) |
|
|
|
layers = [ |
|
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), |
|
] |
|
|
|
for i in range(self.depth-1, 0, -1): |
|
layers += [DecoderBlock( |
|
in_channels=c_mults[i]*channels, |
|
out_channels=c_mults[i-1]*channels, |
|
stride=strides[i-1], |
|
use_snake=use_snake, |
|
antialias_activation=antialias_activation, |
|
use_nearest_upsample=use_nearest_upsample |
|
) |
|
] |
|
|
|
layers += [ |
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), |
|
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), |
|
nn.Tanh() if final_tanh else nn.Identity() |
|
] |
|
|
|
self.layers = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
|
|
class AudioOobleckVAE(nn.Module): |
|
def __init__(self, |
|
in_channels=2, |
|
channels=128, |
|
latent_dim=64, |
|
c_mults = [1, 2, 4, 8, 16], |
|
strides = [2, 4, 4, 8, 8], |
|
use_snake=True, |
|
antialias_activation=False, |
|
use_nearest_upsample=False, |
|
final_tanh=False): |
|
super().__init__() |
|
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation) |
|
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation, |
|
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh) |
|
self.bottleneck = VAEBottleneck() |
|
|
|
def encode(self, x): |
|
return self.bottleneck.encode(self.encoder(x)) |
|
|
|
def decode(self, x): |
|
return self.decoder(self.bottleneck.decode(x)) |
|
|
|
|