File size: 2,689 Bytes
26e1a9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
import cv2
import glob
import numpy as np
from collections import OrderedDict
import os
import torch
import requests
from PIL import Image
import torchvision.transforms.functional as TF
import torch.nn.functional as F

from model.SRMNet import SRMNet
from utils import util_calculate_psnr_ssim as util


def save_img(filepath, img):
    cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))


def load_checkpoint(model, weights):
    checkpoint = torch.load(weights)
    try:
        model.load_state_dict(checkpoint["state_dict"])
    except:
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)


def main():
    parser = argparse.ArgumentParser(description='Demo Image Denoising')
    parser.add_argument('--input_dir', default='./test/', type=str, help='Input images')
    parser.add_argument('--result_dir', default='./result/', type=str, help='Directory for results')
    parser.add_argument('--weights',
                        default='./checkpoints/SRMNet_real_denoise/models/model_bestPSNR.pth', type=str,
                        help='Path to weights')

    args = parser.parse_args()

    inp_dir = args.input_dir
    out_dir = args.result_dir

    os.makedirs(out_dir, exist_ok=True)

    files = sorted(glob.glob(os.path.join(inp_dir, '*.PNG')))

    if len(files) == 0:
        raise Exception(f"No files found at {inp_dir}")

    # Load corresponding models architecture and weights
    model = SRMNet()
    model.cuda()

    load_checkpoint(model, args.weights)
    model.eval()

    mul = 16
    for file_ in files:
        img = Image.open(file_).convert('RGB')
        input_ = TF.to_tensor(img).unsqueeze(0).cuda()

        # Pad the input if not_multiple_of 8
        h, w = input_.shape[2], input_.shape[3]
        H, W = ((h + mul) // mul) * mul, ((w + mul) // mul) * mul
        padh = H - h if h % mul != 0 else 0
        padw = W - w if w % mul != 0 else 0
        input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
        with torch.no_grad():
            restored = model(input_)

        restored = torch.clamp(restored, 0, 1)
        restored = restored[:, :, :h, :w]
        restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
        restored = img_as_ubyte(restored[0])

        f = os.path.splitext(os.path.split(file_)[-1])[0]
        save_img((os.path.join(out_dir, f + '.png')), restored)



if __name__ == '__main__':
    main()