Vexa's picture
Duplicate from pivich/sovits-new
d5d7329
raw
history blame contribute delete
No virus
8.18 kB
import warnings
from logging import getLogger
from typing import Any, Literal, Sequence
import torch
from torch import nn
import so_vits_svc_fork.f0
from so_vits_svc_fork.f0 import f0_to_coarse
from so_vits_svc_fork.modules import commons as commons
from so_vits_svc_fork.modules.decoders.f0 import F0Decoder
from so_vits_svc_fork.modules.decoders.hifigan import NSFHifiGANGenerator
from so_vits_svc_fork.modules.decoders.mb_istft import (
Multiband_iSTFT_Generator,
Multistream_iSTFT_Generator,
iSTFT_Generator,
)
from so_vits_svc_fork.modules.encoders import Encoder, TextEncoder
from so_vits_svc_fork.modules.flows import ResidualCouplingBlock
LOG = getLogger(__name__)
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(
self,
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: int,
resblock: str,
resblock_kernel_sizes: Sequence[int],
resblock_dilation_sizes: Sequence[Sequence[int]],
upsample_rates: Sequence[int],
upsample_initial_channel: int,
upsample_kernel_sizes: Sequence[int],
gin_channels: int,
ssl_dim: int,
n_speakers: int,
sampling_rate: int = 44100,
type_: Literal["hifi-gan", "istft", "ms-istft", "mb-istft"] = "hifi-gan",
gen_istft_n_fft: int = 16,
gen_istft_hop_size: int = 4,
subbands: int = 4,
**kwargs: Any,
):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.gin_channels = gin_channels
self.ssl_dim = ssl_dim
self.n_speakers = n_speakers
self.sampling_rate = sampling_rate
self.type_ = type_
self.gen_istft_n_fft = gen_istft_n_fft
self.gen_istft_hop_size = gen_istft_hop_size
self.subbands = subbands
if kwargs:
warnings.warn(f"Unused arguments: {kwargs}")
self.emb_g = nn.Embedding(n_speakers, gin_channels)
if ssl_dim is None:
self.pre = nn.LazyConv1d(hidden_channels, kernel_size=5, padding=2)
else:
self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
self.enc_p = TextEncoder(
inter_channels,
hidden_channels,
filter_channels=filter_channels,
n_heads=n_heads,
n_layers=n_layers,
kernel_size=kernel_size,
p_dropout=p_dropout,
)
LOG.info(f"Decoder type: {type_}")
if type_ == "hifi-gan":
hps = {
"sampling_rate": sampling_rate,
"inter_channels": inter_channels,
"resblock": resblock,
"resblock_kernel_sizes": resblock_kernel_sizes,
"resblock_dilation_sizes": resblock_dilation_sizes,
"upsample_rates": upsample_rates,
"upsample_initial_channel": upsample_initial_channel,
"upsample_kernel_sizes": upsample_kernel_sizes,
"gin_channels": gin_channels,
}
self.dec = NSFHifiGANGenerator(h=hps)
self.mb = False
else:
hps = {
"initial_channel": inter_channels,
"resblock": resblock,
"resblock_kernel_sizes": resblock_kernel_sizes,
"resblock_dilation_sizes": resblock_dilation_sizes,
"upsample_rates": upsample_rates,
"upsample_initial_channel": upsample_initial_channel,
"upsample_kernel_sizes": upsample_kernel_sizes,
"gin_channels": gin_channels,
"gen_istft_n_fft": gen_istft_n_fft,
"gen_istft_hop_size": gen_istft_hop_size,
"subbands": subbands,
}
# gen_istft_n_fft, gen_istft_hop_size, subbands
if type_ == "istft":
del hps["subbands"]
self.dec = iSTFT_Generator(**hps)
elif type_ == "ms-istft":
self.dec = Multistream_iSTFT_Generator(**hps)
elif type_ == "mb-istft":
self.dec = Multiband_iSTFT_Generator(**hps)
else:
raise ValueError(f"Unknown type: {type_}")
self.mb = True
self.enc_q = Encoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.f0_decoder = F0Decoder(
1,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
spk_channels=gin_channels,
)
self.emb_uv = nn.Embedding(2, hidden_channels)
def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None):
g = self.emb_g(g).transpose(1, 2)
# ssl prenet
x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(
c.dtype
)
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
# f0 predict
lf0 = 2595.0 * torch.log10(1.0 + f0.unsqueeze(1) / 700.0) / 500
norm_lf0 = so_vits_svc_fork.f0.normalize_f0(lf0, x_mask, uv)
pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
# encoder
z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0))
z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
# flow
z_p = self.flow(z, spec_mask, g=g)
z_slice, pitch_slice, ids_slice = commons.rand_slice_segments_with_pitch(
z, f0, spec_lengths, self.segment_size
)
# MB-iSTFT-VITS
if self.mb:
o, o_mb = self.dec(z_slice, g=g)
# HiFi-GAN
else:
o = self.dec(z_slice, g=g, f0=pitch_slice)
o_mb = None
return (
o,
o_mb,
ids_slice,
spec_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
pred_lf0,
norm_lf0,
lf0,
)
def infer(self, c, f0, uv, g=None, noice_scale=0.35, predict_f0=False):
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
g = self.emb_g(g).transpose(1, 2)
x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(
c.dtype
)
x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
if predict_f0:
lf0 = 2595.0 * torch.log10(1.0 + f0.unsqueeze(1) / 700.0) / 500
norm_lf0 = so_vits_svc_fork.f0.normalize_f0(
lf0, x_mask, uv, random_scale=False
)
pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1)
z_p, m_p, logs_p, c_mask = self.enc_p(
x, x_mask, f0=f0_to_coarse(f0), noice_scale=noice_scale
)
z = self.flow(z_p, c_mask, g=g, reverse=True)
# MB-iSTFT-VITS
if self.mb:
o, o_mb = self.dec(z * c_mask, g=g)
else:
o = self.dec(z * c_mask, g=g, f0=f0)
return o