Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from segment_anything import sam_model_registry | |
from segment_anything.utils.onnx import SamOnnxModel | |
import argparse | |
import warnings | |
try: | |
import onnxruntime # type: ignore | |
onnxruntime_exists = True | |
except ImportError: | |
onnxruntime_exists = False | |
parser = argparse.ArgumentParser( | |
description="Export the SAM prompt encoder and mask decoder to an ONNX model." | |
) | |
parser.add_argument( | |
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint." | |
) | |
parser.add_argument( | |
"--output", type=str, required=True, help="The filename to save the ONNX model to." | |
) | |
parser.add_argument( | |
"--model-type", | |
type=str, | |
required=True, | |
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.", | |
) | |
parser.add_argument( | |
"--return-single-mask", | |
action="store_true", | |
help=( | |
"If true, the exported ONNX model will only return the best mask, " | |
"instead of returning multiple masks. For high resolution images " | |
"this can improve runtime when upscaling masks is expensive." | |
), | |
) | |
parser.add_argument( | |
"--opset", | |
type=int, | |
default=17, | |
help="The ONNX opset version to use. Must be >=11", | |
) | |
parser.add_argument( | |
"--quantize-out", | |
type=str, | |
default=None, | |
help=( | |
"If set, will quantize the model and save it with this name. " | |
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." | |
), | |
) | |
parser.add_argument( | |
"--gelu-approximate", | |
action="store_true", | |
help=( | |
"Replace GELU operations with approximations using tanh. Useful " | |
"for some runtimes that have slow or unimplemented erf ops, used in GELU." | |
), | |
) | |
parser.add_argument( | |
"--use-stability-score", | |
action="store_true", | |
help=( | |
"Replaces the model's predicted mask quality score with the stability " | |
"score calculated on the low resolution masks using an offset of 1.0. " | |
), | |
) | |
parser.add_argument( | |
"--return-extra-metrics", | |
action="store_true", | |
help=( | |
"The model will return five results: (masks, scores, stability_scores, " | |
"areas, low_res_logits) instead of the usual three. This can be " | |
"significantly slower for high resolution outputs." | |
), | |
) | |
def run_export( | |
model_type: str, | |
checkpoint: str, | |
output: str, | |
opset: int, | |
return_single_mask: bool, | |
gelu_approximate: bool = False, | |
use_stability_score: bool = False, | |
return_extra_metrics=False, | |
): | |
print("Loading model...") | |
sam = sam_model_registry[model_type](checkpoint=checkpoint) | |
onnx_model = SamOnnxModel( | |
model=sam, | |
return_single_mask=return_single_mask, | |
use_stability_score=use_stability_score, | |
return_extra_metrics=return_extra_metrics, | |
) | |
if gelu_approximate: | |
for n, m in onnx_model.named_modules(): | |
if isinstance(m, torch.nn.GELU): | |
m.approximate = "tanh" | |
dynamic_axes = { | |
"point_coords": {1: "num_points"}, | |
"point_labels": {1: "num_points"}, | |
} | |
embed_dim = sam.prompt_encoder.embed_dim | |
embed_size = sam.prompt_encoder.image_embedding_size | |
mask_input_size = [4 * x for x in embed_size] | |
dummy_inputs = { | |
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), | |
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), | |
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), | |
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), | |
"has_mask_input": torch.tensor([1], dtype=torch.float), | |
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), | |
} | |
_ = onnx_model(**dummy_inputs) | |
output_names = ["masks", "iou_predictions", "low_res_masks"] | |
with warnings.catch_warnings(): | |
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) | |
warnings.filterwarnings("ignore", category=UserWarning) | |
with open(output, "wb") as f: | |
print(f"Exporing onnx model to {output}...") | |
torch.onnx.export( | |
onnx_model, | |
tuple(dummy_inputs.values()), | |
f, | |
export_params=True, | |
verbose=False, | |
opset_version=opset, | |
do_constant_folding=True, | |
input_names=list(dummy_inputs.keys()), | |
output_names=output_names, | |
dynamic_axes=dynamic_axes, | |
) | |
if onnxruntime_exists: | |
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} | |
ort_session = onnxruntime.InferenceSession(output) | |
_ = ort_session.run(None, ort_inputs) | |
print("Model has successfully been run with ONNXRuntime.") | |
def to_numpy(tensor): | |
return tensor.cpu().numpy() | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
run_export( | |
model_type=args.model_type, | |
checkpoint=args.checkpoint, | |
output=args.output, | |
opset=args.opset, | |
return_single_mask=args.return_single_mask, | |
gelu_approximate=args.gelu_approximate, | |
use_stability_score=args.use_stability_score, | |
return_extra_metrics=args.return_extra_metrics, | |
) | |
if args.quantize_out is not None: | |
assert onnxruntime_exists, "onnxruntime is required to quantize the model." | |
from onnxruntime.quantization import QuantType # type: ignore | |
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore | |
print(f"Quantizing model and writing to {args.quantize_out}...") | |
quantize_dynamic( | |
model_input=args.output, | |
model_output=args.quantize_out, | |
optimize_model=True, | |
per_channel=False, | |
reduce_range=False, | |
weight_type=QuantType.QUInt8, | |
) | |
print("Done!") | |