from typing import Optional, List, Tuple import torch from torch import nn from torch.nn import Conv1d, ConvTranspose1d from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, weight_norm from .residuals import ResBlock1, ResBlock2, LRELU_SLOPE from .utils import call_weight_data_normal_if_Conv class Generator(torch.nn.Module): def __init__( self, initial_channel: int, resblock: str, resblock_kernel_sizes: List[int], resblock_dilation_sizes: List[List[int]], upsample_rates: List[int], upsample_initial_channel: int, upsample_kernel_sizes: List[int], gin_channels: int = 0, ): super(Generator, self).__init__() self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) self.conv_pre = Conv1d( initial_channel, upsample_initial_channel, 7, 1, padding=3 ) self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): self.ups.append( weight_norm( ConvTranspose1d( upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2, ) ) ) self.resblocks = nn.ModuleList() resblock_module = ResBlock1 if resblock == "1" else ResBlock2 for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): self.resblocks.append(resblock_module(ch, k, d)) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.ups.apply(call_weight_data_normal_if_Conv) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) def __call__( self, x: torch.Tensor, g: Optional[torch.Tensor] = None, n_res: Optional[int] = None, ) -> torch.Tensor: return super().__call__(x, g=g, n_res=n_res) def forward( self, x: torch.Tensor, g: Optional[torch.Tensor] = None, n_res: Optional[int] = None, ): if n_res is not None: n = int(n_res) if n != x.shape[-1]: x = F.interpolate(x, size=n, mode="linear") x = self.conv_pre(x) if g is not None: x = x + self.cond(g) for i in range(self.num_upsamples): x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) n = i * self.num_kernels xs = self.resblocks[n](x) for j in range(1, self.num_kernels): xs += self.resblocks[n + j](x) x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) x = torch.tanh(x) return x def __prepare_scriptable__(self): for l in self.ups: for hook in l._forward_pre_hooks.values(): # The hook we want to remove is an instance of WeightNorm class, so # normally we would do `if isinstance(...)` but this class is not accessible # because of shadowing, so we check the module name directly. # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(l) for l in self.resblocks: for hook in l._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(l) return self def remove_weight_norm(self): for l in self.ups: remove_weight_norm(l) for l in self.resblocks: l.remove_weight_norm() class SineGenerator(torch.nn.Module): """Definition of sine generator SineGenerator(samp_rate, harmonic_num = 0, sine_amp = 0.1, noise_std = 0.003, voiced_threshold = 0, flag_for_pulse=False) samp_rate: sampling rate in Hz harmonic_num: number of harmonic overtones (default 0) sine_amp: amplitude of sine-wavefrom (default 0.1) noise_std: std of Gaussian noise (default 0.003) voiced_thoreshold: F0 threshold for U/V classification (default 0) flag_for_pulse: this SinGen is used inside PulseGen (default False) Note: when flag_for_pulse is True, the first time step of a voiced segment is always sin(torch.pi) or cos(0) """ def __init__( self, samp_rate: int, harmonic_num: int = 0, sine_amp: float = 0.1, noise_std: float = 0.003, voiced_threshold: int = 0, ): super(SineGenerator, self).__init__() self.sine_amp = sine_amp self.noise_std = noise_std self.harmonic_num = harmonic_num self.dim = harmonic_num + 1 self.sampling_rate = samp_rate self.voiced_threshold = voiced_threshold def __call__( self, f0: torch.Tensor, upp: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return super().__call__(f0, upp) def forward( self, f0: torch.Tensor, upp: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """sine_tensor, uv = forward(f0) input F0: tensor(batchsize=1, length, dim=1) f0 for unvoiced steps should be 0 output sine_tensor: tensor(batchsize=1, length, dim) output uv: tensor(batchsize=1, length, 1) """ with torch.no_grad(): f0 = f0[:, None].transpose(1, 2) f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) # fundamental component f0_buf[:, :, 0] = f0[:, :, 0] for idx in range(self.harmonic_num): f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( idx + 2 ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic rad_values = ( f0_buf / self.sampling_rate ) % 1 ###%1意味着n_har的乘积无法后处理优化 rand_ini = torch.rand( f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device ) rand_ini[:, 0] = 0 rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini tmp_over_one = torch.cumsum( rad_values, 1 ) # % 1 #####%1意味着后面的cumsum无法再优化 tmp_over_one *= upp tmp_over_one: torch.Tensor = F.interpolate( tmp_over_one.transpose(2, 1), scale_factor=float(upp), mode="linear", align_corners=True, ).transpose(2, 1) rad_values: torch.Tensor = F.interpolate( rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest" ).transpose( 2, 1 ) ####### tmp_over_one %= 1 tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 cumsum_shift = torch.zeros_like(rad_values) cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 sine_waves = torch.sin( torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi ) sine_waves = sine_waves * self.sine_amp uv = self._f02uv(f0) uv: torch.Tensor = F.interpolate( uv.transpose(2, 1), scale_factor=float(upp), mode="nearest" ).transpose(2, 1) noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 noise = noise_amp * torch.randn_like(sine_waves) sine_waves = sine_waves * uv + noise return sine_waves, uv, noise def _f02uv(self, f0): # generate uv signal uv = torch.ones_like(f0) uv = uv * (f0 > self.voiced_threshold) if uv.device.type == "privateuseone": # for DirectML uv = uv.float() return uv