#!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright 2023 Advanced Micro Devices, Inc. on behalf of itself and its subsidiaries and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright (c) Megvii, Inc. and its affiliates. import onnxruntime import argparse from PIL import Image import torchvision.transforms as transforms parser = argparse.ArgumentParser() parser.add_argument('--onnx_path', type=str, default="EfficientNet_int.onnx", required=False) parser.add_argument('--image_path', type=str, required=True) parser.add_argument( "--ipu", action="store_true", help="Use IPU for inference.", ) parser.add_argument( "--provider_config", type=str, default="vaip_config.json", help="Path of the config file for seting provider_options.", ) args = parser.parse_args() def read_image(): # Read a PIL image image = Image.open(args.image_path) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224)), normalize, ]) img_tensor = transform(image).unsqueeze(0) return img_tensor.numpy() def main(): if args.ipu: providers = ["VitisAIExecutionProvider"] provider_options = [{"config_file": args.provider_config}] else: providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] provider_options = None ort_session = onnxruntime.InferenceSession( args.onnx_path, providers=providers, provider_options=provider_options) ort_inputs = { ort_session.get_inputs()[0].name: read_image() } output = ort_session.run(None, ort_inputs)[0] print("class id =", output[0].argmax()) if __name__ == "__main__": main()