FunSR / models /cnn_models /__init__.py
KyanChen's picture
add
02c5426
raw
history blame
6.73 kB
from .transenet import *
from .srcnn import SRCNN
from .fsrcnn import FSRCNN
from .lgcnet import LGCNET
from .dcm import DIM
from .vdsr import VDSR
# import os
# from importlib import import_module
#
# import torch
# import torch.nn as nn
#
#
# class Model(nn.Module):
# def __init__(self, args, ckp):
# super(Model, self).__init__()
# print('Making model...')
#
# self.scale = args.scale
# self.idx_scale = 0
# self.self_ensemble = args.self_ensemble
# self.chop = args.chop
# self.precision = args.precision
# self.cpu = args.cpu
# self.device = torch.device('cpu' if args.cpu else 'cuda')
# self.n_GPUs = args.n_GPUs
# self.save_models = args.save_models
#
# module = import_module('model.' + args.model.lower())
# self.model = module.make_model(args).to(self.device)
# if args.precision == 'half': self.model.half()
#
# if not args.cpu and args.n_GPUs > 1:
# self.model = nn.DataParallel(self.model, range(args.n_GPUs))
#
# self.load(
# ckp.dir,
# pre_train=args.pre_train,
# resume=args.resume,
# cpu=args.cpu
# )
# if args.print_model: print(self.model)
#
# def forward(self, x):
# target = self.get_model()
#
# if self.self_ensemble and not self.training:
# if self.chop:
# forward_function = self.forward_chop
# else:
# forward_function = self.model.forward
#
# return self.forward_x8(x, forward_function)
# elif self.chop and not self.training:
# return self.forward_chop(x)
# else:
# return self.model(x)
#
# def get_model(self):
# if self.n_GPUs == 1:
# return self.model
# else:
# return self.model.module
#
# def state_dict(self, **kwargs):
# target = self.get_model()
# return target.state_dict(**kwargs)
#
# def save(self, apath, epoch, is_best=False):
# target = self.get_model()
# torch.save(
# target.state_dict(),
# os.path.join(apath, 'model', 'model_latest.pt')
# )
# if is_best:
# torch.save(
# target.state_dict(),
# os.path.join(apath, 'model', 'model_best.pt')
# )
#
# if self.save_models:
# torch.save(
# target.state_dict(),
# os.path.join(apath, 'model', 'model_{}.pt'.format(epoch))
# )
#
# def load(self, apath, pre_train='.', resume=-1, cpu=False):
# if cpu:
# kwargs = {'map_location': lambda storage, loc: storage}
# else:
# kwargs = {}
#
# if resume == 1: # loading model from model_latest.pt file
# print('loading model from the model_latest.pt file...')
# self.get_model().load_state_dict(
# torch.load(
# os.path.join(apath, 'model', 'model_latest.pt'),
# **kwargs
# ),
# strict=False
# )
# elif resume == 0: # loading model from a pre-trained model file ...
# if pre_train != '.':
# print('Loading model from {}'.format(pre_train))
# self.get_model().load_state_dict(
# torch.load(pre_train, **kwargs),
# strict=False
# )
# else:
# self.get_model().load_state_dict(
# torch.load(
# os.path.join(apath, 'model', 'model_{}.pt'.format(resume)),
# **kwargs
# ),
# strict=False
# )
#
# def forward_chop(self, x, shave=10, min_size=160000):
# scale = self.scale[self.idx_scale]
# n_GPUs = min(self.n_GPUs, 4)
# b, c, h, w = x.size()
# h_half, w_half = h // 2, w // 2
# h_size, w_size = h_half + shave, w_half + shave
# lr_list = [
# x[:, :, 0:h_size, 0:w_size],
# x[:, :, 0:h_size, (w - w_size):w],
# x[:, :, (h - h_size):h, 0:w_size],
# x[:, :, (h - h_size):h, (w - w_size):w]]
#
# if w_size * h_size < min_size:
# sr_list = []
# for i in range(0, 4, n_GPUs):
# lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
# sr_batch = self.model(lr_batch)
# sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
# else:
# sr_list = [
# self.forward_chop(patch, shave=shave, min_size=min_size) \
# for patch in lr_list
# ]
#
# h, w = scale * h, scale * w
# h_half, w_half = scale * h_half, scale * w_half
# h_size, w_size = scale * h_size, scale * w_size
# shave *= scale
#
# output = x.new(b, c, h, w)
# output[:, :, 0:h_half, 0:w_half] \
# = sr_list[0][:, :, 0:h_half, 0:w_half]
# output[:, :, 0:h_half, w_half:w] \
# = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
# output[:, :, h_half:h, 0:w_half] \
# = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
# output[:, :, h_half:h, w_half:w] \
# = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
#
# return output
#
# def forward_x8(self, x, forward_function):
# def _transform(v, op):
# if self.precision != 'single': v = v.float()
#
# v2np = v.data.cpu().numpy()
# if op == 'v':
# tfnp = v2np[:, :, :, ::-1].copy()
# elif op == 'h':
# tfnp = v2np[:, :, ::-1, :].copy()
# elif op == 't':
# tfnp = v2np.transpose((0, 1, 3, 2)).copy()
#
# ret = torch.Tensor(tfnp).to(self.device)
# if self.precision == 'half': ret = ret.half()
#
# return ret
#
# lr_list = [x]
# for tf in 'v', 'h', 't':
# lr_list.extend([_transform(t, tf) for t in lr_list])
#
# sr_list = [forward_function(aug) for aug in lr_list]
# for i in range(len(sr_list)):
# if i > 3:
# sr_list[i] = _transform(sr_list[i], 't')
# if i % 4 > 1:
# sr_list[i] = _transform(sr_list[i], 'h')
# if (i % 4) % 2 == 1:
# sr_list[i] = _transform(sr_list[i], 'v')
#
# output_cat = torch.cat(sr_list, dim=0)
# output = output_cat.mean(dim=0, keepdim=True)
#
# return output
#
#
#
#
#
#
#
#
#
#
#
#
#
#