from argparse import Namespace import torch import torch.nn as nn from models import register import torch.nn.functional as F def make_model(args, parent=False): return CNN7(args) @register('LGCNET') def LGCNET(scale_ratio, rgb_range=1): args = Namespace() args.scale = [scale_ratio] args.n_colors = 3 args.rgb_range = rgb_range return LGCNET(args) class LGCNET(nn.Module): def __init__(self, args, nfeats = 32): super(LGCNET, self).__init__() self.conv1 = nn.Conv2d(args.n_colors, nfeats, kernel_size=3, stride=1, padding=1, bias=True) self.conv2 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True) self.conv3 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True) self.conv4 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True) self.conv5 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True) self.conv6 = nn.Conv2d(nfeats*3, nfeats*2, kernel_size=5, stride=1, padding=2, bias=True) self.conv7 = nn.Conv2d(nfeats*2, 3, kernel_size=3, stride=1, padding=1, bias=True) self.relu = nn.ReLU() def forward(self, x, out_size): x = F.interpolate(x, out_size, mode='bicubic') residual = x im1 = self.relu(self.conv1(x)) im2 = self.relu(self.conv2(im1)) im3 = self.relu(self.conv3(im2)) im4 = self.relu(self.conv4(im3)) im5 = self.relu(self.conv5(im4)) out = self.relu(self.conv6(torch.cat((im3, im4, im5), dim = 1))) out = self.conv7(out) + residual return out def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find('tail') >= 0: print('Replace pre-trained upsampler to new one...') else: raise RuntimeError('While copying the parameter named {}, ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}.' .format(name, own_state[name].size(), param.size())) elif strict: if name.find('tail') == -1: raise KeyError('unexpected key "{}" in state_dict' .format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) class CNN7(nn.Module): def __init__(self, args, nfeats = 32): super(CNN7, self).__init__() self.conv1 = nn.Conv2d(args.n_colors, nfeats, kernel_size=3, stride=1, padding=1, bias=True) self.conv2 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True) self.conv3 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True) self.conv4 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True) self.conv5 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True) self.conv6 = nn.Conv2d(nfeats, nfeats, kernel_size=3, stride=1, padding=1, bias=True) self.conv7 = nn.Conv2d(nfeats, 3, kernel_size=3, stride=1, padding=1, bias=True) self.relu = nn.ReLU() def forward(self, x): residual = x im1 = self.relu(self.conv1(x)) im2 = self.relu(self.conv2(im1)) im3 = self.relu(self.conv3(im2)) im4 = self.relu(self.conv4(im3)) im5 = self.relu(self.conv5(im4)) im6 = self.relu(self.conv6(im5)) out = self.conv7(im6) + residual return out def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find('tail') >= 0: print('Replace pre-trained upsampler to new one...') else: raise RuntimeError('While copying the parameter named {}, ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}.' .format(name, own_state[name].size(), param.size())) elif strict: if name.find('tail') == -1: raise KeyError('unexpected key "{}" in state_dict' .format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing))