import torch import torch.nn as nn import torch.nn.functional as F class CSDN_Tem(nn.Module): def __init__(self, in_ch, out_ch): super(CSDN_Tem, self).__init__() self.depth_conv = nn.Conv2d( in_channels=in_ch, out_channels=in_ch, kernel_size=3, padding=1, groups=in_ch ) self.point_conv = nn.Conv2d( in_channels=in_ch, out_channels=out_ch, kernel_size=1 ) def forward(self, input): out = self.depth_conv(input) out = self.point_conv(out) return out class enhance_net_nopool(nn.Module): def __init__(self,scale_factor): super(enhance_net_nopool, self).__init__() self.relu = nn.ReLU(inplace=True) self.scale_factor = scale_factor self.upsample = nn.UpsamplingBilinear2d(scale_factor=self.scale_factor) number_f = 32 # zerodce DWC + p-shared self.e_conv1 = CSDN_Tem(3, number_f) self.e_conv2 = CSDN_Tem(number_f, number_f) self.e_conv3 = CSDN_Tem(number_f, number_f) self.e_conv4 = CSDN_Tem(number_f, number_f) self.e_conv5 = CSDN_Tem(number_f * 2, number_f) self.e_conv6 = CSDN_Tem(number_f * 2, number_f) self.e_conv7 = CSDN_Tem(number_f * 2, 3) def enhance(self, x, x_r): for _ in range(8): x = x + x_r * (torch.pow(x, 2) - x) return x def forward(self, x): x_down = x if self.scale_factor==1 else F.interpolate(x, scale_factor = 1 / self.scale_factor, mode='bilinear') x1 = self.relu(self.e_conv1(x_down)) x2 = self.relu(self.e_conv2(x1)) x3 = self.relu(self.e_conv3(x2)) x4 = self.relu(self.e_conv4(x3)) x5 = self.relu(self.e_conv5(torch.cat([x3, x4], 1))) x6 = self.relu(self.e_conv6(torch.cat([x2, x5], 1))) x_r = torch.tanh(self.e_conv7(torch.cat([x1, x6], 1))) x_r = x_r if self.scale_factor==1 else self.upsample(x_r) enhance_image = self.enhance(x, x_r) return enhance_image, x_r