model1 / trt_convert_vit.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import logging
import os
import onnx
import tensorrt as trt
from typing import List
from collections import OrderedDict
from onnx import shape_inference
def vit_tagging_t2t(input_path="simple_model.onnx",output_path="vit.trt"):
model = onnx.load(input_path)
inferred_model = shape_inference.infer_shapes(model)
#print(inferred_model.graph.value_info)
simplified_model = input_path
bitmask = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
trt_logger = trt.Logger()
all_count,mix_count=0,0
with trt.Builder(trt_logger) as builder, builder.create_network(bitmask) as network, builder.create_builder_config() as config, trt.OnnxParser(network, trt_logger) as parser:
#config.max_workspace_size = self.max_workspace_size
config.set_flag(trt.BuilderFlag.FP16)
with open(simplified_model, 'rb') as f:
success = parser.parse(f.read())
if not success:
for idx in range(parser.num_errors):
print(parser.get_error(idx))
raise RuntimeError("Failed to parse the ONNX file.")
profile = builder.create_optimization_profile()
min_shape = [3,224,224]
max_shape = [3,224,224]
opt_shape = max_shape #opt shape=max shape by default
profile.set_shape("input",
min=(1, *min_shape),
opt=(70, *opt_shape),
max=(70, *max_shape))
config.add_optimization_profile(profile)
"""
for i in range(network.num_layers):
all_count+=1
layer = network.get_layer(i)
if "ReduceMean" in layer.name or "Pow" in layer.name:
mix_count+=1
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
layer.precision = trt.float32
layer.set_output_type(0, trt.float32)
"""
#networtgetInput(0)->setType(DataType::kHALF)
network.get_input(0).dtype = trt.float32
network.get_output(0).dtype = trt.float32
print(all_count,mix_count)
engine = builder.build_engine(network, config)
#print(engine)
with open(output_path, 'wb') as f:
f.write(engine.serialize())
f.close()
if __name__=="__main__":
vit_tagging_t2t()