StableTTS_en-demo / models /text_encoder.py
KdaiP's picture
Upload 238 files
d358e26 verified
import torch
import torch.nn as nn
from models.dit import DiTConVBlock
def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
# modified from https://github.com/jaywalnut310/vits/blob/main/models.py
class TextEncoder(nn.Module):
def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.scale = self.hidden_channels ** 0.5
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.encoder = nn.ModuleList([DiTConVBlock(hidden_channels, filter_channels, n_heads, kernel_size, p_dropout, gin_channels) for _ in range(n_layers)])
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.initialize_weights()
def initialize_weights(self):
for block in self.encoder:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
def forward(self, x: torch.Tensor, c: torch.Tensor, x_lengths: torch.Tensor):
x = self.emb(x) * self.scale # [b, t, h]
x = x.transpose(1, -1) # [b, h, t]
x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype)
for layer in self.encoder:
x = layer(x, c, x_mask)
mu_x = self.proj(x) * x_mask
return x, mu_x, x_mask