Spaces:
Running
Running
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import torch.nn.init as init | |
from .modules import InvertibleConv1x1 | |
def initialize_weights(net_l, scale=1): | |
if not isinstance(net_l, list): | |
net_l = [net_l] | |
for net in net_l: | |
for m in net.modules(): | |
if isinstance(m, nn.Conv2d): | |
init.kaiming_normal_(m.weight, a=0, mode="fan_in") | |
m.weight.data *= scale # for residual block | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.Linear): | |
init.kaiming_normal_(m.weight, a=0, mode="fan_in") | |
m.weight.data *= scale | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm2d): | |
init.constant_(m.weight, 1) | |
init.constant_(m.bias.data, 0.0) | |
def initialize_weights_xavier(net_l, scale=1): | |
if not isinstance(net_l, list): | |
net_l = [net_l] | |
for net in net_l: | |
for m in net.modules(): | |
if isinstance(m, nn.Conv2d): | |
init.xavier_normal_(m.weight) | |
m.weight.data *= scale # for residual block | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.Linear): | |
init.xavier_normal_(m.weight) | |
m.weight.data *= scale | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm2d): | |
init.constant_(m.weight, 1) | |
init.constant_(m.bias.data, 0.0) | |
class DenseBlock(nn.Module): | |
def __init__(self, channel_in, channel_out, init="xavier", gc=32, bias=True): | |
super(DenseBlock, self).__init__() | |
self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias) | |
self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias) | |
self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias) | |
self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias) | |
self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias) | |
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
if init == "xavier": | |
initialize_weights_xavier( | |
[self.conv1, self.conv2, self.conv3, self.conv4], 0.1 | |
) | |
else: | |
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) | |
initialize_weights(self.conv5, 0) | |
def forward(self, x): | |
x1 = self.lrelu(self.conv1(x)) | |
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) | |
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) | |
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) | |
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
return x5 | |
def subnet(net_structure, init="xavier"): | |
def constructor(channel_in, channel_out): | |
if net_structure == "DBNet": | |
if init == "xavier": | |
return DenseBlock(channel_in, channel_out, init) | |
else: | |
return DenseBlock(channel_in, channel_out) | |
# return UNetBlock(channel_in, channel_out) | |
else: | |
return None | |
return constructor | |
class InvBlock(nn.Module): | |
def __init__(self, subnet_constructor, channel_num, channel_split_num, clamp=0.8): | |
super(InvBlock, self).__init__() | |
# channel_num: 3 | |
# channel_split_num: 1 | |
self.split_len1 = channel_split_num # 1 | |
self.split_len2 = channel_num - channel_split_num # 2 | |
self.clamp = clamp | |
self.F = subnet_constructor(self.split_len2, self.split_len1) | |
self.G = subnet_constructor(self.split_len1, self.split_len2) | |
self.H = subnet_constructor(self.split_len1, self.split_len2) | |
in_channels = 3 | |
self.invconv = InvertibleConv1x1(in_channels, LU_decomposed=True) | |
self.flow_permutation = lambda z, logdet, rev: self.invconv(z, logdet, rev) | |
def forward(self, x, rev=False): | |
if not rev: | |
# invert1x1conv | |
x, logdet = self.flow_permutation(x, logdet=0, rev=False) | |
# split to 1 channel and 2 channel. | |
x1, x2 = ( | |
x.narrow(1, 0, self.split_len1), | |
x.narrow(1, self.split_len1, self.split_len2), | |
) | |
y1 = x1 + self.F(x2) # 1 channel | |
self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1) | |
y2 = x2.mul(torch.exp(self.s)) + self.G(y1) # 2 channel | |
out = torch.cat((y1, y2), 1) | |
else: | |
# split. | |
x1, x2 = ( | |
x.narrow(1, 0, self.split_len1), | |
x.narrow(1, self.split_len1, self.split_len2), | |
) | |
self.s = self.clamp * (torch.sigmoid(self.H(x1)) * 2 - 1) | |
y2 = (x2 - self.G(x1)).div(torch.exp(self.s)) | |
y1 = x1 - self.F(y2) | |
x = torch.cat((y1, y2), 1) | |
# inv permutation | |
out, logdet = self.flow_permutation(x, logdet=0, rev=True) | |
return out | |
class InvISPNet(nn.Module): | |
def __init__( | |
self, | |
channel_in=3, | |
channel_out=3, | |
subnet_constructor=subnet("DBNet"), | |
block_num=8, | |
): | |
super(InvISPNet, self).__init__() | |
operations = [] | |
current_channel = channel_in | |
channel_num = channel_in | |
channel_split_num = 1 | |
for j in range(block_num): | |
b = InvBlock( | |
subnet_constructor, channel_num, channel_split_num | |
) # one block is one flow step. | |
operations.append(b) | |
self.operations = nn.ModuleList(operations) | |
self.initialize() | |
def initialize(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
init.xavier_normal_(m.weight) | |
m.weight.data *= 1.0 # for residual block | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.Linear): | |
init.xavier_normal_(m.weight) | |
m.weight.data *= 1.0 | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm2d): | |
init.constant_(m.weight, 1) | |
init.constant_(m.bias.data, 0.0) | |
def forward(self, x, rev=False): | |
out = x # x: [N,3,H,W] | |
if not rev: | |
for op in self.operations: | |
out = op.forward(out, rev) | |
else: | |
for op in reversed(self.operations): | |
out = op.forward(out, rev) | |
return out | |