Flux9665's picture
update to current version
6a79837
"""
Copied from https://github.com/KdaiP/StableTTS by https://github.com/KdaiP
https://github.com/KdaiP/StableTTS/blob/eebb177ebf195fd1246dedabec4ef69d9351a4f8/models/dit.py
Code is under MIT License
"""
# References:
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py
# https://github.com/jaywalnut310/vits/blob/main/attentions.py
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.drop = nn.Dropout(p_dropout)
self.act1 = nn.GELU(approximate="tanh")
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = self.act1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
return x * x_mask
class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, p_dropout=0.):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.k_channels = channels // n_heads
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
# from https://nn.labml.ai/transformers/rope/index.html
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
self.drop = torch.nn.Dropout(p_dropout)
torch.nn.init.xavier_uniform_(self.conv_q.weight)
torch.nn.init.xavier_uniform_(self.conv_k.weight)
torch.nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(x)
v = self.conv_v(x)
x = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head]
key = self.key_rotary_pe(key)
output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output
# modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390
class DiTConVBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_channels, out_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6)
self.attn = MultiHeadAttention(hidden_channels + out_channels, hidden_channels + out_channels, num_heads, p_dropout)
self.norm2 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6)
self.mlp = FFN(hidden_channels + out_channels, hidden_channels + out_channels, filter_channels, kernel_size, p_dropout=p_dropout)
self.adaLN_modulation = nn.Sequential(
nn.Linear(gin_channels, hidden_channels + out_channels) if gin_channels != hidden_channels + out_channels else nn.Identity(),
nn.SiLU(),
nn.Linear(hidden_channels + out_channels, 6 * (hidden_channels + out_channels), bias=True)
)
def forward(self, x, c, x_mask):
"""
Args:
x : [batch_size, channel, time]
c : [batch_size, channel]
x_mask : [batch_size, 1, time]
return the same shape as x
"""
x = x * x_mask
attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time]
# attn_mask = attn_mask.to(torch.bool)
if c is not None:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1]
x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1, 2)).transpose(1, 2), shift_msa, scale_msa), attn_mask) * x_mask
# x = x.masked_fill(~x_mask, 0.0)
x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1, 2)).transpose(1, 2), shift_mlp, scale_mlp), x_mask) * x_mask
else:
# no condition version
x = x + self.attn(self.norm1(x.transpose(1, 2)).transpose(1, 2), attn_mask)
x = x + self.mlp(self.norm1(x.transpose(1, 2)).transpose(1, 2), x_mask)
return x
@staticmethod
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class RotaryPositionalEmbeddings(nn.Module):
"""
## RoPE module
Rotary encoding transforms pairs of features by rotating in the 2D plane.
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
by an angle depending on the position of the token.
"""
def __init__(self, d: int, base: int = 10_000):
r"""
* `d` is the number of features $d$
* `base` is the constant used for calculating $\Theta$
"""
super().__init__()
self.base = base
self.d = int(d)
self.cos_cached = None
self.sin_cached = None
def _build_cache(self, x: torch.Tensor):
r"""
Cache $\cos$ and $\sin$ values
"""
# Return if cache is already built
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
return
# Get sequence length
seq_len = x.shape[0]
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
# Concatenate so that for row $m$ we have
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
# Cache them
self.cos_cached = idx_theta2.cos()[:, None, None, :]
self.sin_cached = idx_theta2.sin()[:, None, None, :]
def _neg_half(self, x: torch.Tensor):
# $\frac{d}{2}$
d_2 = self.d // 2
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
def forward(self, x: torch.Tensor):
"""
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
"""
# Cache $\cos$ and $\sin$ values
x = x.permute(2, 0, 1, 3) # b h t d -> t b h d
self._build_cache(x)
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., : self.d], x[..., self.d:]
# Calculate
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x_rope)
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d
class Transpose(nn.Identity):
"""(N, T, D) -> (N, D, T)"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input.transpose(1, 2)