Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class CSDN_Tem(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super(CSDN_Tem, self).__init__() | |
self.depth_conv = nn.Conv2d( | |
in_channels=in_ch, | |
out_channels=in_ch, | |
kernel_size=3, | |
padding=1, | |
groups=in_ch | |
) | |
self.point_conv = nn.Conv2d( | |
in_channels=in_ch, | |
out_channels=out_ch, | |
kernel_size=1 | |
) | |
def forward(self, input): | |
out = self.depth_conv(input) | |
out = self.point_conv(out) | |
return out | |
class enhance_net_nopool(nn.Module): | |
def __init__(self,scale_factor): | |
super(enhance_net_nopool, self).__init__() | |
self.relu = nn.ReLU(inplace=True) | |
self.scale_factor = scale_factor | |
self.upsample = nn.UpsamplingBilinear2d(scale_factor=self.scale_factor) | |
number_f = 32 | |
# zerodce DWC + p-shared | |
self.e_conv1 = CSDN_Tem(3, number_f) | |
self.e_conv2 = CSDN_Tem(number_f, number_f) | |
self.e_conv3 = CSDN_Tem(number_f, number_f) | |
self.e_conv4 = CSDN_Tem(number_f, number_f) | |
self.e_conv5 = CSDN_Tem(number_f * 2, number_f) | |
self.e_conv6 = CSDN_Tem(number_f * 2, number_f) | |
self.e_conv7 = CSDN_Tem(number_f * 2, 3) | |
def enhance(self, x, x_r): | |
for _ in range(8): x = x + x_r * (torch.pow(x, 2) - x) | |
return x | |
def forward(self, x): | |
x_down = x if self.scale_factor==1 else F.interpolate(x, scale_factor = 1 / self.scale_factor, mode='bilinear') | |
x1 = self.relu(self.e_conv1(x_down)) | |
x2 = self.relu(self.e_conv2(x1)) | |
x3 = self.relu(self.e_conv3(x2)) | |
x4 = self.relu(self.e_conv4(x3)) | |
x5 = self.relu(self.e_conv5(torch.cat([x3, x4], 1))) | |
x6 = self.relu(self.e_conv6(torch.cat([x2, x5], 1))) | |
x_r = torch.tanh(self.e_conv7(torch.cat([x1, x6], 1))) | |
x_r = x_r if self.scale_factor==1 else self.upsample(x_r) | |
enhance_image = self.enhance(x, x_r) | |
return enhance_image, x_r |