import torch.nn as nn import module as modules class ResidualCouplingBlock(nn.Module): def __init__( self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0, ): super().__init__() self.channels = channels self.hidden_channels = hidden_channels self.kernel_size = kernel_size self.dilation_rate = dilation_rate self.n_layers = n_layers self.n_flows = n_flows self.gin_channels = gin_channels self.flows = nn.ModuleList() for i in range(n_flows): self.flows.append( modules.ResidualCouplingLayer( channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True, ) ) self.flows.append(modules.Flip()) def forward(self, x, x_mask, g=None, reverse=False): if not reverse: for flow in self.flows: x, _ = flow(x, x_mask, g=g, reverse=reverse) else: for flow in reversed(self.flows): x = flow(x, x_mask, g=g, reverse=reverse) return x