from typing import List import onnx import torch import torch.nn as nn from onnxsim import simplify class Preprocess(nn.Module): def __init__(self, input_shape: List[int]): super(Preprocess, self).__init__() self.input_shape = tuple(input_shape) self.mean = torch.tensor([0.4815, 0.4578, 0.4082]).view(1, 3, 1, 1) self.std = torch.tensor([0.2686, 0.2613, 0.2758]).view(1, 3, 1, 1) def forward(self, x: torch.Tensor): x = torch.nn.functional.interpolate( input=x, size=self.input_shape[2:], ) x = x / 255.0 x = (x - self.mean) / self.std return x if __name__ == "__main__": input_shape = [1, 3, 448, 448] output_onnx_file = "preprocessing.onnx" model = Preprocess(input_shape=input_shape) torch.onnx.export( model, torch.randn(input_shape), output_onnx_file, opset_version=20, input_names=["input_rgb"], output_names=["output_preprocessing"], dynamic_axes={ "input_rgb": { 0: "batch_size", 2: "height", 3: "width", }, }, ) model_onnx = onnx.load(output_onnx_file) model_simplified, _ = simplify(model_onnx) onnx.save(model_simplified, output_onnx_file)