File size: 1,992 Bytes
37aeb5b cb68e0f 37aeb5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import onnxruntime
import torch
providers = [
# ('TensorrtExecutionProvider', {
# 'device_id': 0,
# 'trt_max_workspace_size': 8 * 1024 * 1024 * 1024,
# 'trt_fp16_enable': True,
# 'trt_engine_cache_enable': True,
# }),
('CUDAExecutionProvider', {
'device_id': 0,
'arena_extend_strategy': 'kSameAsRequested',
'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
'cudnn_conv_algo_search': 'HEURISTIC',
})
]
def load_onnx(file_path: str):
assert file_path.endswith(".onnx")
sess_opt = onnxruntime.SessionOptions()
ort_session = onnxruntime.InferenceSession(file_path, sess_opt=sess_opt, providers=providers)
return ort_session
def load_onnx_caller(file_path: str, single_output=False):
ort_session = load_onnx(file_path)
def caller(*args):
torch_input = isinstance(args[0], torch.Tensor)
if torch_input:
torch_input_dtype = args[0].dtype
torch_input_device = args[0].device
# check all are torch.Tensor and have same dtype and device
assert all([isinstance(arg, torch.Tensor) for arg in args]), "All inputs should be torch.Tensor, if first input is torch.Tensor"
assert all([arg.dtype == torch_input_dtype for arg in args]), "All inputs should have same dtype, if first input is torch.Tensor"
assert all([arg.device == torch_input_device for arg in args]), "All inputs should have same device, if first input is torch.Tensor"
args = [arg.cpu().float().numpy() for arg in args]
ort_inputs = {ort_session.get_inputs()[idx].name: args[idx] for idx in range(len(args))}
ort_outs = ort_session.run(None, ort_inputs)
if torch_input:
ort_outs = [torch.tensor(ort_out, dtype=torch_input_dtype, device=torch_input_device) for ort_out in ort_outs]
if single_output:
return ort_outs[0]
return ort_outs
return caller
|