English
XavierJiezou's picture
Add files using upload-large-folder tool
63a9590 verified
raw
history blame
15.4 kB
# -*- coding: utf-8 -*-
# @Time : 2024/7/21 下午3:51
# @Author : xiaoshun
# @Email : 3038523973@qq.com
# @File : mcdnet.py
# @Software: PyCharm
import image_dehazer
import numpy as np
# 论文地址:https://www.sciencedirect.com/science/article/pii/S1569843224001742?via%3Dihub
import torch
import torch.nn as nn
import torch.nn.functional as F
class _DPFF(nn.Module):
def __init__(self, in_channels) -> None:
super(_DPFF, self).__init__()
self.cbr1 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False)
self.cbr2 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False)
# self.sigmoid = nn.Sigmoid()
self.cbr3 = nn.Conv2d(in_channels, in_channels, 1, 1, bias=False)
self.cbr4 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False)
def forward(self, feature1, feature2):
d1 = torch.abs(feature1 - feature2)
d2 = self.cbr1(torch.cat([feature1, feature2], dim=1))
d = torch.cat([d1, d2], dim=1)
d = self.cbr2(d)
# d = self.sigmoid(d)
v1, v2 = self.cbr3(feature1), self.cbr3(feature2)
v1, v2 = v1 * d, v2 * d
features = torch.cat([v1, v2], dim=1)
features = self.cbr4(features)
return features
class DPFF(nn.Module):
def __init__(self, layer_channels) -> None:
super(DPFF, self).__init__()
self.cfes = nn.ModuleList()
for layer_channel in layer_channels:
self.cfes.append(_DPFF(layer_channel))
def forward(self, features1, features2):
outputs = []
for feature1, feature2, cfe in zip(features1, features2, self.cfes):
outputs.append(cfe(feature1, feature2))
return outputs
class DirectDPFF(nn.Module):
def __init__(self, layer_channels) -> None:
super(DirectDPFF, self).__init__()
self.fusions = nn.ModuleList(
[nn.Conv2d(layer_channel * 2, layer_channel, 1, 1) for layer_channel in layer_channels]
)
def forward(self, features1, features2):
outputs = []
for feature1, feature2, fusion in zip(features1, features2, self.fusions):
feature = torch.cat([feature1, feature2], dim=1)
outputs.append(fusion(feature))
return outputs
class ConvBlock(nn.Module):
def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True,
bn=False, activation=True, maxpool=True):
super(ConvBlock, self).__init__()
self.module = []
if maxpool:
down = nn.Sequential(
*[
nn.MaxPool2d(2),
nn.Conv2d(input_size, output_size, 1, 1, 0, bias=bias)
]
)
else:
down = nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
self.module.append(down)
if bn:
self.module.append(nn.BatchNorm2d(output_size))
if activation:
self.module.append(nn.PReLU())
self.module = nn.Sequential(*self.module)
def forward(self, x):
out = self.module(x)
return out
class DeconvBlock(nn.Module):
def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True,
bn=False, activation=True, bilinear=True):
super(DeconvBlock, self).__init__()
self.module = []
if bilinear:
deconv = nn.Sequential(
*[
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(input_size, output_size, 1, 1, 0, bias=bias)
]
)
else:
deconv = nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
self.module.append(deconv)
if bn:
self.module.append(nn.BatchNorm2d(output_size))
if activation:
self.module.append(nn.PReLU())
self.module = nn.Sequential(*self.module)
def forward(self, x):
out = self.module(x)
return out
class FusionBlock(torch.nn.Module):
def __init__(self, num_filter, num_ft, kernel_size=4, stride=2, padding=1, bias=True, maxpool=False,
bilinear=False):
super(FusionBlock, self).__init__()
self.num_ft = num_ft
self.up_convs = nn.ModuleList()
self.down_convs = nn.ModuleList()
for i in range(self.num_ft):
self.up_convs.append(
DeconvBlock(num_filter // (2 ** i), num_filter // (2 ** (i + 1)), kernel_size, stride, padding,
bias=bias, bilinear=bilinear)
)
self.down_convs.append(
ConvBlock(num_filter // (2 ** (i + 1)), num_filter // (2 ** i), kernel_size, stride, padding, bias=bias,
maxpool=maxpool)
)
def forward(self, ft_l, ft_h_list):
ft_fusion = ft_l
for i in range(len(ft_h_list)):
ft = ft_fusion
for j in range(self.num_ft - i):
ft = self.up_convs[j](ft)
ft = ft - ft_h_list[i]
for j in range(self.num_ft - i):
ft = self.down_convs[self.num_ft - i - j - 1](ft)
ft_fusion = ft_fusion + ft
return ft_fusion
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
super(ConvLayer, self).__init__()
reflection_padding = kernel_size // 2
self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
def forward(self, x):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class UpsampleConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(UpsampleConvLayer, self).__init__()
self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)
def forward(self, x):
out = self.conv2d(x)
return out
class AddRelu(nn.Module):
"""It is for adding two feed forwards to the output of the two following conv layers in expanding path
"""
def __init__(self) -> None:
super(AddRelu, self).__init__()
self.relu = nn.PReLU()
def forward(self, input_tensor1, input_tensor2, input_tensor3):
x = input_tensor1 + input_tensor2 + input_tensor3
return self.relu(x)
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super(BasicBlock, self).__init__()
if not mid_channels:
mid_channels = out_channels
self.conv1 = ConvLayer(in_channels, mid_channels, kernel_size=3, stride=1)
self.bn1 = nn.BatchNorm2d(mid_channels, momentum=0.1)
self.relu = nn.PReLU()
self.conv2 = ConvLayer(mid_channels, out_channels, kernel_size=3, stride=1)
self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.1)
self.conv3 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
residual = self.conv3(x)
out = out + residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
def __init__(self, in_channels, out_channels):
super(Bottleneck, self).__init__()
self.conv1 = ConvLayer(in_channels, out_channels, kernel_size=3, stride=1)
self.bn1 = nn.BatchNorm2d(out_channels, momentum=0.1)
self.conv2 = ConvLayer(out_channels, out_channels, kernel_size=3, stride=1)
self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.1)
self.conv3 = ConvLayer(out_channels, out_channels, kernel_size=3, stride=1)
self.bn3 = nn.BatchNorm2d(out_channels, momentum=0.1)
self.conv4 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1)
self.relu = nn.PReLU()
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
residual = self.conv4(x)
out = out + residual
out = self.relu(out)
return out
class PPM(nn.Module):
def __init__(self, in_channels, out_channels):
super(PPM, self).__init__()
self.pool_sizes = [1, 2, 3, 6] # subregion size in each level
self.num_levels = len(self.pool_sizes) # number of pyramid levels
self.conv_layers = nn.ModuleList()
for i in range(self.num_levels):
self.conv_layers.append(nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=self.pool_sizes[i]),
nn.Conv2d(in_channels, in_channels // self.num_levels, kernel_size=1),
nn.BatchNorm2d(in_channels // self.num_levels),
nn.ReLU(inplace=True)
))
self.out_conv = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1, stride=1)
def forward(self, x):
input_size = x.size()[2:] # get input size
output = [x]
# pyramid pooling
for i in range(self.num_levels):
out = self.conv_layers[i](x)
out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=True)
output.append(out)
# concatenate features from different levels
output = torch.cat(output, dim=1)
output = self.out_conv(output)
return output
class MCDNet(nn.Module):
def __init__(self, in_channels=4, num_classes=4, maxpool=False, bilinear=False) -> None:
super().__init__()
level = 1
# encoder
self.conv_input = ConvLayer(in_channels, 32 * level, kernel_size=3, stride=2)
self.dense0 = BasicBlock(32 * level, 32 * level)
self.conv2x = ConvLayer(32 * level, 64 * level, kernel_size=3, stride=2)
self.dense1 = BasicBlock(64 * level, 64 * level)
self.conv4x = ConvLayer(64 * level, 128 * level, kernel_size=3, stride=2)
self.dense2 = BasicBlock(128 * level, 128 * level)
self.conv8x = ConvLayer(128 * level, 256 * level, kernel_size=3, stride=2)
self.dense3 = BasicBlock(256 * level, 256 * level)
self.conv16x = ConvLayer(256 * level, 512 * level, kernel_size=3, stride=2)
self.dense4 = PPM(512 * level, 512 * level)
# dpff
self.dpffm = DPFF([32, 64, 128, 256, 512])
# decoder
self.convd16x = UpsampleConvLayer(512 * level, 256 * level, kernel_size=3, stride=2)
self.fusion4 = FusionBlock(256 * level, 3, maxpool=maxpool, bilinear=bilinear)
self.dense_4 = Bottleneck(512 * level, 256 * level)
self.add_block4 = AddRelu()
self.convd8x = UpsampleConvLayer(256 * level, 128 * level, kernel_size=3, stride=2)
self.fusion3 = FusionBlock(128 * level, 2, maxpool=maxpool, bilinear=bilinear)
self.dense_3 = Bottleneck(256 * level, 128 * level)
self.add_block3 = AddRelu()
self.convd4x = UpsampleConvLayer(128 * level, 64 * level, kernel_size=3, stride=2)
self.fusion2 = FusionBlock(64 * level, 1, maxpool=maxpool, bilinear=bilinear)
self.dense_2 = Bottleneck(128 * level, 64 * level)
self.add_block2 = AddRelu()
self.convd2x = UpsampleConvLayer(64 * level, 32 * level, kernel_size=3, stride=2)
self.dense_1 = Bottleneck(64 * level, 32 * level)
self.add_block1 = AddRelu()
self.head = UpsampleConvLayer(32 * level, num_classes, kernel_size=3, stride=2)
self.apply(self._weights_init)
@torch.no_grad()
def get_lr_data(self, x: torch.Tensor) -> torch.Tensor:
images = x.cpu().permute(0, 2, 3, 1).numpy() # b, h, w, c
batch_size = images.shape[0]
lr = []
for i in range(batch_size):
lr_image = image_dehazer.remove_haze((images[i]*255).astype(np.uint8), showHazeTransmissionMap=False)[0] # h, w, c, numpy.array
lr_tensor = torch.from_numpy(lr_image).permute(2, 0, 1)/255. # c, h, w
lr.append(lr_tensor)
return torch.stack(lr, dim=0).to(x.device) # b, c, h, w
def _weights_init(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x1):
x2 = self.get_lr_data(x1)
# encoder1
res1x_1 = self.conv_input(x1)
res1x_1 = self.dense0(res1x_1)
res2x_1 = self.conv2x(res1x_1)
res2x_1 = self.dense1(res2x_1)
res4x_1 = self.conv4x(res2x_1)
res4x_1 = self.dense2(res4x_1)
res8x_1 = self.conv8x(res4x_1)
res8x_1 = self.dense3(res8x_1)
res16x_1 = self.conv16x(res8x_1)
res16x_1 = self.dense4(res16x_1)
# encoder2
res1x_2 = self.conv_input(x2)
res1x_2 = self.dense0(res1x_2)
res2x_2 = self.conv2x(res1x_2)
res2x_2 = self.dense1(res2x_2)
res4x_2 = self.conv4x(res2x_2)
res4x_2 = self.dense2(res4x_2)
res8x_2 = self.conv8x(res4x_2)
res8x_2 = self.dense3(res8x_2)
res16x_2 = self.conv16x(res8x_2)
res16x_2 = self.dense4(res16x_2)
# dual-perspective feature fusion
res1x, res2x, res4x, res8x, res16x = self.dpffm(
[res1x_1, res2x_1, res4x_1, res8x_1, res16x_1],
[res1x_2, res2x_2, res4x_2, res8x_2, res16x_2]
)
# decoder
res8x1 = self.convd16x(res16x)
res8x1 = F.interpolate(res8x1, res8x.size()[2:], mode='bilinear')
res8x2 = self.fusion4(res8x, [res1x, res2x, res4x])
res8x2 = torch.cat([res8x1, res8x2], dim=1)
res8x2 = self.dense_4(res8x2)
res8x2 = self.add_block4(res8x1, res8x, res8x2)
res4x1 = self.convd8x(res8x2)
res4x1 = F.interpolate(res4x1, res4x.size()[2:], mode='bilinear')
res4x2 = self.fusion3(res4x, [res1x, res2x])
res4x2 = torch.cat([res4x1, res4x2], dim=1)
res4x2 = self.dense_3(res4x2)
res4x2 = self.add_block3(res4x1, res4x, res4x2)
res2x1 = self.convd4x(res4x2)
res2x1 = F.interpolate(res2x1, res2x.size()[2:], mode='bilinear')
res2x2 = self.fusion2(res2x, [res1x])
res2x2 = torch.cat([res2x1, res2x2], dim=1)
res2x2 = self.dense_2(res2x2)
res2x2 = self.add_block2(res2x1, res2x, res2x2)
res1x1 = self.convd2x(res2x2)
res1x1 = F.interpolate(res1x1, res1x.size()[2:], mode='bilinear')
res1x2 = torch.cat([res1x1, res1x], dim=1)
res1x2 = self.dense_1(res1x2)
res1x2 = self.add_block1(res1x1, res1x, res1x2)
out = self.head(res1x2)
out = F.interpolate(out, x1.size()[2:], mode='bilinear')
return out
if __name__ == "__main__":
num_classes = 2
model = MCDNet()
# inp = torch.randn(size=(2, 3, 256, 256))
# assert model(input).shape == (2, 2, 256, 256)