|
from . import common |
|
|
|
import torch.nn as nn |
|
|
|
|
|
def make_model(args, parent=False): |
|
return BASIC(args) |
|
|
|
|
|
class BASIC(nn.Module): |
|
def __init__(self, args, conv=common.default_conv): |
|
super(BASIC, self).__init__() |
|
|
|
n_resblocks = args.n_resblocks |
|
n_feats = args.n_feats |
|
kernel_size = 3 |
|
scale = args.scale[0] |
|
act = nn.ReLU(True) |
|
|
|
|
|
m_head = [conv(args.n_colors, n_feats, kernel_size)] |
|
|
|
|
|
m_body = [ |
|
common.ResBlock( |
|
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale |
|
) for _ in range(n_resblocks) |
|
] |
|
m_body.append(conv(n_feats, n_feats, kernel_size)) |
|
|
|
|
|
m_tail = [ |
|
common.Upsampler(conv, scale, n_feats), |
|
conv(n_feats, args.n_colors, kernel_size) |
|
] |
|
|
|
self.head = nn.Sequential(*m_head) |
|
self.body = nn.Sequential(*m_body) |
|
self.tail = nn.Sequential(*m_tail) |
|
|
|
def forward(self, x): |
|
x = self.head(x) |
|
res = self.body(x) |
|
res += x |
|
x = self.tail(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|