mart9992's picture
m
2cd560a
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l
from segment_anything.utils.onnx import SamOnnxModel
import argparse
import warnings
try:
import onnxruntime # type: ignore
onnxruntime_exists = True
except ImportError:
onnxruntime_exists = False
parser = argparse.ArgumentParser(
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
)
parser.add_argument(
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
)
parser.add_argument(
"--output", type=str, required=True, help="The filename to save the ONNX model to."
)
parser.add_argument(
"--model-type",
type=str,
default="default",
help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.",
)
parser.add_argument(
"--return-single-mask",
action="store_true",
help=(
"If true, the exported ONNX model will only return the best mask, "
"instead of returning multiple masks. For high resolution images "
"this can improve runtime when upscaling masks is expensive."
),
)
parser.add_argument(
"--opset",
type=int,
default=17,
help="The ONNX opset version to use. Must be >=11",
)
parser.add_argument(
"--quantize-out",
type=str,
default=None,
help=(
"If set, will quantize the model and save it with this name. "
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
),
)
parser.add_argument(
"--gelu-approximate",
action="store_true",
help=(
"Replace GELU operations with approximations using tanh. Useful "
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
),
)
parser.add_argument(
"--use-stability-score",
action="store_true",
help=(
"Replaces the model's predicted mask quality score with the stability "
"score calculated on the low resolution masks using an offset of 1.0. "
),
)
parser.add_argument(
"--return-extra-metrics",
action="store_true",
help=(
"The model will return five results: (masks, scores, stability_scores, "
"areas, low_res_logits) instead of the usual three. This can be "
"significantly slower for high resolution outputs."
),
)
def run_export(
model_type: str,
checkpoint: str,
output: str,
opset: int,
return_single_mask: bool,
gelu_approximate: bool = False,
use_stability_score: bool = False,
return_extra_metrics=False,
):
print("Loading model...")
if model_type == "vit_b":
sam = build_sam_vit_b(checkpoint)
elif model_type == "vit_l":
sam = build_sam_vit_l(checkpoint)
else:
sam = build_sam(checkpoint)
onnx_model = SamOnnxModel(
model=sam,
return_single_mask=return_single_mask,
use_stability_score=use_stability_score,
return_extra_metrics=return_extra_metrics,
)
if gelu_approximate:
for n, m in onnx_model.named_modules():
if isinstance(m, torch.nn.GELU):
m.approximate = "tanh"
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
_ = onnx_model(**dummy_inputs)
output_names = ["masks", "iou_predictions", "low_res_masks"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
with open(output, "wb") as f:
print(f"Exporing onnx model to {output}...")
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
if onnxruntime_exists:
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
ort_session = onnxruntime.InferenceSession(output)
_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")
def to_numpy(tensor):
return tensor.cpu().numpy()
if __name__ == "__main__":
args = parser.parse_args()
run_export(
model_type=args.model_type,
checkpoint=args.checkpoint,
output=args.output,
opset=args.opset,
return_single_mask=args.return_single_mask,
gelu_approximate=args.gelu_approximate,
use_stability_score=args.use_stability_score,
return_extra_metrics=args.return_extra_metrics,
)
if args.quantize_out is not None:
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
from onnxruntime.quantization import QuantType # type: ignore
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
print(f"Quantizing model and writing to {args.quantize_out}...")
quantize_dynamic(
model_input=args.output,
model_output=args.quantize_out,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
print("Done!")