PAN / eval_onnx.py
Tellll's picture
Update code and model to support NHWC input format
90e4acb
raw
history blame
1.82 kB
import sys
import pathlib
CURRENT_DIR = pathlib.Path(__file__).parent
sys.path.append(str(CURRENT_DIR))
import torch
from tqdm import tqdm
import utility
import data
from option import args
import onnxruntime
from data.data_tiling import tiling_inference
def test_model(session, loader):
torch.set_grad_enabled(False)
self_scale = [2]
for idx_data, d in enumerate(loader.loader_test):
eval_ssim = 0
eval_psnr = 0
for idx_scale, scale in enumerate(self_scale):
d.dataset.set_scale(idx_scale)
for lr, hr, filename in tqdm(d, ncols=80):
# Tiled inference
sr = tiling_inference(session, lr.numpy(), 8, (56, 56))
sr = torch.from_numpy(sr)
sr = utility.quantize(sr, 255)
# Transform from NHWC to NCHW to calculate metric
sr = sr.permute((0, 3, 1, 2))
hr = hr.permute((0, 3, 1, 2))
eval_psnr += utility.calc_psnr(
sr, hr, scale, 255, benchmark=d)
eval_ssim += utility.calc_ssim(
sr, hr, scale, 255, dataset=d)
mean_ssim = eval_ssim / len(d)
mean_psnr = eval_psnr / len(d)
print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim))
return mean_psnr, mean_ssim
def main():
loader = data.Data(args)
onnx_file_name = args.onnx_path
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CPUExecutionProvider']
provider_options = None
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
test_model(ort_session, loader)
if __name__ == '__main__':
main()