import warnings from logging import getLogger from typing import Any, Literal, Sequence import torch from torch import nn import so_vits_svc_fork.f0 from so_vits_svc_fork.f0 import f0_to_coarse from so_vits_svc_fork.modules import commons as commons from so_vits_svc_fork.modules.decoders.f0 import F0Decoder from so_vits_svc_fork.modules.decoders.hifigan import NSFHifiGANGenerator from so_vits_svc_fork.modules.decoders.mb_istft import ( Multiband_iSTFT_Generator, Multistream_iSTFT_Generator, iSTFT_Generator, ) from so_vits_svc_fork.modules.encoders import Encoder, TextEncoder from so_vits_svc_fork.modules.flows import ResidualCouplingBlock LOG = getLogger(__name__) class SynthesizerTrn(nn.Module): """ Synthesizer for Training """ def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: int, resblock: str, resblock_kernel_sizes: Sequence[int], resblock_dilation_sizes: Sequence[Sequence[int]], upsample_rates: Sequence[int], upsample_initial_channel: int, upsample_kernel_sizes: Sequence[int], gin_channels: int, ssl_dim: int, n_speakers: int, sampling_rate: int = 44100, type_: Literal["hifi-gan", "istft", "ms-istft", "mb-istft"] = "hifi-gan", gen_istft_n_fft: int = 16, gen_istft_hop_size: int = 4, subbands: int = 4, **kwargs: Any, ): super().__init__() self.spec_channels = spec_channels self.inter_channels = inter_channels self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = p_dropout self.resblock = resblock self.resblock_kernel_sizes = resblock_kernel_sizes self.resblock_dilation_sizes = resblock_dilation_sizes self.upsample_rates = upsample_rates self.upsample_initial_channel = upsample_initial_channel self.upsample_kernel_sizes = upsample_kernel_sizes self.segment_size = segment_size self.gin_channels = gin_channels self.ssl_dim = ssl_dim self.n_speakers = n_speakers self.sampling_rate = sampling_rate self.type_ = type_ self.gen_istft_n_fft = gen_istft_n_fft self.gen_istft_hop_size = gen_istft_hop_size self.subbands = subbands if kwargs: warnings.warn(f"Unused arguments: {kwargs}") self.emb_g = nn.Embedding(n_speakers, gin_channels) if ssl_dim is None: self.pre = nn.LazyConv1d(hidden_channels, kernel_size=5, padding=2) else: self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2) self.enc_p = TextEncoder( inter_channels, hidden_channels, filter_channels=filter_channels, n_heads=n_heads, n_layers=n_layers, kernel_size=kernel_size, p_dropout=p_dropout, ) LOG.info(f"Decoder type: {type_}") if type_ == "hifi-gan": hps = { "sampling_rate": sampling_rate, "inter_channels": inter_channels, "resblock": resblock, "resblock_kernel_sizes": resblock_kernel_sizes, "resblock_dilation_sizes": resblock_dilation_sizes, "upsample_rates": upsample_rates, "upsample_initial_channel": upsample_initial_channel, "upsample_kernel_sizes": upsample_kernel_sizes, "gin_channels": gin_channels, } self.dec = NSFHifiGANGenerator(h=hps) self.mb = False else: hps = { "initial_channel": inter_channels, "resblock": resblock, "resblock_kernel_sizes": resblock_kernel_sizes, "resblock_dilation_sizes": resblock_dilation_sizes, "upsample_rates": upsample_rates, "upsample_initial_channel": upsample_initial_channel, "upsample_kernel_sizes": upsample_kernel_sizes, "gin_channels": gin_channels, "gen_istft_n_fft": gen_istft_n_fft, "gen_istft_hop_size": gen_istft_hop_size, "subbands": subbands, } # gen_istft_n_fft, gen_istft_hop_size, subbands if type_ == "istft": del hps["subbands"] self.dec = iSTFT_Generator(**hps) elif type_ == "ms-istft": self.dec = Multistream_iSTFT_Generator(**hps) elif type_ == "mb-istft": self.dec = Multiband_iSTFT_Generator(**hps) else: raise ValueError(f"Unknown type: {type_}") self.mb = True self.enc_q = Encoder( spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels, ) self.flow = ResidualCouplingBlock( inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels ) self.f0_decoder = F0Decoder( 1, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, spk_channels=gin_channels, ) self.emb_uv = nn.Embedding(2, hidden_channels) def forward(self, c, f0, uv, spec, g=None, c_lengths=None, spec_lengths=None): g = self.emb_g(g).transpose(1, 2) # ssl prenet x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to( c.dtype ) x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) # f0 predict lf0 = 2595.0 * torch.log10(1.0 + f0.unsqueeze(1) / 700.0) / 500 norm_lf0 = so_vits_svc_fork.f0.normalize_f0(lf0, x_mask, uv) pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) # encoder z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask, f0=f0_to_coarse(f0)) z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g) # flow z_p = self.flow(z, spec_mask, g=g) z_slice, pitch_slice, ids_slice = commons.rand_slice_segments_with_pitch( z, f0, spec_lengths, self.segment_size ) # MB-iSTFT-VITS if self.mb: o, o_mb = self.dec(z_slice, g=g) # HiFi-GAN else: o = self.dec(z_slice, g=g, f0=pitch_slice) o_mb = None return ( o, o_mb, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q), pred_lf0, norm_lf0, lf0, ) def infer(self, c, f0, uv, g=None, noice_scale=0.35, predict_f0=False): c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) g = self.emb_g(g).transpose(1, 2) x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to( c.dtype ) x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) if predict_f0: lf0 = 2595.0 * torch.log10(1.0 + f0.unsqueeze(1) / 700.0) / 500 norm_lf0 = so_vits_svc_fork.f0.normalize_f0( lf0, x_mask, uv, random_scale=False ) pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g) f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1) z_p, m_p, logs_p, c_mask = self.enc_p( x, x_mask, f0=f0_to_coarse(f0), noice_scale=noice_scale ) z = self.flow(z_p, c_mask, g=g, reverse=True) # MB-iSTFT-VITS if self.mb: o, o_mb = self.dec(z * c_mask, g=g) else: o = self.dec(z * c_mask, g=g, f0=f0) return o