import torch | |
import torch.onnx as onnx | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
# Load the PyTorch model | |
device = torch.device('cpu') | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
# Load the state dictionary | |
state_dict = torch.load('Real-ESRGAN_x2plus.pth', map_location=device) | |
# Load the state dictionary | |
model.load_state_dict(state_dict['params_ema']) | |
model.train(False) | |
# Set the model to evaluation mode | |
model.eval() | |
# Define the input shape | |
input_shape = (1, 3, 64, 64) # batch_size, channels, height, width | |
# Create a dummy input tensor | |
dummy_input = torch.randn(input_shape) | |
# Convert the model to ONNX | |
onnx.export(model, | |
dummy_input, | |
'Real-ESRGAN_x2plus.onnx', | |
opset_version=11, | |
input_names=['input'], | |
output_names=['output'], | |
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}) |