|
import torch |
|
import torch.nn as nn |
|
|
|
from modules.ChatTTS.ChatTTS.model.dvae import ConvNeXtBlock, DVAEDecoder |
|
|
|
from .wavenet import WaveNet |
|
|
|
|
|
def get_encoder_config(decoder: DVAEDecoder) -> dict[str, int | bool]: |
|
return { |
|
"idim": decoder.conv_out.out_channels, |
|
"odim": decoder.conv_in[0].in_channels, |
|
"n_layer": len(decoder.decoder_block), |
|
"bn_dim": decoder.conv_in[0].out_channels, |
|
"hidden": decoder.conv_in[2].out_channels, |
|
"kernel": decoder.decoder_block[0].dwconv.kernel_size[0], |
|
"dilation": decoder.decoder_block[0].dwconv.dilation[0], |
|
"down": decoder.up, |
|
} |
|
|
|
|
|
class DVAEEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
idim: int, |
|
odim: int, |
|
n_layer: int = 12, |
|
bn_dim: int = 64, |
|
hidden: int = 256, |
|
kernel: int = 7, |
|
dilation: int = 2, |
|
down: bool = False, |
|
) -> None: |
|
super().__init__() |
|
self.wavenet = WaveNet( |
|
input_channels=100, |
|
residual_channels=idim, |
|
residual_layers=20, |
|
dilation_cycle=4, |
|
) |
|
self.conv_in_transpose = nn.ConvTranspose1d( |
|
idim, hidden, kernel_size=1, bias=False |
|
) |
|
|
|
|
|
|
|
|
|
self.encoder_block = nn.ModuleList( |
|
[ |
|
ConvNeXtBlock( |
|
hidden, |
|
hidden * 4, |
|
kernel, |
|
dilation, |
|
) |
|
for _ in range(n_layer) |
|
] |
|
) |
|
self.conv_out_transpose = nn.Sequential( |
|
nn.Conv1d(hidden, bn_dim, 3, 1, 1), |
|
nn.GELU(), |
|
nn.Conv1d(bn_dim, odim, 3, 1, 1), |
|
) |
|
|
|
def forward( |
|
self, |
|
audio_mel_specs: torch.Tensor, |
|
audio_attention_mask: torch.Tensor, |
|
conditioning=None, |
|
) -> torch.Tensor: |
|
mel_attention_mask = ( |
|
audio_attention_mask.unsqueeze(-1).repeat(1, 1, 2).flatten(1) |
|
) |
|
x: torch.Tensor = self.wavenet( |
|
audio_mel_specs.transpose(1, 2) |
|
) |
|
x = x * mel_attention_mask.unsqueeze(1) |
|
x = self.conv_in_transpose(x) |
|
for f in self.encoder_block: |
|
x = f(x, conditioning) |
|
x = self.conv_out_transpose(x) |
|
x = ( |
|
x.view(x.size(0), x.size(1), 2, x.size(2) // 2) |
|
.permute(0, 3, 1, 2) |
|
.flatten(2) |
|
) |
|
return x |
|
|