Spaces:
Running
on
T4
Running
on
T4
import warnings | |
from collections import namedtuple | |
from functools import partial | |
from pathlib import Path | |
from typing import List, Optional, Union | |
import numpy as np | |
import onnxruntime | |
try: | |
import tensorrt as trt | |
except Exception: | |
trt = None | |
import torch | |
warnings.filterwarnings(action='ignore', category=DeprecationWarning) | |
class TRTWrapper(torch.nn.Module): | |
dtype_mapping = {} | |
def __init__(self, weight: Union[str, Path], | |
device: Optional[torch.device]): | |
super().__init__() | |
weight = Path(weight) if isinstance(weight, str) else weight | |
assert weight.exists() and weight.suffix in ('.engine', '.plan') | |
if isinstance(device, str): | |
device = torch.device(device) | |
elif isinstance(device, int): | |
device = torch.device(f'cuda:{device}') | |
self.weight = weight | |
self.device = device | |
self.stream = torch.cuda.Stream(device=device) | |
self.__update_mapping() | |
self.__init_engine() | |
self.__init_bindings() | |
def __update_mapping(self): | |
self.dtype_mapping.update({ | |
trt.bool: torch.bool, | |
trt.int8: torch.int8, | |
trt.int32: torch.int32, | |
trt.float16: torch.float16, | |
trt.float32: torch.float32 | |
}) | |
def __init_engine(self): | |
logger = trt.Logger(trt.Logger.ERROR) | |
self.log = partial(logger.log, trt.Logger.ERROR) | |
trt.init_libnvinfer_plugins(logger, namespace='') | |
self.logger = logger | |
with trt.Runtime(logger) as runtime: | |
model = runtime.deserialize_cuda_engine(self.weight.read_bytes()) | |
context = model.create_execution_context() | |
names = [model.get_binding_name(i) for i in range(model.num_bindings)] | |
num_inputs, num_outputs = 0, 0 | |
for i in range(model.num_bindings): | |
if model.binding_is_input(i): | |
num_inputs += 1 | |
else: | |
num_outputs += 1 | |
self.is_dynamic = -1 in model.get_binding_shape(0) | |
self.model = model | |
self.context = context | |
self.input_names = names[:num_inputs] | |
self.output_names = names[num_inputs:] | |
self.num_inputs = num_inputs | |
self.num_outputs = num_outputs | |
self.num_bindings = num_inputs + num_outputs | |
self.bindings: List[int] = [0] * self.num_bindings | |
def __init_bindings(self): | |
Binding = namedtuple('Binding', ('name', 'dtype', 'shape')) | |
inputs_info = [] | |
outputs_info = [] | |
for i, name in enumerate(self.input_names): | |
assert self.model.get_binding_name(i) == name | |
dtype = self.dtype_mapping[self.model.get_binding_dtype(i)] | |
shape = tuple(self.model.get_binding_shape(i)) | |
inputs_info.append(Binding(name, dtype, shape)) | |
for i, name in enumerate(self.output_names): | |
i += self.num_inputs | |
assert self.model.get_binding_name(i) == name | |
dtype = self.dtype_mapping[self.model.get_binding_dtype(i)] | |
shape = tuple(self.model.get_binding_shape(i)) | |
outputs_info.append(Binding(name, dtype, shape)) | |
self.inputs_info = inputs_info | |
self.outputs_info = outputs_info | |
if not self.is_dynamic: | |
self.output_tensor = [ | |
torch.empty(o.shape, dtype=o.dtype, device=self.device) | |
for o in outputs_info | |
] | |
def forward(self, *inputs): | |
assert len(inputs) == self.num_inputs | |
contiguous_inputs: List[torch.Tensor] = [ | |
i.contiguous() for i in inputs | |
] | |
for i in range(self.num_inputs): | |
self.bindings[i] = contiguous_inputs[i].data_ptr() | |
if self.is_dynamic: | |
self.context.set_binding_shape( | |
i, tuple(contiguous_inputs[i].shape)) | |
# create output tensors | |
outputs: List[torch.Tensor] = [] | |
for i in range(self.num_outputs): | |
j = i + self.num_inputs | |
if self.is_dynamic: | |
shape = tuple(self.context.get_binding_shape(j)) | |
output = torch.empty( | |
size=shape, | |
dtype=self.output_dtypes[i], | |
device=self.device) | |
else: | |
output = self.output_tensor[i] | |
outputs.append(output) | |
self.bindings[j] = output.data_ptr() | |
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream) | |
self.stream.synchronize() | |
return tuple(outputs) | |
class ORTWrapper(torch.nn.Module): | |
def __init__(self, weight: Union[str, Path], | |
device: Optional[torch.device]): | |
super().__init__() | |
weight = Path(weight) if isinstance(weight, str) else weight | |
assert weight.exists() and weight.suffix == '.onnx' | |
if isinstance(device, str): | |
device = torch.device(device) | |
elif isinstance(device, int): | |
device = torch.device(f'cuda:{device}') | |
self.weight = weight | |
self.device = device | |
self.__init_session() | |
self.__init_bindings() | |
def __init_session(self): | |
providers = ['CPUExecutionProvider'] | |
if 'cuda' in self.device.type: | |
providers.insert(0, 'CUDAExecutionProvider') | |
session = onnxruntime.InferenceSession( | |
str(self.weight), providers=providers) | |
self.session = session | |
def __init_bindings(self): | |
Binding = namedtuple('Binding', ('name', 'dtype', 'shape')) | |
inputs_info = [] | |
outputs_info = [] | |
self.is_dynamic = False | |
for i, tensor in enumerate(self.session.get_inputs()): | |
if any(not isinstance(i, int) for i in tensor.shape): | |
self.is_dynamic = True | |
inputs_info.append( | |
Binding(tensor.name, tensor.type, tuple(tensor.shape))) | |
for i, tensor in enumerate(self.session.get_outputs()): | |
outputs_info.append( | |
Binding(tensor.name, tensor.type, tuple(tensor.shape))) | |
self.inputs_info = inputs_info | |
self.outputs_info = outputs_info | |
self.num_inputs = len(inputs_info) | |
def forward(self, *inputs): | |
assert len(inputs) == self.num_inputs | |
contiguous_inputs: List[np.ndarray] = [ | |
i.contiguous().cpu().numpy() for i in inputs | |
] | |
if not self.is_dynamic: | |
# make sure input shape is right for static input shape | |
for i in range(self.num_inputs): | |
assert contiguous_inputs[i].shape == self.inputs_info[i].shape | |
outputs = self.session.run([o.name for o in self.outputs_info], { | |
j.name: contiguous_inputs[i] | |
for i, j in enumerate(self.inputs_info) | |
}) | |
return tuple(torch.from_numpy(o).to(self.device) for o in outputs) | |