# Converting PyTorch to ONNX

In [8]:
import torch
from torch import nn
from torch.nn import functional as F

print(torch.cuda.is_available())

True


## Defining the model

In [9]:
class BasicBlock(nn.Module):

    def __init__(self, in_channels, out_channels, stride= 1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride= stride, padding= 1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride= 1, padding= 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace= True)

        self.downsample = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1))
        
        if stride != 1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride),
                nn.BatchNorm2d(out_channels))

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.downsample(x)

        out = self.relu(out)

        return out

class ResNet34(nn.Module):
    def __init__(self, in_channels, num_classes) -> None:
        super().__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, 7, stride= 2, padding= 3),
            nn.BatchNorm2d(64), nn.MaxPool2d(2, stride= 1), nn.ReLU(inplace= True))

        self.layer0 = self._make_layer(64, 64, 3, 1)
        self.layer1 = self._make_layer(64, 128, 4, 2)
        self.layer2 = self._make_layer(128, 256, 6, 2)
        self.layer3 = self._make_layer(256, 512, 3, 2)

        self.avg_pool = nn.AvgPool2d(4)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        
        layers = []
        layers.append(BasicBlock(in_channels, out_channels, stride))

        for i in range(num_blocks - 1):
            layers.append(BasicBlock(out_channels, out_channels, 1))

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        
        out = self.layer0(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out

## Loading the model

In [10]:
model = ResNet34(3, 10)
model.load_state_dict(torch.load('../models/model.pth'))

dummy_input = torch.randn(1, 3, 64, 64)

model.eval()
torch_out = model(dummy_input)

## Converting to ONNX

In [11]:
onnx_path = '../models/model.onnx'

torch.onnx.export(model,
                dummy_input,
                onnx_path,
                verbose=True,
                input_names = ['input'],   # the model's input names
                output_names = ['output'], # the model's output names
                dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                              'output' : {0 : 'batch_size'}})

Exported graph: graph(%input : Float(*, 3, 64, 64, strides=[12288, 4096, 64, 1], requires_grad=0, device=cpu),
      %layer0.0.downsample.0.weight : Float(64, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
      %layer0.0.downsample.0.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
      %layer0.1.downsample.0.weight : Float(64, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
      %layer0.1.downsample.0.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
      %layer0.2.downsample.0.weight : Float(64, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
      %layer0.2.downsample.0.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
      %layer1.1.downsample.0.weight : Float(128, 128, 1, 1, strides=[128, 1, 1, 1], requires_grad=1, device=cpu),
      %layer1.1.downsample.0.bias : Float(128, strides=[1], requires_grad=1, device=cpu),
      %layer1.2.downsample.0.weight : Float(128, 128, 1, 1, strides=[128, 1, 1, 1], r

## Verifying the ONNX model

In [12]:
import onnx

onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

## Comparing ONNX Runtime and PyTorch results

In [13]:
import onnxruntime
import numpy as np

ort_session = onnxruntime.InferenceSession(onnx_path)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}
ort_outs = ort_session.run(None, ort_inputs)

np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")


Exported model has been tested with ONNXRuntime, and the result looks good!
