|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from torch.autograd import Variable |
|
|
|
def default_conv(in_channels, out_channels, kernel_size, bias=True): |
|
return nn.Conv2d( |
|
in_channels, out_channels, kernel_size, |
|
padding=(kernel_size//2), bias=bias) |
|
|
|
class MeanShift(nn.Conv2d): |
|
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): |
|
super(MeanShift, self).__init__(3, 3, kernel_size=1) |
|
std = torch.Tensor(rgb_std) |
|
self.weight.data = torch.eye(3).view(3, 3, 1, 1) |
|
self.weight.data.div_(std.view(3, 1, 1, 1)) |
|
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) |
|
self.bias.data.div_(std) |
|
self.requires_grad = False |
|
|
|
class BasicBlock(nn.Sequential): |
|
def __init__( |
|
self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True, |
|
bn=False, act=nn.ReLU(True)): |
|
|
|
m = [conv(in_channels, out_channels, kernel_size, bias=bias)] |
|
if bn: |
|
m.append(nn.BatchNorm2d(out_channels)) |
|
if act is not None: |
|
m.append(act) |
|
super(BasicBlock, self).__init__(*m) |
|
|
|
class ResBlock(nn.Module): |
|
def __init__( |
|
self, conv, n_feat, kernel_size, |
|
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): |
|
|
|
super(ResBlock, self).__init__() |
|
m = [] |
|
for i in range(2): |
|
m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) |
|
if bn: m.append(nn.BatchNorm2d(n_feat)) |
|
if i == 0: m.append(act) |
|
|
|
self.body = nn.Sequential(*m) |
|
self.res_scale = res_scale |
|
|
|
def forward(self, x): |
|
res = self.body(x).mul(self.res_scale) |
|
res += x |
|
|
|
return res |
|
|
|
|
|
class Upsampler(nn.Sequential): |
|
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): |
|
|
|
m = [] |
|
if (scale & (scale - 1)) == 0: |
|
for _ in range(int(math.log(scale, 2))): |
|
m.append(conv(n_feat, 4 * n_feat, 3, bias)) |
|
m.append(nn.PixelShuffle(2)) |
|
if bn: m.append(nn.BatchNorm2d(n_feat)) |
|
if act: m.append(act()) |
|
elif scale == 3: |
|
m.append(conv(n_feat, 9 * n_feat, 3, bias)) |
|
m.append(nn.PixelShuffle(3)) |
|
if bn: m.append(nn.BatchNorm2d(n_feat)) |
|
if act: m.append(act()) |
|
else: |
|
raise NotImplementedError |
|
|
|
super(Upsampler, self).__init__(*m) |
|
|
|
|
|
class DownBlock(nn.Module): |
|
def __init__(self, scale): |
|
super().__init__() |
|
|
|
self.scale = scale |
|
|
|
def forward(self, x): |
|
n, c, h, w = x.size() |
|
x = x.view(n, c, h//self.scale, self.scale, w//self.scale, self.scale) |
|
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() |
|
x = x.view(n, c * (self.scale**2), h//self.scale, w//self.scale) |
|
return x |
|
|
|
|
|
|
|
|
|
class NonLocalBlock2D(nn.Module): |
|
def __init__(self, in_channels, inter_channels): |
|
super(NonLocalBlock2D, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
self.inter_channels = inter_channels |
|
|
|
self.g = nn.Conv2d(in_channels=in_channels, out_channels=inter_channels, |
|
kernel_size=1, stride=1, padding=0) |
|
self.W = nn.Conv2d(in_channels=inter_channels, out_channels=in_channels, |
|
kernel_size=1, stride=1, padding=0) |
|
nn.init.constant_(self.W.weight, 0) |
|
nn.init.constant_(self.W.bias, 0) |
|
|
|
self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, |
|
kernel_size=1, stride=1, padding=0) |
|
self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, |
|
kernel_size=1, stride=1, padding=0) |
|
|
|
def forward(self, x): |
|
|
|
batch_size = x.size(0) |
|
|
|
g_x = self.g(x).view(batch_size, self.inter_channels, -1) |
|
g_x = g_x.permute(0, 2, 1) |
|
|
|
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) |
|
theta_x = theta_x.permute(0, 2, 1) |
|
|
|
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) |
|
f = torch.matmul(theta_x, phi_x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
f_div_C = F.softmax(f, dim=-1) |
|
|
|
y = torch.matmul(f_div_C, g_x) |
|
y = y.permute(0, 2, 1).contiguous() |
|
y = y.view(batch_size, self.inter_channels, *x.size()[2:]) |
|
W_y = self.W(y) |
|
z = W_y + x |
|
|
|
return z |