Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copied from https://github.com/KdaiP/StableTTS by https://github.com/KdaiP | |
https://github.com/KdaiP/StableTTS/blob/eebb177ebf195fd1246dedabec4ef69d9351a4f8/models/dit.py | |
Code is under MIT License | |
""" | |
# References: | |
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py | |
# https://github.com/jaywalnut310/vits/blob/main/attentions.py | |
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class FFN(nn.Module): | |
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.filter_channels = filter_channels | |
self.kernel_size = kernel_size | |
self.p_dropout = p_dropout | |
self.gin_channels = gin_channels | |
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) | |
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) | |
self.drop = nn.Dropout(p_dropout) | |
self.act1 = nn.GELU(approximate="tanh") | |
def forward(self, x, x_mask): | |
x = self.conv_1(x * x_mask) | |
x = self.act1(x) | |
x = self.drop(x) | |
x = self.conv_2(x * x_mask) | |
return x * x_mask | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, channels, out_channels, n_heads, p_dropout=0.): | |
super().__init__() | |
assert channels % n_heads == 0 | |
self.channels = channels | |
self.out_channels = out_channels | |
self.n_heads = n_heads | |
self.p_dropout = p_dropout | |
self.k_channels = channels // n_heads | |
self.conv_q = torch.nn.Conv1d(channels, channels, 1) | |
self.conv_k = torch.nn.Conv1d(channels, channels, 1) | |
self.conv_v = torch.nn.Conv1d(channels, channels, 1) | |
# from https://nn.labml.ai/transformers/rope/index.html | |
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) | |
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) | |
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) | |
self.drop = torch.nn.Dropout(p_dropout) | |
torch.nn.init.xavier_uniform_(self.conv_q.weight) | |
torch.nn.init.xavier_uniform_(self.conv_k.weight) | |
torch.nn.init.xavier_uniform_(self.conv_v.weight) | |
def forward(self, x, attn_mask=None): | |
q = self.conv_q(x) | |
k = self.conv_k(x) | |
v = self.conv_v(x) | |
x = self.attention(q, k, v, mask=attn_mask) | |
x = self.conv_o(x) | |
return x | |
def attention(self, query, key, value, mask=None): | |
b, d, t_s, t_t = (*key.size(), query.size(2)) | |
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) | |
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) | |
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) | |
query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head] | |
key = self.key_rotary_pe(key) | |
output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0) | |
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] | |
return output | |
# modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390 | |
class DiTConVBlock(nn.Module): | |
""" | |
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. | |
""" | |
def __init__(self, hidden_channels, out_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0): | |
super().__init__() | |
self.norm1 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6) | |
self.attn = MultiHeadAttention(hidden_channels + out_channels, hidden_channels + out_channels, num_heads, p_dropout) | |
self.norm2 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6) | |
self.mlp = FFN(hidden_channels + out_channels, hidden_channels + out_channels, filter_channels, kernel_size, p_dropout=p_dropout) | |
self.adaLN_modulation = nn.Sequential( | |
nn.Linear(gin_channels, hidden_channels + out_channels) if gin_channels != hidden_channels + out_channels else nn.Identity(), | |
nn.SiLU(), | |
nn.Linear(hidden_channels + out_channels, 6 * (hidden_channels + out_channels), bias=True) | |
) | |
def forward(self, x, c, x_mask): | |
""" | |
Args: | |
x : [batch_size, channel, time] | |
c : [batch_size, channel] | |
x_mask : [batch_size, 1, time] | |
return the same shape as x | |
""" | |
x = x * x_mask | |
attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time] | |
# attn_mask = attn_mask.to(torch.bool) | |
if c is not None: | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1] | |
x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1, 2)).transpose(1, 2), shift_msa, scale_msa), attn_mask) * x_mask | |
# x = x.masked_fill(~x_mask, 0.0) | |
x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1, 2)).transpose(1, 2), shift_mlp, scale_mlp), x_mask) * x_mask | |
else: | |
# no condition version | |
x = x + self.attn(self.norm1(x.transpose(1, 2)).transpose(1, 2), attn_mask) | |
x = x + self.mlp(self.norm1(x.transpose(1, 2)).transpose(1, 2), x_mask) | |
return x | |
def modulate(x, shift, scale): | |
return x * (1 + scale) + shift | |
class RotaryPositionalEmbeddings(nn.Module): | |
""" | |
## RoPE module | |
Rotary encoding transforms pairs of features by rotating in the 2D plane. | |
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. | |
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it | |
by an angle depending on the position of the token. | |
""" | |
def __init__(self, d: int, base: int = 10_000): | |
r""" | |
* `d` is the number of features $d$ | |
* `base` is the constant used for calculating $\Theta$ | |
""" | |
super().__init__() | |
self.base = base | |
self.d = int(d) | |
self.cos_cached = None | |
self.sin_cached = None | |
def _build_cache(self, x: torch.Tensor): | |
r""" | |
Cache $\cos$ and $\sin$ values | |
""" | |
# Return if cache is already built | |
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: | |
return | |
# Get sequence length | |
seq_len = x.shape[0] | |
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ | |
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) | |
# Create position indexes `[0, 1, ..., seq_len - 1]` | |
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) | |
# Calculate the product of position index and $\theta_i$ | |
idx_theta = torch.einsum("n,d->nd", seq_idx, theta) | |
# Concatenate so that for row $m$ we have | |
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ | |
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) | |
# Cache them | |
self.cos_cached = idx_theta2.cos()[:, None, None, :] | |
self.sin_cached = idx_theta2.sin()[:, None, None, :] | |
def _neg_half(self, x: torch.Tensor): | |
# $\frac{d}{2}$ | |
d_2 = self.d // 2 | |
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ | |
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) | |
def forward(self, x: torch.Tensor): | |
""" | |
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` | |
""" | |
# Cache $\cos$ and $\sin$ values | |
x = x.permute(2, 0, 1, 3) # b h t d -> t b h d | |
self._build_cache(x) | |
# Split the features, we can choose to apply rotary embeddings only to a partial set of features. | |
x_rope, x_pass = x[..., : self.d], x[..., self.d:] | |
# Calculate | |
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ | |
neg_half_x = self._neg_half(x_rope) | |
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) | |
return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d | |
class Transpose(nn.Identity): | |
"""(N, T, D) -> (N, D, T)""" | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return input.transpose(1, 2) | |