Spaces:
Runtime error
Runtime error
import argparse | |
import mmcv | |
import numpy as np | |
import torch | |
import torch._C | |
import torch.serialization | |
from mmcv.runner import load_checkpoint | |
from torch import nn | |
from mmseg.models import build_segmentor | |
torch.manual_seed(3) | |
def digit_version(version_str): | |
digit_version = [] | |
for x in version_str.split('.'): | |
if x.isdigit(): | |
digit_version.append(int(x)) | |
elif x.find('rc') != -1: | |
patch_version = x.split('rc') | |
digit_version.append(int(patch_version[0]) - 1) | |
digit_version.append(int(patch_version[1])) | |
return digit_version | |
def check_torch_version(): | |
torch_minimum_version = '1.8.0' | |
torch_version = digit_version(torch.__version__) | |
assert (torch_version >= digit_version(torch_minimum_version)), \ | |
f'Torch=={torch.__version__} is not support for converting to ' \ | |
f'torchscript. Please install pytorch>={torch_minimum_version}.' | |
def _convert_batchnorm(module): | |
module_output = module | |
if isinstance(module, torch.nn.SyncBatchNorm): | |
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, | |
module.momentum, module.affine, | |
module.track_running_stats) | |
if module.affine: | |
module_output.weight.data = module.weight.data.clone().detach() | |
module_output.bias.data = module.bias.data.clone().detach() | |
# keep requires_grad unchanged | |
module_output.weight.requires_grad = module.weight.requires_grad | |
module_output.bias.requires_grad = module.bias.requires_grad | |
module_output.running_mean = module.running_mean | |
module_output.running_var = module.running_var | |
module_output.num_batches_tracked = module.num_batches_tracked | |
for name, child in module.named_children(): | |
module_output.add_module(name, _convert_batchnorm(child)) | |
del module | |
return module_output | |
def _demo_mm_inputs(input_shape, num_classes): | |
"""Create a superset of inputs needed to run test or train batches. | |
Args: | |
input_shape (tuple): | |
input batch dimensions | |
num_classes (int): | |
number of semantic classes | |
""" | |
(N, C, H, W) = input_shape | |
rng = np.random.RandomState(0) | |
imgs = rng.rand(*input_shape) | |
segs = rng.randint( | |
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) | |
img_metas = [{ | |
'img_shape': (H, W, C), | |
'ori_shape': (H, W, C), | |
'pad_shape': (H, W, C), | |
'filename': '<demo>.png', | |
'scale_factor': 1.0, | |
'flip': False, | |
} for _ in range(N)] | |
mm_inputs = { | |
'imgs': torch.FloatTensor(imgs).requires_grad_(True), | |
'img_metas': img_metas, | |
'gt_semantic_seg': torch.LongTensor(segs) | |
} | |
return mm_inputs | |
def pytorch2libtorch(model, | |
input_shape, | |
show=False, | |
output_file='tmp.pt', | |
verify=False): | |
"""Export Pytorch model to TorchScript model and verify the outputs are | |
same between Pytorch and TorchScript. | |
Args: | |
model (nn.Module): Pytorch model we want to export. | |
input_shape (tuple): Use this input shape to construct | |
the corresponding dummy input and execute the model. | |
show (bool): Whether print the computation graph. Default: False. | |
output_file (string): The path to where we store the | |
output TorchScript model. Default: `tmp.pt`. | |
verify (bool): Whether compare the outputs between | |
Pytorch and TorchScript. Default: False. | |
""" | |
if isinstance(model.decode_head, nn.ModuleList): | |
num_classes = model.decode_head[-1].num_classes | |
else: | |
num_classes = model.decode_head.num_classes | |
mm_inputs = _demo_mm_inputs(input_shape, num_classes) | |
imgs = mm_inputs.pop('imgs') | |
# replace the orginal forword with forward_dummy | |
model.forward = model.forward_dummy | |
model.eval() | |
traced_model = torch.jit.trace( | |
model, | |
example_inputs=imgs, | |
check_trace=verify, | |
) | |
if show: | |
print(traced_model.graph) | |
traced_model.save(output_file) | |
print('Successfully exported TorchScript model: {}'.format(output_file)) | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description='Convert MMSeg to TorchScript') | |
parser.add_argument('config', help='test config file path') | |
parser.add_argument('--checkpoint', help='checkpoint file', default=None) | |
parser.add_argument( | |
'--show', action='store_true', help='show TorchScript graph') | |
parser.add_argument( | |
'--verify', action='store_true', help='verify the TorchScript model') | |
parser.add_argument('--output-file', type=str, default='tmp.pt') | |
parser.add_argument( | |
'--shape', | |
type=int, | |
nargs='+', | |
default=[512, 512], | |
help='input image size (height, width)') | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
args = parse_args() | |
check_torch_version() | |
if len(args.shape) == 1: | |
input_shape = (1, 3, args.shape[0], args.shape[0]) | |
elif len(args.shape) == 2: | |
input_shape = ( | |
1, | |
3, | |
) + tuple(args.shape) | |
else: | |
raise ValueError('invalid input shape') | |
cfg = mmcv.Config.fromfile(args.config) | |
cfg.model.pretrained = None | |
# build the model and load checkpoint | |
cfg.model.train_cfg = None | |
segmentor = build_segmentor( | |
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) | |
# convert SyncBN to BN | |
segmentor = _convert_batchnorm(segmentor) | |
if args.checkpoint: | |
load_checkpoint(segmentor, args.checkpoint, map_location='cpu') | |
# convert the PyTorch model to LibTorch model | |
pytorch2libtorch( | |
segmentor, | |
input_shape, | |
show=args.show, | |
output_file=args.output_file, | |
verify=args.verify) | |