File size: 1,824 Bytes
3135a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90e4acb
 
 
 
3135a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import sys
import pathlib
CURRENT_DIR = pathlib.Path(__file__).parent
sys.path.append(str(CURRENT_DIR))

import torch
from tqdm import tqdm
import utility
import data
from option import args
import onnxruntime
from data.data_tiling import tiling_inference


def test_model(session, loader):
    torch.set_grad_enabled(False)
    self_scale = [2]
    for idx_data, d in enumerate(loader.loader_test):
        eval_ssim = 0
        eval_psnr = 0
        for idx_scale, scale in enumerate(self_scale):
            d.dataset.set_scale(idx_scale)
            for lr, hr, filename in tqdm(d, ncols=80):

                # Tiled inference
                sr = tiling_inference(session, lr.numpy(), 8, (56, 56))
                sr = torch.from_numpy(sr)
                sr = utility.quantize(sr, 255)

                # Transform from NHWC to NCHW to calculate metric
                sr = sr.permute((0, 3, 1, 2))
                hr = hr.permute((0, 3, 1, 2))
                eval_psnr += utility.calc_psnr(
                    sr, hr, scale, 255, benchmark=d)
                eval_ssim += utility.calc_ssim(
                    sr, hr, scale, 255, dataset=d)
            mean_ssim = eval_ssim / len(d)
            mean_psnr = eval_psnr / len(d)
        print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim))
    return mean_psnr, mean_ssim

def main():
    loader = data.Data(args)
    onnx_file_name = args.onnx_path
    if args.ipu:
        providers = ["VitisAIExecutionProvider"]
        provider_options = [{"config_file": args.provider_config}]
    else:
        providers = ['CPUExecutionProvider']
        provider_options = None
    ort_session = onnxruntime.InferenceSession(onnx_file_name,  providers=providers, provider_options=provider_options) 
    test_model(ort_session, loader)


if __name__ == '__main__':
    main()