rcan / infer_onnx.py
Chan
change onnx to NHWC
9f86314
raw
history blame
1.68 kB
import torch
import onnxruntime
import cv2
import sys
import pathlib
CURRENT_DIR = pathlib.Path(__file__).parent
sys.path.append(str(CURRENT_DIR))
import numpy as np
from data.data_tiling import tiling_inference
import argparse
def main(args):
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
provider_options = None
onnx_file_name = args.onnx_path
image_path = args.image_path
output_path = args.output_path
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
sr = tiling_inference(ort_session, lr, 8, (56, 56))
sr = np.clip(sr, 0, 255)
sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
sr = cv2.imwrite(output_path, sr)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RCAN SISR')
parser.add_argument('--onnx_path', type=str, default='RCAN_int8_NHWC.onnx',
help='onnx path')
parser.add_argument('--image_path', default='test_data/test.png',
help='path of your image')
parser.add_argument('--output_path', default='test_data/sr.png',
help='path of your image')
parser.add_argument('--ipu', action='store_true',
help='use ipu')
parser.add_argument('--provider_config', type=str, default=None,
help='provider config path')
args = parser.parse_args()
main(args)