import math import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import torch.nn.init as init from .modules import InvertibleConv1x1 def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def initialize_weights_xavier(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.xavier_normal_(m.weight) m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.xavier_normal_(m.weight) m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) class DenseBlock(nn.Module): def __init__(self, channel_in, channel_out, init="xavier", gc=32, bias=True): super(DenseBlock, self).__init__() self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) if init == "xavier": initialize_weights_xavier( [self.conv1, self.conv2, self.conv3, self.conv4], 0.1 ) else: initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) initialize_weights(self.conv5, 0) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 def subnet(net_structure, init="xavier"): def constructor(channel_in, channel_out): if net_structure == "DBNet": if init == "xavier": return DenseBlock(channel_in, channel_out, init) else: return DenseBlock(channel_in, channel_out) # return UNetBlock(channel_in, channel_out) else: return None return constructor class InvBlock(nn.Module): def __init__(self, subnet_constructor, channel_num, channel_split_num, clamp=0.8): super(InvBlock, self).__init__() # channel_num: 3 # channel_split_num: 1 self.split_len1 = channel_split_num # 1 self.split_len2 = channel_num - channel_split_num # 2 self.clamp = clamp self.F = subnet_constructor(self.split_len2, self.split_len1) self.G = subnet_constructor(self.split_len1, self.split_len2) self.H = subnet_constructor(self.split_len1, self.split_len2) in_channels = 3 self.invconv = InvertibleConv1x1(in_channels, LU_decomposed=True) self.flow_permutation = lambda z, logdet, rev: self.invconv(z, logdet, rev) def forward(self, x, rev=False): if not rev: # invert1x1conv x, logdet = self.flow_permutation(x, logdet=0, rev=False) # split to 1 channel and 2 channel. x1, x2 = ( x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2), ) y1 = x1 + self.F(x2) # 1 channel self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1) y2 = x2.mul(torch.exp(self.s)) + self.G(y1) # 2 channel out = torch.cat((y1, y2), 1) else: # split. x1, x2 = ( x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2), ) self.s = self.clamp * (torch.sigmoid(self.H(x1)) * 2 - 1) y2 = (x2 - self.G(x1)).div(torch.exp(self.s)) y1 = x1 - self.F(y2) x = torch.cat((y1, y2), 1) # inv permutation out, logdet = self.flow_permutation(x, logdet=0, rev=True) return out class InvISPNet(nn.Module): def __init__( self, channel_in=3, channel_out=3, subnet_constructor=subnet("DBNet"), block_num=8, ): super(InvISPNet, self).__init__() operations = [] current_channel = channel_in channel_num = channel_in channel_split_num = 1 for j in range(block_num): b = InvBlock( subnet_constructor, channel_num, channel_split_num ) # one block is one flow step. operations.append(b) self.operations = nn.ModuleList(operations) self.initialize() def initialize(self): for m in self.modules(): if isinstance(m, nn.Conv2d): init.xavier_normal_(m.weight) m.weight.data *= 1.0 # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.xavier_normal_(m.weight) m.weight.data *= 1.0 if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def forward(self, x, rev=False): out = x # x: [N,3,H,W] if not rev: for op in self.operations: out = op.forward(out, rev) else: for op in reversed(self.operations): out = op.forward(out, rev) return out