S3Diff / basicsr /models /swinir_model.py
zhangap's picture
Upload 213 files
36d9761 verified
raw
history blame
1.12 kB
import torch
from torch.nn import functional as F
from basicsr.utils.registry import MODEL_REGISTRY
from .sr_model import SRModel
@MODEL_REGISTRY.register()
class SwinIRModel(SRModel):
def test(self):
# pad to multiplication of window_size
window_size = self.opt['network_g']['window_size']
scale = self.opt.get('scale', 1)
mod_pad_h, mod_pad_w = 0, 0
_, _, h, w = self.lq.size()
if h % window_size != 0:
mod_pad_h = window_size - h % window_size
if w % window_size != 0:
mod_pad_w = window_size - w % window_size
img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
self.output = self.net_g_ema(img)
else:
self.net_g.eval()
with torch.no_grad():
self.output = self.net_g(img)
self.net_g.train()
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]