|
|
|
|
|
|
|
|
|
|
|
|
|
import image_dehazer |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class _DPFF(nn.Module): |
|
def __init__(self, in_channels) -> None: |
|
super(_DPFF, self).__init__() |
|
self.cbr1 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False) |
|
self.cbr2 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False) |
|
|
|
self.cbr3 = nn.Conv2d(in_channels, in_channels, 1, 1, bias=False) |
|
self.cbr4 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False) |
|
|
|
def forward(self, feature1, feature2): |
|
d1 = torch.abs(feature1 - feature2) |
|
d2 = self.cbr1(torch.cat([feature1, feature2], dim=1)) |
|
d = torch.cat([d1, d2], dim=1) |
|
d = self.cbr2(d) |
|
|
|
|
|
v1, v2 = self.cbr3(feature1), self.cbr3(feature2) |
|
v1, v2 = v1 * d, v2 * d |
|
features = torch.cat([v1, v2], dim=1) |
|
features = self.cbr4(features) |
|
|
|
return features |
|
|
|
|
|
class DPFF(nn.Module): |
|
def __init__(self, layer_channels) -> None: |
|
super(DPFF, self).__init__() |
|
self.cfes = nn.ModuleList() |
|
for layer_channel in layer_channels: |
|
self.cfes.append(_DPFF(layer_channel)) |
|
|
|
def forward(self, features1, features2): |
|
outputs = [] |
|
for feature1, feature2, cfe in zip(features1, features2, self.cfes): |
|
outputs.append(cfe(feature1, feature2)) |
|
return outputs |
|
|
|
|
|
class DirectDPFF(nn.Module): |
|
def __init__(self, layer_channels) -> None: |
|
super(DirectDPFF, self).__init__() |
|
self.fusions = nn.ModuleList( |
|
[nn.Conv2d(layer_channel * 2, layer_channel, 1, 1) for layer_channel in layer_channels] |
|
) |
|
|
|
def forward(self, features1, features2): |
|
outputs = [] |
|
for feature1, feature2, fusion in zip(features1, features2, self.fusions): |
|
feature = torch.cat([feature1, feature2], dim=1) |
|
outputs.append(fusion(feature)) |
|
return outputs |
|
|
|
|
|
class ConvBlock(nn.Module): |
|
def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, |
|
bn=False, activation=True, maxpool=True): |
|
super(ConvBlock, self).__init__() |
|
self.module = [] |
|
if maxpool: |
|
down = nn.Sequential( |
|
*[ |
|
nn.MaxPool2d(2), |
|
nn.Conv2d(input_size, output_size, 1, 1, 0, bias=bias) |
|
] |
|
) |
|
else: |
|
down = nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias) |
|
self.module.append(down) |
|
if bn: |
|
self.module.append(nn.BatchNorm2d(output_size)) |
|
if activation: |
|
self.module.append(nn.PReLU()) |
|
self.module = nn.Sequential(*self.module) |
|
|
|
def forward(self, x): |
|
out = self.module(x) |
|
|
|
return out |
|
|
|
|
|
class DeconvBlock(nn.Module): |
|
def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, |
|
bn=False, activation=True, bilinear=True): |
|
super(DeconvBlock, self).__init__() |
|
self.module = [] |
|
if bilinear: |
|
deconv = nn.Sequential( |
|
*[ |
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), |
|
nn.Conv2d(input_size, output_size, 1, 1, 0, bias=bias) |
|
] |
|
) |
|
else: |
|
deconv = nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias) |
|
self.module.append(deconv) |
|
if bn: |
|
self.module.append(nn.BatchNorm2d(output_size)) |
|
if activation: |
|
self.module.append(nn.PReLU()) |
|
self.module = nn.Sequential(*self.module) |
|
|
|
def forward(self, x): |
|
out = self.module(x) |
|
|
|
return out |
|
|
|
|
|
class FusionBlock(torch.nn.Module): |
|
def __init__(self, num_filter, num_ft, kernel_size=4, stride=2, padding=1, bias=True, maxpool=False, |
|
bilinear=False): |
|
super(FusionBlock, self).__init__() |
|
self.num_ft = num_ft |
|
self.up_convs = nn.ModuleList() |
|
self.down_convs = nn.ModuleList() |
|
for i in range(self.num_ft): |
|
self.up_convs.append( |
|
DeconvBlock(num_filter // (2 ** i), num_filter // (2 ** (i + 1)), kernel_size, stride, padding, |
|
bias=bias, bilinear=bilinear) |
|
) |
|
self.down_convs.append( |
|
ConvBlock(num_filter // (2 ** (i + 1)), num_filter // (2 ** i), kernel_size, stride, padding, bias=bias, |
|
maxpool=maxpool) |
|
) |
|
|
|
def forward(self, ft_l, ft_h_list): |
|
ft_fusion = ft_l |
|
for i in range(len(ft_h_list)): |
|
ft = ft_fusion |
|
for j in range(self.num_ft - i): |
|
ft = self.up_convs[j](ft) |
|
ft = ft - ft_h_list[i] |
|
for j in range(self.num_ft - i): |
|
ft = self.down_convs[self.num_ft - i - j - 1](ft) |
|
ft_fusion = ft_fusion + ft |
|
|
|
return ft_fusion |
|
|
|
|
|
class ConvLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): |
|
super(ConvLayer, self).__init__() |
|
reflection_padding = kernel_size // 2 |
|
self.reflection_pad = nn.ReflectionPad2d(reflection_padding) |
|
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) |
|
|
|
def forward(self, x): |
|
out = self.reflection_pad(x) |
|
out = self.conv2d(out) |
|
return out |
|
|
|
|
|
class UpsampleConvLayer(torch.nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride): |
|
super(UpsampleConvLayer, self).__init__() |
|
self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride) |
|
|
|
def forward(self, x): |
|
out = self.conv2d(x) |
|
return out |
|
|
|
|
|
class AddRelu(nn.Module): |
|
"""It is for adding two feed forwards to the output of the two following conv layers in expanding path |
|
""" |
|
|
|
def __init__(self) -> None: |
|
super(AddRelu, self).__init__() |
|
self.relu = nn.PReLU() |
|
|
|
def forward(self, input_tensor1, input_tensor2, input_tensor3): |
|
x = input_tensor1 + input_tensor2 + input_tensor3 |
|
return self.relu(x) |
|
|
|
|
|
class BasicBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, mid_channels=None): |
|
super(BasicBlock, self).__init__() |
|
if not mid_channels: |
|
mid_channels = out_channels |
|
self.conv1 = ConvLayer(in_channels, mid_channels, kernel_size=3, stride=1) |
|
self.bn1 = nn.BatchNorm2d(mid_channels, momentum=0.1) |
|
self.relu = nn.PReLU() |
|
|
|
self.conv2 = ConvLayer(mid_channels, out_channels, kernel_size=3, stride=1) |
|
self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.1) |
|
|
|
self.conv3 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1) |
|
|
|
def forward(self, x): |
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
|
|
residual = self.conv3(x) |
|
|
|
out = out + residual |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class Bottleneck(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(Bottleneck, self).__init__() |
|
self.conv1 = ConvLayer(in_channels, out_channels, kernel_size=3, stride=1) |
|
self.bn1 = nn.BatchNorm2d(out_channels, momentum=0.1) |
|
|
|
self.conv2 = ConvLayer(out_channels, out_channels, kernel_size=3, stride=1) |
|
self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.1) |
|
|
|
self.conv3 = ConvLayer(out_channels, out_channels, kernel_size=3, stride=1) |
|
self.bn3 = nn.BatchNorm2d(out_channels, momentum=0.1) |
|
|
|
self.conv4 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1) |
|
|
|
self.relu = nn.PReLU() |
|
|
|
def forward(self, x): |
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv3(out) |
|
out = self.bn3(out) |
|
|
|
residual = self.conv4(x) |
|
|
|
out = out + residual |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class PPM(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(PPM, self).__init__() |
|
|
|
self.pool_sizes = [1, 2, 3, 6] |
|
self.num_levels = len(self.pool_sizes) |
|
|
|
self.conv_layers = nn.ModuleList() |
|
for i in range(self.num_levels): |
|
self.conv_layers.append(nn.Sequential( |
|
nn.AdaptiveAvgPool2d(output_size=self.pool_sizes[i]), |
|
nn.Conv2d(in_channels, in_channels // self.num_levels, kernel_size=1), |
|
nn.BatchNorm2d(in_channels // self.num_levels), |
|
nn.ReLU(inplace=True) |
|
)) |
|
self.out_conv = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1, stride=1) |
|
|
|
def forward(self, x): |
|
input_size = x.size()[2:] |
|
output = [x] |
|
|
|
|
|
for i in range(self.num_levels): |
|
out = self.conv_layers[i](x) |
|
out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=True) |
|
output.append(out) |
|
|
|
|
|
output = torch.cat(output, dim=1) |
|
output = self.out_conv(output) |
|
|
|
return output |
|
|
|
|
|
class MCDNet(nn.Module): |
|
def __init__(self, in_channels=4, num_classes=4, maxpool=False, bilinear=False) -> None: |
|
super().__init__() |
|
level = 1 |
|
|
|
self.conv_input = ConvLayer(in_channels, 32 * level, kernel_size=3, stride=2) |
|
|
|
self.dense0 = BasicBlock(32 * level, 32 * level) |
|
self.conv2x = ConvLayer(32 * level, 64 * level, kernel_size=3, stride=2) |
|
|
|
self.dense1 = BasicBlock(64 * level, 64 * level) |
|
self.conv4x = ConvLayer(64 * level, 128 * level, kernel_size=3, stride=2) |
|
|
|
self.dense2 = BasicBlock(128 * level, 128 * level) |
|
self.conv8x = ConvLayer(128 * level, 256 * level, kernel_size=3, stride=2) |
|
|
|
self.dense3 = BasicBlock(256 * level, 256 * level) |
|
self.conv16x = ConvLayer(256 * level, 512 * level, kernel_size=3, stride=2) |
|
|
|
self.dense4 = PPM(512 * level, 512 * level) |
|
|
|
|
|
self.dpffm = DPFF([32, 64, 128, 256, 512]) |
|
|
|
|
|
self.convd16x = UpsampleConvLayer(512 * level, 256 * level, kernel_size=3, stride=2) |
|
self.fusion4 = FusionBlock(256 * level, 3, maxpool=maxpool, bilinear=bilinear) |
|
self.dense_4 = Bottleneck(512 * level, 256 * level) |
|
self.add_block4 = AddRelu() |
|
|
|
self.convd8x = UpsampleConvLayer(256 * level, 128 * level, kernel_size=3, stride=2) |
|
self.fusion3 = FusionBlock(128 * level, 2, maxpool=maxpool, bilinear=bilinear) |
|
self.dense_3 = Bottleneck(256 * level, 128 * level) |
|
self.add_block3 = AddRelu() |
|
|
|
self.convd4x = UpsampleConvLayer(128 * level, 64 * level, kernel_size=3, stride=2) |
|
self.fusion2 = FusionBlock(64 * level, 1, maxpool=maxpool, bilinear=bilinear) |
|
self.dense_2 = Bottleneck(128 * level, 64 * level) |
|
self.add_block2 = AddRelu() |
|
|
|
self.convd2x = UpsampleConvLayer(64 * level, 32 * level, kernel_size=3, stride=2) |
|
self.dense_1 = Bottleneck(64 * level, 32 * level) |
|
self.add_block1 = AddRelu() |
|
|
|
self.head = UpsampleConvLayer(32 * level, num_classes, kernel_size=3, stride=2) |
|
self.apply(self._weights_init) |
|
|
|
@torch.no_grad() |
|
def get_lr_data(self, x: torch.Tensor) -> torch.Tensor: |
|
images = x.cpu().permute(0, 2, 3, 1).numpy() |
|
batch_size = images.shape[0] |
|
lr = [] |
|
for i in range(batch_size): |
|
lr_image = image_dehazer.remove_haze((images[i]*255).astype(np.uint8), showHazeTransmissionMap=False)[0] |
|
lr_tensor = torch.from_numpy(lr_image).permute(2, 0, 1)/255. |
|
lr.append(lr_tensor) |
|
return torch.stack(lr, dim=0).to(x.device) |
|
|
|
def _weights_init(self, m): |
|
if isinstance(m, nn.Linear): |
|
nn.init.xavier_normal_(m.weight) |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
elif isinstance(m, nn.BatchNorm2d): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x1): |
|
x2 = self.get_lr_data(x1) |
|
|
|
res1x_1 = self.conv_input(x1) |
|
res1x_1 = self.dense0(res1x_1) |
|
|
|
res2x_1 = self.conv2x(res1x_1) |
|
res2x_1 = self.dense1(res2x_1) |
|
|
|
res4x_1 = self.conv4x(res2x_1) |
|
res4x_1 = self.dense2(res4x_1) |
|
|
|
res8x_1 = self.conv8x(res4x_1) |
|
res8x_1 = self.dense3(res8x_1) |
|
|
|
res16x_1 = self.conv16x(res8x_1) |
|
res16x_1 = self.dense4(res16x_1) |
|
|
|
|
|
res1x_2 = self.conv_input(x2) |
|
res1x_2 = self.dense0(res1x_2) |
|
|
|
res2x_2 = self.conv2x(res1x_2) |
|
res2x_2 = self.dense1(res2x_2) |
|
|
|
res4x_2 = self.conv4x(res2x_2) |
|
res4x_2 = self.dense2(res4x_2) |
|
|
|
res8x_2 = self.conv8x(res4x_2) |
|
res8x_2 = self.dense3(res8x_2) |
|
|
|
res16x_2 = self.conv16x(res8x_2) |
|
res16x_2 = self.dense4(res16x_2) |
|
|
|
|
|
res1x, res2x, res4x, res8x, res16x = self.dpffm( |
|
[res1x_1, res2x_1, res4x_1, res8x_1, res16x_1], |
|
[res1x_2, res2x_2, res4x_2, res8x_2, res16x_2] |
|
) |
|
|
|
|
|
res8x1 = self.convd16x(res16x) |
|
res8x1 = F.interpolate(res8x1, res8x.size()[2:], mode='bilinear') |
|
res8x2 = self.fusion4(res8x, [res1x, res2x, res4x]) |
|
res8x2 = torch.cat([res8x1, res8x2], dim=1) |
|
res8x2 = self.dense_4(res8x2) |
|
res8x2 = self.add_block4(res8x1, res8x, res8x2) |
|
|
|
res4x1 = self.convd8x(res8x2) |
|
res4x1 = F.interpolate(res4x1, res4x.size()[2:], mode='bilinear') |
|
res4x2 = self.fusion3(res4x, [res1x, res2x]) |
|
res4x2 = torch.cat([res4x1, res4x2], dim=1) |
|
res4x2 = self.dense_3(res4x2) |
|
res4x2 = self.add_block3(res4x1, res4x, res4x2) |
|
|
|
res2x1 = self.convd4x(res4x2) |
|
res2x1 = F.interpolate(res2x1, res2x.size()[2:], mode='bilinear') |
|
res2x2 = self.fusion2(res2x, [res1x]) |
|
res2x2 = torch.cat([res2x1, res2x2], dim=1) |
|
res2x2 = self.dense_2(res2x2) |
|
res2x2 = self.add_block2(res2x1, res2x, res2x2) |
|
|
|
res1x1 = self.convd2x(res2x2) |
|
res1x1 = F.interpolate(res1x1, res1x.size()[2:], mode='bilinear') |
|
res1x2 = torch.cat([res1x1, res1x], dim=1) |
|
res1x2 = self.dense_1(res1x2) |
|
res1x2 = self.add_block1(res1x1, res1x, res1x2) |
|
|
|
out = self.head(res1x2) |
|
out = F.interpolate(out, x1.size()[2:], mode='bilinear') |
|
|
|
return out |
|
|
|
|
|
if __name__ == "__main__": |
|
num_classes = 2 |
|
model = MCDNet() |
|
|
|
|