PAN / infer_onnx.py
Tellll's picture
Update code and model to support NHWC input format
90e4acb
raw
history blame
1.94 kB
import sys
import pathlib
CURRENT_DIR = pathlib.Path(__file__).parent
sys.path.append(str(CURRENT_DIR))
import onnxruntime
import cv2
import numpy as np
from data.data_tiling import tiling_inference
import argparse
def main(args):
onnx_file_name = args.onnx_path
image_path = args.image_path
output_path = args.output_path
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CPUExecutionProvider']
provider_options = None
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
lr = cv2.imread(image_path)[np.newaxis,:,:,:].astype(np.float32)
# Tiled inference
sr = tiling_inference(ort_session, lr, 8, (56, 56))
sr = np.clip(sr, 0, 255)
sr = sr.squeeze().astype(np.uint8)
cv2.imwrite(output_path, sr)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PAN')
parser.add_argument('--onnx_path',
type=str,
default='PAN_int8.onnx',
help='Path to onnx model')
parser.add_argument('--image_path',
type=str,
default='test_data/test.png',
help='Path to your low resolution input image.')
parser.add_argument('--output_path',
type=str,
default='test_data/sr.png',
help='Path to your upscaled output image.')
parser.add_argument('--provider_config',
type=str,
default="vaip_config.json",
help="Path of the config file for seting provider_options.")
parser.add_argument('--ipu', action='store_true', help='Use Ipu for interence.')
args = parser.parse_args()
main(args)